numpy.h 89 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312
  1. /*
  2. pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
  3. Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
  4. All rights reserved. Use of this source code is governed by a
  5. BSD-style license that can be found in the LICENSE file.
  6. */
  7. #pragma once
  8. #include "pybind11.h"
  9. #include "detail/common.h"
  10. #include "complex.h"
  11. #include "gil_safe_call_once.h"
  12. #include "pytypes.h"
  13. #include <algorithm>
  14. #include <array>
  15. #include <cstdint>
  16. #include <cstdlib>
  17. #include <cstring>
  18. #include <functional>
  19. #include <numeric>
  20. #include <sstream>
  21. #include <string>
  22. #include <type_traits>
  23. #include <typeindex>
  24. #include <utility>
  25. #include <vector>
  26. #if defined(PYBIND11_NUMPY_1_ONLY)
  27. # error "PYBIND11_NUMPY_1_ONLY is no longer supported (see PR #5595)."
  28. #endif
  29. /* This will be true on all flat address space platforms and allows us to reduce the
  30. whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
  31. and dimension types (e.g. shape, strides, indexing), instead of inflicting this
  32. upon the library user.
  33. Note that NumPy 2 now uses ssize_t for `npy_intp` to simplify this. */
  34. static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
  35. static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
  36. // We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares)
  37. PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
  38. PYBIND11_WARNING_DISABLE_MSVC(4127)
  39. class dtype; // Forward declaration
  40. class array; // Forward declaration
  41. template <typename>
  42. struct numpy_scalar; // Forward declaration
  43. PYBIND11_NAMESPACE_BEGIN(detail)
  44. template <>
  45. struct handle_type_name<dtype> {
  46. static constexpr auto name = const_name("numpy.dtype");
  47. };
  48. template <>
  49. struct handle_type_name<array> {
  50. static constexpr auto name = const_name("numpy.ndarray");
  51. };
  52. template <typename type, typename SFINAE = void>
  53. struct npy_format_descriptor;
  54. /* NumPy 1 proxy (always includes legacy fields) */
  55. struct PyArrayDescr1_Proxy {
  56. PyObject_HEAD
  57. PyObject *typeobj;
  58. char kind;
  59. char type;
  60. char byteorder;
  61. char flags;
  62. int type_num;
  63. int elsize;
  64. int alignment;
  65. char *subarray;
  66. PyObject *fields;
  67. PyObject *names;
  68. };
  69. struct PyArrayDescr_Proxy {
  70. PyObject_HEAD
  71. PyObject *typeobj;
  72. char kind;
  73. char type;
  74. char byteorder;
  75. char _former_flags;
  76. int type_num;
  77. /* Additional fields are NumPy version specific. */
  78. };
  79. /* NumPy 2 proxy, including legacy fields */
  80. struct PyArrayDescr2_Proxy {
  81. PyObject_HEAD
  82. PyObject *typeobj;
  83. char kind;
  84. char type;
  85. char byteorder;
  86. char _former_flags;
  87. int type_num;
  88. std::uint64_t flags;
  89. ssize_t elsize;
  90. ssize_t alignment;
  91. PyObject *metadata;
  92. Py_hash_t hash;
  93. void *reserved_null[2];
  94. /* The following fields only exist if 0 <= type_num < 2056 */
  95. char *subarray;
  96. PyObject *fields;
  97. PyObject *names;
  98. };
  99. struct PyArray_Proxy {
  100. PyObject_HEAD
  101. char *data;
  102. int nd;
  103. ssize_t *dimensions;
  104. ssize_t *strides;
  105. PyObject *base;
  106. PyObject *descr;
  107. int flags;
  108. };
  109. struct PyVoidScalarObject_Proxy {
  110. PyObject_VAR_HEAD char *obval;
  111. PyArrayDescr_Proxy *descr;
  112. int flags;
  113. PyObject *base;
  114. };
  115. struct numpy_type_info {
  116. PyObject *dtype_ptr;
  117. std::string format_str;
  118. };
  119. struct numpy_internals {
  120. std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
  121. numpy_type_info *get_type_info(const std::type_info &tinfo, bool throw_if_missing = true) {
  122. auto it = registered_dtypes.find(std::type_index(tinfo));
  123. if (it != registered_dtypes.end()) {
  124. return &(it->second);
  125. }
  126. if (throw_if_missing) {
  127. pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
  128. }
  129. return nullptr;
  130. }
  131. template <typename T>
  132. numpy_type_info *get_type_info(bool throw_if_missing = true) {
  133. return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
  134. }
  135. };
  136. PYBIND11_NOINLINE void load_numpy_internals(numpy_internals *&ptr) {
  137. ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
  138. }
  139. inline numpy_internals &get_numpy_internals() {
  140. static numpy_internals *ptr = nullptr;
  141. if (!ptr) {
  142. load_numpy_internals(ptr);
  143. }
  144. return *ptr;
  145. }
  146. PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name) {
  147. module_ numpy = module_::import("numpy");
  148. str version_string = numpy.attr("__version__");
  149. module_ numpy_lib = module_::import("numpy.lib");
  150. object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
  151. int major_version = numpy_version.attr("major").cast<int>();
  152. /* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
  153. became a private module. */
  154. std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
  155. return module_::import((numpy_core_path + "." + submodule_name).c_str());
  156. }
  157. template <typename T>
  158. struct same_size {
  159. template <typename U>
  160. using as = bool_constant<sizeof(T) == sizeof(U)>;
  161. };
  162. template <typename Concrete>
  163. constexpr int platform_lookup() {
  164. return -1;
  165. }
  166. // Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
  167. template <typename Concrete, typename T, typename... Ts, typename... Ints>
  168. constexpr int platform_lookup(int I, Ints... Is) {
  169. return sizeof(Concrete) == sizeof(T) ? I : platform_lookup<Concrete, Ts...>(Is...);
  170. }
  171. struct npy_api {
  172. // If you change this code, please review `normalized_dtype_num` below.
  173. enum constants {
  174. NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
  175. NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
  176. NPY_ARRAY_OWNDATA_ = 0x0004,
  177. NPY_ARRAY_FORCECAST_ = 0x0010,
  178. NPY_ARRAY_ENSUREARRAY_ = 0x0040,
  179. NPY_ARRAY_ALIGNED_ = 0x0100,
  180. NPY_ARRAY_WRITEABLE_ = 0x0400,
  181. NPY_BOOL_ = 0,
  182. NPY_BYTE_,
  183. NPY_UBYTE_,
  184. NPY_SHORT_,
  185. NPY_USHORT_,
  186. NPY_INT_,
  187. NPY_UINT_,
  188. NPY_LONG_,
  189. NPY_ULONG_,
  190. NPY_LONGLONG_,
  191. NPY_ULONGLONG_,
  192. NPY_FLOAT_,
  193. NPY_DOUBLE_,
  194. NPY_LONGDOUBLE_,
  195. NPY_CFLOAT_,
  196. NPY_CDOUBLE_,
  197. NPY_CLONGDOUBLE_,
  198. NPY_OBJECT_ = 17,
  199. NPY_STRING_,
  200. NPY_UNICODE_,
  201. NPY_VOID_,
  202. // Platform-dependent normalization
  203. NPY_INT8_ = NPY_BYTE_,
  204. NPY_UINT8_ = NPY_UBYTE_,
  205. NPY_INT16_ = NPY_SHORT_,
  206. NPY_UINT16_ = NPY_USHORT_,
  207. // `npy_common.h` defines the integer aliases. In order, it checks:
  208. // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
  209. // and assigns the alias to the first matching size, so we should check in this order.
  210. NPY_INT32_
  211. = platform_lookup<std::int32_t, long, int, short>(NPY_LONG_, NPY_INT_, NPY_SHORT_),
  212. NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
  213. NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
  214. NPY_INT64_
  215. = platform_lookup<std::int64_t, long, long long, int>(NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
  216. NPY_UINT64_
  217. = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
  218. NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
  219. NPY_FLOAT32_ = platform_lookup<float, double, float, long double>(
  220. NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
  221. NPY_FLOAT64_ = platform_lookup<double, double, float, long double>(
  222. NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
  223. NPY_COMPLEX64_
  224. = platform_lookup<std::complex<float>,
  225. std::complex<double>,
  226. std::complex<float>,
  227. std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
  228. NPY_COMPLEX128_
  229. = platform_lookup<std::complex<double>,
  230. std::complex<double>,
  231. std::complex<float>,
  232. std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
  233. NPY_CHAR_ = std::is_signed<char>::value ? NPY_BYTE_ : NPY_UBYTE_,
  234. };
  235. unsigned int PyArray_RUNTIME_VERSION_;
  236. struct PyArray_Dims {
  237. Py_intptr_t *ptr;
  238. int len;
  239. };
  240. static npy_api &get() {
  241. PYBIND11_CONSTINIT static gil_safe_call_once_and_store<npy_api> storage;
  242. return storage.call_once_and_store_result(lookup).get_stored();
  243. }
  244. bool PyArray_Check_(PyObject *obj) const {
  245. return PyObject_TypeCheck(obj, PyArray_Type_) != 0;
  246. }
  247. bool PyArrayDescr_Check_(PyObject *obj) const {
  248. return PyObject_TypeCheck(obj, PyArrayDescr_Type_) != 0;
  249. }
  250. unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
  251. PyObject *(*PyArray_DescrFromType_)(int);
  252. PyObject *(*PyArray_TypeObjectFromType_)(int);
  253. PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
  254. PyObject *,
  255. int,
  256. Py_intptr_t const *,
  257. Py_intptr_t const *,
  258. void *,
  259. int,
  260. PyObject *);
  261. // Unused. Not removed because that affects ABI of the class.
  262. PyObject *(*PyArray_DescrNewFromType_)(int);
  263. int (*PyArray_CopyInto_)(PyObject *, PyObject *);
  264. PyObject *(*PyArray_NewCopy_)(PyObject *, int);
  265. PyTypeObject *PyArray_Type_;
  266. PyTypeObject *PyVoidArrType_Type_;
  267. PyTypeObject *PyArrayDescr_Type_;
  268. PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
  269. PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *);
  270. void (*PyArray_ScalarAsCtype_)(PyObject *, void *);
  271. PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
  272. int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
  273. bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
  274. PyObject *(*PyArray_Squeeze_)(PyObject *);
  275. // Unused. Not removed because that affects ABI of the class.
  276. int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
  277. PyObject *(*PyArray_Resize_)(PyObject *, PyArray_Dims *, int, int);
  278. PyObject *(*PyArray_Newshape_)(PyObject *, PyArray_Dims *, int);
  279. PyObject *(*PyArray_View_)(PyObject *, PyObject *, PyObject *);
  280. private:
  281. enum functions {
  282. API_PyArray_GetNDArrayCFeatureVersion = 211,
  283. API_PyArray_Type = 2,
  284. API_PyArrayDescr_Type = 3,
  285. API_PyVoidArrType_Type = 39,
  286. API_PyArray_DescrFromType = 45,
  287. API_PyArray_TypeObjectFromType = 46,
  288. API_PyArray_DescrFromScalar = 57,
  289. API_PyArray_Scalar = 60,
  290. API_PyArray_ScalarAsCtype = 62,
  291. API_PyArray_FromAny = 69,
  292. API_PyArray_Resize = 80,
  293. // CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
  294. API_PyArray_CopyInto = 50,
  295. API_PyArray_NewCopy = 85,
  296. API_PyArray_NewFromDescr = 94,
  297. API_PyArray_DescrNewFromType = 96,
  298. API_PyArray_Newshape = 135,
  299. API_PyArray_Squeeze = 136,
  300. API_PyArray_View = 137,
  301. API_PyArray_DescrConverter = 174,
  302. API_PyArray_EquivTypes = 182,
  303. API_PyArray_SetBaseObject = 282
  304. };
  305. static npy_api lookup() {
  306. module_ m = detail::import_numpy_core_submodule("multiarray");
  307. auto c = m.attr("_ARRAY_API");
  308. void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), nullptr);
  309. if (api_ptr == nullptr) {
  310. raise_from(PyExc_SystemError, "FAILURE obtaining numpy _ARRAY_API pointer.");
  311. throw error_already_set();
  312. }
  313. npy_api api;
  314. #define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
  315. DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
  316. api.PyArray_RUNTIME_VERSION_ = api.PyArray_GetNDArrayCFeatureVersion_();
  317. if (api.PyArray_RUNTIME_VERSION_ < 0x7) {
  318. pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
  319. }
  320. DECL_NPY_API(PyArray_Type);
  321. DECL_NPY_API(PyVoidArrType_Type);
  322. DECL_NPY_API(PyArrayDescr_Type);
  323. DECL_NPY_API(PyArray_DescrFromType);
  324. DECL_NPY_API(PyArray_TypeObjectFromType);
  325. DECL_NPY_API(PyArray_DescrFromScalar);
  326. DECL_NPY_API(PyArray_Scalar);
  327. DECL_NPY_API(PyArray_ScalarAsCtype);
  328. DECL_NPY_API(PyArray_FromAny);
  329. DECL_NPY_API(PyArray_Resize);
  330. DECL_NPY_API(PyArray_CopyInto);
  331. DECL_NPY_API(PyArray_NewCopy);
  332. DECL_NPY_API(PyArray_NewFromDescr);
  333. DECL_NPY_API(PyArray_DescrNewFromType);
  334. DECL_NPY_API(PyArray_Newshape);
  335. DECL_NPY_API(PyArray_Squeeze);
  336. DECL_NPY_API(PyArray_View);
  337. DECL_NPY_API(PyArray_DescrConverter);
  338. DECL_NPY_API(PyArray_EquivTypes);
  339. DECL_NPY_API(PyArray_SetBaseObject);
  340. #undef DECL_NPY_API
  341. return api;
  342. }
  343. };
  344. template <typename T>
  345. struct is_complex : std::false_type {};
  346. template <typename T>
  347. struct is_complex<std::complex<T>> : std::true_type {};
  348. template <typename T, typename = void>
  349. struct npy_format_descriptor_name;
  350. template <typename T>
  351. struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
  352. static constexpr auto name = const_name<std::is_same<T, bool>::value>(
  353. const_name("numpy.bool"),
  354. const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
  355. + const_name<sizeof(T) * 8>());
  356. };
  357. template <typename T>
  358. struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
  359. static constexpr auto name = const_name < std::is_same<T, float>::value
  360. || std::is_same<T, const float>::value
  361. || std::is_same<T, double>::value
  362. || std::is_same<T, const double>::value
  363. > (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
  364. const_name("numpy.longdouble"));
  365. };
  366. template <typename T>
  367. struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
  368. static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
  369. || std::is_same<typename T::value_type, const float>::value
  370. || std::is_same<typename T::value_type, double>::value
  371. || std::is_same<typename T::value_type, const double>::value
  372. > (const_name("numpy.complex")
  373. + const_name<sizeof(typename T::value_type) * 16>(),
  374. const_name("numpy.longcomplex"));
  375. };
  376. template <typename T>
  377. struct numpy_scalar_info {};
  378. #define PYBIND11_NUMPY_SCALAR_IMPL(ctype_, typenum_) \
  379. template <> \
  380. struct numpy_scalar_info<ctype_> { \
  381. static constexpr auto name = npy_format_descriptor_name<ctype_>::name; \
  382. static constexpr int typenum = npy_api::typenum_##_; \
  383. }
  384. // boolean type
  385. PYBIND11_NUMPY_SCALAR_IMPL(bool, NPY_BOOL);
  386. // character types
  387. PYBIND11_NUMPY_SCALAR_IMPL(char, NPY_CHAR);
  388. PYBIND11_NUMPY_SCALAR_IMPL(signed char, NPY_BYTE);
  389. PYBIND11_NUMPY_SCALAR_IMPL(unsigned char, NPY_UBYTE);
  390. // signed integer types
  391. PYBIND11_NUMPY_SCALAR_IMPL(std::int16_t, NPY_INT16);
  392. PYBIND11_NUMPY_SCALAR_IMPL(std::int32_t, NPY_INT32);
  393. PYBIND11_NUMPY_SCALAR_IMPL(std::int64_t, NPY_INT64);
  394. // unsigned integer types
  395. PYBIND11_NUMPY_SCALAR_IMPL(std::uint16_t, NPY_UINT16);
  396. PYBIND11_NUMPY_SCALAR_IMPL(std::uint32_t, NPY_UINT32);
  397. PYBIND11_NUMPY_SCALAR_IMPL(std::uint64_t, NPY_UINT64);
  398. // floating point types
  399. PYBIND11_NUMPY_SCALAR_IMPL(float, NPY_FLOAT);
  400. PYBIND11_NUMPY_SCALAR_IMPL(double, NPY_DOUBLE);
  401. PYBIND11_NUMPY_SCALAR_IMPL(long double, NPY_LONGDOUBLE);
  402. // complex types
  403. PYBIND11_NUMPY_SCALAR_IMPL(std::complex<float>, NPY_CFLOAT);
  404. PYBIND11_NUMPY_SCALAR_IMPL(std::complex<double>, NPY_CDOUBLE);
  405. PYBIND11_NUMPY_SCALAR_IMPL(std::complex<long double>, NPY_CLONGDOUBLE);
  406. #undef PYBIND11_NUMPY_SCALAR_IMPL
  407. // This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ...
  408. // This is needed to correctly handle situations where multiple typenums map to the same type,
  409. // e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different
  410. // typenum. The normalized typenum should always match the values used in npy_format_descriptor.
  411. // If you change this code, please review `enum constants` above.
  412. static constexpr int normalized_dtype_num[npy_api::NPY_VOID_ + 1] = {
  413. // NPY_BOOL_ =>
  414. npy_api::NPY_BOOL_,
  415. // NPY_BYTE_ =>
  416. npy_api::NPY_BYTE_,
  417. // NPY_UBYTE_ =>
  418. npy_api::NPY_UBYTE_,
  419. // NPY_SHORT_ =>
  420. npy_api::NPY_INT16_,
  421. // NPY_USHORT_ =>
  422. npy_api::NPY_UINT16_,
  423. // NPY_INT_ =>
  424. sizeof(int) == sizeof(std::int16_t) ? npy_api::NPY_INT16_
  425. : sizeof(int) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
  426. : sizeof(int) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
  427. : npy_api::NPY_INT_,
  428. // NPY_UINT_ =>
  429. sizeof(unsigned int) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_
  430. : sizeof(unsigned int) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
  431. : sizeof(unsigned int) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
  432. : npy_api::NPY_UINT_,
  433. // NPY_LONG_ =>
  434. sizeof(long) == sizeof(std::int16_t) ? npy_api::NPY_INT16_
  435. : sizeof(long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
  436. : sizeof(long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
  437. : npy_api::NPY_LONG_,
  438. // NPY_ULONG_ =>
  439. sizeof(unsigned long) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_
  440. : sizeof(unsigned long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
  441. : sizeof(unsigned long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
  442. : npy_api::NPY_ULONG_,
  443. // NPY_LONGLONG_ =>
  444. sizeof(long long) == sizeof(std::int16_t) ? npy_api::NPY_INT16_
  445. : sizeof(long long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
  446. : sizeof(long long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
  447. : npy_api::NPY_LONGLONG_,
  448. // NPY_ULONGLONG_ =>
  449. sizeof(unsigned long long) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_
  450. : sizeof(unsigned long long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
  451. : sizeof(unsigned long long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
  452. : npy_api::NPY_ULONGLONG_,
  453. // NPY_FLOAT_ =>
  454. npy_api::NPY_FLOAT_,
  455. // NPY_DOUBLE_ =>
  456. npy_api::NPY_DOUBLE_,
  457. // NPY_LONGDOUBLE_ =>
  458. npy_api::NPY_LONGDOUBLE_,
  459. // NPY_CFLOAT_ =>
  460. npy_api::NPY_CFLOAT_,
  461. // NPY_CDOUBLE_ =>
  462. npy_api::NPY_CDOUBLE_,
  463. // NPY_CLONGDOUBLE_ =>
  464. npy_api::NPY_CLONGDOUBLE_,
  465. // NPY_OBJECT_ =>
  466. npy_api::NPY_OBJECT_,
  467. // NPY_STRING_ =>
  468. npy_api::NPY_STRING_,
  469. // NPY_UNICODE_ =>
  470. npy_api::NPY_UNICODE_,
  471. // NPY_VOID_ =>
  472. npy_api::NPY_VOID_,
  473. };
  474. inline PyArray_Proxy *array_proxy(void *ptr) { return reinterpret_cast<PyArray_Proxy *>(ptr); }
  475. inline const PyArray_Proxy *array_proxy(const void *ptr) {
  476. return reinterpret_cast<const PyArray_Proxy *>(ptr);
  477. }
  478. inline PyArrayDescr_Proxy *array_descriptor_proxy(PyObject *ptr) {
  479. return reinterpret_cast<PyArrayDescr_Proxy *>(ptr);
  480. }
  481. inline const PyArrayDescr_Proxy *array_descriptor_proxy(const PyObject *ptr) {
  482. return reinterpret_cast<const PyArrayDescr_Proxy *>(ptr);
  483. }
  484. inline const PyArrayDescr1_Proxy *array_descriptor1_proxy(const PyObject *ptr) {
  485. return reinterpret_cast<const PyArrayDescr1_Proxy *>(ptr);
  486. }
  487. inline const PyArrayDescr2_Proxy *array_descriptor2_proxy(const PyObject *ptr) {
  488. return reinterpret_cast<const PyArrayDescr2_Proxy *>(ptr);
  489. }
  490. inline bool check_flags(const void *ptr, int flag) {
  491. return (flag == (array_proxy(ptr)->flags & flag));
  492. }
  493. template <typename T>
  494. struct is_std_array : std::false_type {};
  495. template <typename T, size_t N>
  496. struct is_std_array<std::array<T, N>> : std::true_type {};
  497. template <typename T>
  498. struct array_info_scalar {
  499. using type = T;
  500. static constexpr bool is_array = false;
  501. static constexpr bool is_empty = false;
  502. static constexpr auto extents = const_name("");
  503. static void append_extents(list & /* shape */) {}
  504. };
  505. // Computes underlying type and a comma-separated list of extents for array
  506. // types (any mix of std::array and built-in arrays). An array of char is
  507. // treated as scalar because it gets special handling.
  508. template <typename T>
  509. struct array_info : array_info_scalar<T> {};
  510. template <typename T, size_t N>
  511. struct array_info<std::array<T, N>> {
  512. using type = typename array_info<T>::type;
  513. static constexpr bool is_array = true;
  514. static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
  515. static constexpr size_t extent = N;
  516. // appends the extents to shape
  517. static void append_extents(list &shape) {
  518. shape.append(N);
  519. array_info<T>::append_extents(shape);
  520. }
  521. static constexpr auto extents = const_name<array_info<T>::is_array>(
  522. ::pybind11::detail::concat(const_name<N>(), array_info<T>::extents), const_name<N>());
  523. };
  524. // For numpy we have special handling for arrays of characters, so we don't include
  525. // the size in the array extents.
  526. template <size_t N>
  527. struct array_info<char[N]> : array_info_scalar<char[N]> {};
  528. template <size_t N>
  529. struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> {};
  530. template <typename T, size_t N>
  531. struct array_info<T[N]> : array_info<std::array<T, N>> {};
  532. template <typename T>
  533. using remove_all_extents_t = typename array_info<T>::type;
  534. template <typename T>
  535. using is_pod_struct
  536. = all_of<std::is_standard_layout<T>, // since we're accessing directly in memory
  537. // we need a standard layout type
  538. #if defined(__GLIBCXX__) \
  539. && (__GLIBCXX__ < 20150422 || __GLIBCXX__ == 20150426 || __GLIBCXX__ == 20150623 \
  540. || __GLIBCXX__ == 20150626 || __GLIBCXX__ == 20160803)
  541. // libstdc++ < 5 (including versions 4.8.5, 4.9.3 and 4.9.4 which were released after
  542. // 5) don't implement is_trivially_copyable, so approximate it
  543. std::is_trivially_destructible<T>,
  544. satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
  545. #else
  546. std::is_trivially_copyable<T>,
  547. #endif
  548. satisfies_none_of<T,
  549. std::is_reference,
  550. std::is_array,
  551. is_std_array,
  552. std::is_arithmetic,
  553. is_complex,
  554. std::is_enum>>;
  555. // Replacement for std::is_pod (deprecated in C++20)
  556. template <typename T>
  557. using is_pod = all_of<std::is_standard_layout<T>, std::is_trivial<T>>;
  558. template <ssize_t Dim = 0, typename Strides>
  559. ssize_t byte_offset_unsafe(const Strides &) {
  560. return 0;
  561. }
  562. template <ssize_t Dim = 0, typename Strides, typename... Ix>
  563. ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
  564. return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
  565. }
  566. /**
  567. * Proxy class providing unsafe, unchecked const access to array data. This is constructed through
  568. * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
  569. * will be -1 for dimensions determined at runtime.
  570. */
  571. template <typename T, ssize_t Dims>
  572. class unchecked_reference {
  573. protected:
  574. static constexpr bool Dynamic = Dims < 0;
  575. const unsigned char *data_;
  576. // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
  577. // make large performance gains on big, nested loops, but requires compile-time dimensions
  578. conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>> shape_, strides_;
  579. const ssize_t dims_;
  580. friend class pybind11::array;
  581. // Constructor for compile-time dimensions:
  582. template <bool Dyn = Dynamic>
  583. unchecked_reference(const void *data,
  584. const ssize_t *shape,
  585. const ssize_t *strides,
  586. enable_if_t<!Dyn, ssize_t>)
  587. : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
  588. for (size_t i = 0; i < (size_t) dims_; i++) {
  589. shape_[i] = shape[i];
  590. strides_[i] = strides[i];
  591. }
  592. }
  593. // Constructor for runtime dimensions:
  594. template <bool Dyn = Dynamic>
  595. unchecked_reference(const void *data,
  596. const ssize_t *shape,
  597. const ssize_t *strides,
  598. enable_if_t<Dyn, ssize_t> dims)
  599. : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides},
  600. dims_{dims} {}
  601. public:
  602. /**
  603. * Unchecked const reference access to data at the given indices. For a compile-time known
  604. * number of dimensions, this requires the correct number of arguments; for run-time
  605. * dimensionality, this is not checked (and so is up to the caller to use safely).
  606. */
  607. template <typename... Ix>
  608. const T &operator()(Ix... index) const {
  609. static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
  610. "Invalid number of indices for unchecked array reference");
  611. return *reinterpret_cast<const T *>(data_
  612. + byte_offset_unsafe(strides_, ssize_t(index)...));
  613. }
  614. /**
  615. * Unchecked const reference access to data; this operator only participates if the reference
  616. * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
  617. */
  618. template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
  619. const T &operator[](ssize_t index) const {
  620. return operator()(index);
  621. }
  622. /// Pointer access to the data at the given indices.
  623. template <typename... Ix>
  624. const T *data(Ix... ix) const {
  625. return &operator()(ssize_t(ix)...);
  626. }
  627. /// Returns the item size, i.e. sizeof(T)
  628. constexpr static ssize_t itemsize() { return sizeof(T); }
  629. /// Returns the shape (i.e. size) of dimension `dim`
  630. ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
  631. /// Returns the number of dimensions of the array
  632. ssize_t ndim() const { return dims_; }
  633. /// Returns the total number of elements in the referenced array, i.e. the product of the
  634. /// shapes
  635. template <bool Dyn = Dynamic>
  636. enable_if_t<!Dyn, ssize_t> size() const {
  637. return std::accumulate(
  638. shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
  639. }
  640. template <bool Dyn = Dynamic>
  641. enable_if_t<Dyn, ssize_t> size() const {
  642. return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
  643. }
  644. /// Returns the total number of bytes used by the referenced data. Note that the actual span
  645. /// in memory may be larger if the referenced array has non-contiguous strides (e.g. for a
  646. /// slice).
  647. ssize_t nbytes() const { return size() * itemsize(); }
  648. };
  649. template <typename T, ssize_t Dims>
  650. class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
  651. friend class pybind11::array;
  652. using ConstBase = unchecked_reference<T, Dims>;
  653. using ConstBase::ConstBase;
  654. using ConstBase::Dynamic;
  655. public:
  656. // Bring in const-qualified versions from base class
  657. using ConstBase::operator();
  658. using ConstBase::operator[];
  659. /// Mutable, unchecked access to data at the given indices.
  660. template <typename... Ix>
  661. T &operator()(Ix... index) {
  662. static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
  663. "Invalid number of indices for unchecked array reference");
  664. return const_cast<T &>(ConstBase::operator()(index...));
  665. }
  666. /**
  667. * Mutable, unchecked access data at the given index; this operator only participates if the
  668. * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
  669. * exactly equivalent to `obj(index)`.
  670. */
  671. template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
  672. T &operator[](ssize_t index) {
  673. return operator()(index);
  674. }
  675. /// Mutable pointer access to the data at the given indices.
  676. template <typename... Ix>
  677. T *mutable_data(Ix... ix) {
  678. return &operator()(ssize_t(ix)...);
  679. }
  680. };
  681. template <typename T, ssize_t Dim>
  682. struct type_caster<unchecked_reference<T, Dim>> {
  683. static_assert(Dim == 0 && Dim > 0 /* always fail */,
  684. "unchecked array proxy object is not castable");
  685. };
  686. template <typename T, ssize_t Dim>
  687. struct type_caster<unchecked_mutable_reference<T, Dim>>
  688. : type_caster<unchecked_reference<T, Dim>> {};
  689. template <typename T>
  690. struct type_caster<numpy_scalar<T>> {
  691. using value_type = T;
  692. using type_info = numpy_scalar_info<T>;
  693. PYBIND11_TYPE_CASTER(numpy_scalar<T>, type_info::name);
  694. static handle &target_type() {
  695. static handle tp = npy_api::get().PyArray_TypeObjectFromType_(type_info::typenum);
  696. return tp;
  697. }
  698. static handle &target_dtype() {
  699. static handle tp = npy_api::get().PyArray_DescrFromType_(type_info::typenum);
  700. return tp;
  701. }
  702. bool load(handle src, bool) {
  703. if (isinstance(src, target_type())) {
  704. npy_api::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value);
  705. return true;
  706. }
  707. return false;
  708. }
  709. static handle cast(numpy_scalar<T> src, return_value_policy, handle) {
  710. return npy_api::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr);
  711. }
  712. };
  713. PYBIND11_NAMESPACE_END(detail)
  714. template <typename T>
  715. struct numpy_scalar {
  716. using value_type = T;
  717. value_type value;
  718. numpy_scalar() = default;
  719. explicit numpy_scalar(value_type value) : value(value) {}
  720. explicit operator value_type() const { return value; }
  721. numpy_scalar &operator=(value_type value) {
  722. this->value = value;
  723. return *this;
  724. }
  725. friend bool operator==(const numpy_scalar &a, const numpy_scalar &b) {
  726. return a.value == b.value;
  727. }
  728. friend bool operator!=(const numpy_scalar &a, const numpy_scalar &b) { return !(a == b); }
  729. };
  730. template <typename T>
  731. numpy_scalar<T> make_scalar(T value) {
  732. return numpy_scalar<T>(value);
  733. }
  734. class dtype : public object {
  735. public:
  736. PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_)
  737. explicit dtype(const buffer_info &info) {
  738. dtype descr(_dtype_from_pep3118()(pybind11::str(info.format)));
  739. // If info.itemsize == 0, use the value calculated from the format string
  740. m_ptr = descr.strip_padding(info.itemsize != 0 ? info.itemsize : descr.itemsize())
  741. .release()
  742. .ptr();
  743. }
  744. explicit dtype(const pybind11::str &format) : dtype(from_args(format)) {}
  745. explicit dtype(const std::string &format) : dtype(pybind11::str(format)) {}
  746. explicit dtype(const char *format) : dtype(pybind11::str(format)) {}
  747. dtype(list names, list formats, list offsets, ssize_t itemsize) {
  748. dict args;
  749. args["names"] = std::move(names);
  750. args["formats"] = std::move(formats);
  751. args["offsets"] = std::move(offsets);
  752. args["itemsize"] = pybind11::int_(itemsize);
  753. m_ptr = from_args(args).release().ptr();
  754. }
  755. /// Return dtype for the given typenum (one of the NPY_TYPES).
  756. /// https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_DescrFromType
  757. explicit dtype(int typenum)
  758. : object(detail::npy_api::get().PyArray_DescrFromType_(typenum), stolen_t{}) {
  759. if (m_ptr == nullptr) {
  760. throw error_already_set();
  761. }
  762. }
  763. /// This is essentially the same as calling numpy.dtype(args) in Python.
  764. static dtype from_args(const object &args) {
  765. PyObject *ptr = nullptr;
  766. if ((detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) == 0) || !ptr) {
  767. throw error_already_set();
  768. }
  769. return reinterpret_steal<dtype>(ptr);
  770. }
  771. /// Return dtype associated with a C++ type.
  772. template <typename T>
  773. static dtype of() {
  774. return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
  775. }
  776. /// Return the type number associated with a C++ type.
  777. /// This is the constexpr equivalent of `dtype::of<T>().num()`.
  778. template <typename T>
  779. static constexpr int num_of() {
  780. return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::value;
  781. }
  782. /// Size of the data type in bytes.
  783. ssize_t itemsize() const {
  784. if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
  785. return detail::array_descriptor1_proxy(m_ptr)->elsize;
  786. }
  787. return detail::array_descriptor2_proxy(m_ptr)->elsize;
  788. }
  789. /// Returns true for structured data types.
  790. bool has_fields() const {
  791. if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
  792. return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
  793. }
  794. const auto *proxy = detail::array_descriptor2_proxy(m_ptr);
  795. if (proxy->type_num < 0 || proxy->type_num >= 2056) {
  796. return false;
  797. }
  798. return proxy->names != nullptr;
  799. }
  800. /// Single-character code for dtype's kind.
  801. /// For example, floating point types are 'f' and integral types are 'i'.
  802. char kind() const { return detail::array_descriptor_proxy(m_ptr)->kind; }
  803. /// Single-character for dtype's type.
  804. /// For example, ``float`` is 'f', ``double`` 'd', ``int`` 'i', and ``long`` 'l'.
  805. char char_() const {
  806. // Note: The signature, `dtype::char_` follows the naming of NumPy's
  807. // public Python API (i.e., ``dtype.char``), rather than its internal
  808. // C API (``PyArray_Descr::type``).
  809. return detail::array_descriptor_proxy(m_ptr)->type;
  810. }
  811. /// Type number of dtype. Note that different values may be returned for equivalent types,
  812. /// e.g. even though ``long`` may be equivalent to ``int`` or ``long long``, they still have
  813. /// different type numbers. Consider using `normalized_num` to avoid this.
  814. int num() const {
  815. // Note: The signature, `dtype::num` follows the naming of NumPy's public
  816. // Python API (i.e., ``dtype.num``), rather than its internal
  817. // C API (``PyArray_Descr::type_num``).
  818. return detail::array_descriptor_proxy(m_ptr)->type_num;
  819. }
  820. /// Type number of dtype, normalized to match the return value of `num_of` for equivalent
  821. /// types. This function can be used to write switch statements that correctly handle
  822. /// equivalent types with different type numbers.
  823. int normalized_num() const {
  824. int value = num();
  825. if (value >= 0 && value <= detail::npy_api::NPY_VOID_) {
  826. return detail::normalized_dtype_num[value];
  827. }
  828. return value;
  829. }
  830. /// Single character for byteorder
  831. char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
  832. /// Alignment of the data type
  833. ssize_t alignment() const {
  834. if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
  835. return detail::array_descriptor1_proxy(m_ptr)->alignment;
  836. }
  837. return detail::array_descriptor2_proxy(m_ptr)->alignment;
  838. }
  839. /// Flags for the array descriptor
  840. std::uint64_t flags() const {
  841. if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
  842. return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
  843. }
  844. return detail::array_descriptor2_proxy(m_ptr)->flags;
  845. }
  846. private:
  847. static object &_dtype_from_pep3118() {
  848. PYBIND11_CONSTINIT static gil_safe_call_once_and_store<object> storage;
  849. return storage
  850. .call_once_and_store_result([]() {
  851. return detail::import_numpy_core_submodule("_internal")
  852. .attr("_dtype_from_pep3118");
  853. })
  854. .get_stored();
  855. }
  856. dtype strip_padding(ssize_t itemsize) {
  857. // Recursively strip all void fields with empty names that are generated for
  858. // padding fields (as of NumPy v1.11).
  859. if (!has_fields()) {
  860. return *this;
  861. }
  862. struct field_descr {
  863. pybind11::str name;
  864. object format;
  865. pybind11::int_ offset;
  866. field_descr(pybind11::str &&name, object &&format, pybind11::int_ &&offset)
  867. : name{std::move(name)}, format{std::move(format)}, offset{std::move(offset)} {};
  868. };
  869. auto field_dict = attr("fields").cast<dict>();
  870. std::vector<field_descr> field_descriptors;
  871. field_descriptors.reserve(field_dict.size());
  872. for (auto field : field_dict.attr("items")()) {
  873. auto spec = field.cast<tuple>();
  874. auto name = spec[0].cast<pybind11::str>();
  875. auto spec_fo = spec[1].cast<tuple>();
  876. auto format = spec_fo[0].cast<dtype>();
  877. auto offset = spec_fo[1].cast<pybind11::int_>();
  878. if ((len(name) == 0u) && format.kind() == 'V') {
  879. continue;
  880. }
  881. field_descriptors.emplace_back(
  882. std::move(name), format.strip_padding(format.itemsize()), std::move(offset));
  883. }
  884. std::sort(field_descriptors.begin(),
  885. field_descriptors.end(),
  886. [](const field_descr &a, const field_descr &b) {
  887. return a.offset.cast<int>() < b.offset.cast<int>();
  888. });
  889. list names, formats, offsets;
  890. for (auto &descr : field_descriptors) {
  891. names.append(std::move(descr.name));
  892. formats.append(std::move(descr.format));
  893. offsets.append(std::move(descr.offset));
  894. }
  895. return dtype(std::move(names), std::move(formats), std::move(offsets), itemsize);
  896. }
  897. };
  898. class array : public buffer {
  899. public:
  900. PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
  901. enum {
  902. c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
  903. f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
  904. forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
  905. };
  906. array() : array(0, static_cast<const double *>(nullptr)) {}
  907. using ShapeContainer = detail::any_container<ssize_t>;
  908. using StridesContainer = detail::any_container<ssize_t>;
  909. // Constructs an array taking shape/strides from arbitrary container types
  910. array(const pybind11::dtype &dt,
  911. ShapeContainer shape,
  912. StridesContainer strides,
  913. const void *ptr = nullptr,
  914. handle base = handle()) {
  915. if (strides->empty()) {
  916. *strides = detail::c_strides(*shape, dt.itemsize());
  917. }
  918. auto ndim = shape->size();
  919. if (ndim != strides->size()) {
  920. pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
  921. }
  922. auto descr = dt;
  923. int flags = 0;
  924. if (base && ptr) {
  925. if (isinstance<array>(base)) {
  926. /* Copy flags from base (except ownership bit) */
  927. flags = reinterpret_borrow<array>(base).flags()
  928. & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
  929. } else {
  930. /* Writable by default, easy to downgrade later on if needed */
  931. flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
  932. }
  933. }
  934. auto &api = detail::npy_api::get();
  935. auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
  936. api.PyArray_Type_,
  937. descr.release().ptr(),
  938. (int) ndim,
  939. // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
  940. reinterpret_cast<Py_intptr_t *>(shape->data()),
  941. reinterpret_cast<Py_intptr_t *>(strides->data()),
  942. const_cast<void *>(ptr),
  943. flags,
  944. nullptr));
  945. if (!tmp) {
  946. throw error_already_set();
  947. }
  948. if (ptr) {
  949. if (base) {
  950. api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
  951. } else {
  952. tmp = reinterpret_steal<object>(
  953. api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
  954. }
  955. }
  956. m_ptr = tmp.release().ptr();
  957. }
  958. array(const pybind11::dtype &dt,
  959. ShapeContainer shape,
  960. const void *ptr = nullptr,
  961. handle base = handle())
  962. : array(dt, std::move(shape), {}, ptr, base) {}
  963. template <typename T,
  964. typename
  965. = detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
  966. array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
  967. : array(dt, {{count}}, ptr, base) {}
  968. template <typename T>
  969. array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
  970. : array(pybind11::dtype::of<T>(),
  971. std::move(shape),
  972. std::move(strides),
  973. reinterpret_cast<const void *>(ptr),
  974. base) {}
  975. template <typename T>
  976. array(ShapeContainer shape, const T *ptr, handle base = handle())
  977. : array(std::move(shape), {}, ptr, base) {}
  978. template <typename T>
  979. explicit array(ssize_t count, const T *ptr, handle base = handle())
  980. : array({count}, {}, ptr, base) {}
  981. explicit array(const buffer_info &info, handle base = handle())
  982. : array(pybind11::dtype(info), info.shape, info.strides, info.ptr, base) {}
  983. /// Array descriptor (dtype)
  984. pybind11::dtype dtype() const {
  985. return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
  986. }
  987. /// Total number of elements
  988. ssize_t size() const {
  989. return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
  990. }
  991. /// Byte size of a single element
  992. ssize_t itemsize() const { return dtype().itemsize(); }
  993. /// Total number of bytes
  994. ssize_t nbytes() const { return size() * itemsize(); }
  995. /// Number of dimensions
  996. ssize_t ndim() const { return detail::array_proxy(m_ptr)->nd; }
  997. /// Base object
  998. object base() const { return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base); }
  999. /// Dimensions of the array
  1000. const ssize_t *shape() const { return detail::array_proxy(m_ptr)->dimensions; }
  1001. /// Dimension along a given axis
  1002. ssize_t shape(ssize_t dim) const {
  1003. if (dim >= ndim()) {
  1004. fail_dim_check(dim, "invalid axis");
  1005. }
  1006. return shape()[dim];
  1007. }
  1008. /// Strides of the array
  1009. const ssize_t *strides() const { return detail::array_proxy(m_ptr)->strides; }
  1010. /// Stride along a given axis
  1011. ssize_t strides(ssize_t dim) const {
  1012. if (dim >= ndim()) {
  1013. fail_dim_check(dim, "invalid axis");
  1014. }
  1015. return strides()[dim];
  1016. }
  1017. /// Return the NumPy array flags
  1018. int flags() const { return detail::array_proxy(m_ptr)->flags; }
  1019. /// If set, the array is writeable (otherwise the buffer is read-only)
  1020. bool writeable() const {
  1021. return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
  1022. }
  1023. /// If set, the array owns the data (will be freed when the array is deleted)
  1024. bool owndata() const {
  1025. return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
  1026. }
  1027. /// Pointer to the contained data. If index is not provided, points to the
  1028. /// beginning of the buffer. May throw if the index would lead to out of bounds access.
  1029. template <typename... Ix>
  1030. const void *data(Ix... index) const {
  1031. return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
  1032. }
  1033. /// Mutable pointer to the contained data. If index is not provided, points to the
  1034. /// beginning of the buffer. May throw if the index would lead to out of bounds access.
  1035. /// May throw if the array is not writeable.
  1036. template <typename... Ix>
  1037. void *mutable_data(Ix... index) {
  1038. check_writeable();
  1039. return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
  1040. }
  1041. /// Byte offset from beginning of the array to a given index (full or partial).
  1042. /// May throw if the index would lead to out of bounds access.
  1043. template <typename... Ix>
  1044. ssize_t offset_at(Ix... index) const {
  1045. if ((ssize_t) sizeof...(index) > ndim()) {
  1046. fail_dim_check(sizeof...(index), "too many indices for an array");
  1047. }
  1048. return byte_offset(ssize_t(index)...);
  1049. }
  1050. ssize_t offset_at() const { return 0; }
  1051. /// Item count from beginning of the array to a given index (full or partial).
  1052. /// May throw if the index would lead to out of bounds access.
  1053. template <typename... Ix>
  1054. ssize_t index_at(Ix... index) const {
  1055. return offset_at(index...) / itemsize();
  1056. }
  1057. /**
  1058. * Returns a proxy object that provides access to the array's data without bounds or
  1059. * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
  1060. * care: the array must not be destroyed or reshaped for the duration of the returned object,
  1061. * and the caller must take care not to access invalid dimensions or dimension indices.
  1062. */
  1063. template <typename T, ssize_t Dims = -1>
  1064. detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
  1065. if (Dims >= 0 && ndim() != Dims) {
  1066. throw std::domain_error("array has incorrect number of dimensions: "
  1067. + std::to_string(ndim()) + "; expected "
  1068. + std::to_string(Dims));
  1069. }
  1070. return detail::unchecked_mutable_reference<T, Dims>(
  1071. mutable_data(), shape(), strides(), ndim());
  1072. }
  1073. /**
  1074. * Returns a proxy object that provides const access to the array's data without bounds or
  1075. * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
  1076. * underlying array have the `writable` flag. Use with care: the array must not be destroyed
  1077. * or reshaped for the duration of the returned object, and the caller must take care not to
  1078. * access invalid dimensions or dimension indices.
  1079. */
  1080. template <typename T, ssize_t Dims = -1>
  1081. detail::unchecked_reference<T, Dims> unchecked() const & {
  1082. if (Dims >= 0 && ndim() != Dims) {
  1083. throw std::domain_error("array has incorrect number of dimensions: "
  1084. + std::to_string(ndim()) + "; expected "
  1085. + std::to_string(Dims));
  1086. }
  1087. return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
  1088. }
  1089. /// Return a new view with all of the dimensions of length 1 removed
  1090. array squeeze() {
  1091. auto &api = detail::npy_api::get();
  1092. return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
  1093. }
  1094. /// Resize array to given shape
  1095. /// If refcheck is true and more that one reference exist to this array
  1096. /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
  1097. void resize(ShapeContainer new_shape, bool refcheck = true) {
  1098. detail::npy_api::PyArray_Dims d
  1099. = {// Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
  1100. reinterpret_cast<Py_intptr_t *>(new_shape->data()),
  1101. int(new_shape->size())};
  1102. // try to resize, set ordering param to -1 cause it's not used anyway
  1103. auto new_array = reinterpret_steal<object>(
  1104. detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1));
  1105. if (!new_array) {
  1106. throw error_already_set();
  1107. }
  1108. if (isinstance<array>(new_array)) {
  1109. *this = std::move(new_array);
  1110. }
  1111. }
  1112. /// Optional `order` parameter omitted, to be added as needed.
  1113. array reshape(ShapeContainer new_shape) {
  1114. detail::npy_api::PyArray_Dims d
  1115. = {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
  1116. auto new_array
  1117. = reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
  1118. if (!new_array) {
  1119. throw error_already_set();
  1120. }
  1121. return new_array;
  1122. }
  1123. /// Create a view of an array in a different data type.
  1124. /// This function may fundamentally reinterpret the data in the array.
  1125. /// It is the responsibility of the caller to ensure that this is safe.
  1126. /// Only supports the `dtype` argument, the `type` argument is omitted,
  1127. /// to be added as needed.
  1128. array view(const std::string &dtype) {
  1129. auto &api = detail::npy_api::get();
  1130. auto new_view = reinterpret_steal<array>(api.PyArray_View_(
  1131. m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr));
  1132. if (!new_view) {
  1133. throw error_already_set();
  1134. }
  1135. return new_view;
  1136. }
  1137. /// Ensure that the argument is a NumPy array
  1138. /// In case of an error, nullptr is returned and the Python error is cleared.
  1139. static array ensure(handle h, int ExtraFlags = 0) {
  1140. auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
  1141. if (!result) {
  1142. PyErr_Clear();
  1143. }
  1144. return result;
  1145. }
  1146. protected:
  1147. template <typename, typename>
  1148. friend struct detail::npy_format_descriptor;
  1149. void fail_dim_check(ssize_t dim, const std::string &msg) const {
  1150. throw index_error(msg + ": " + std::to_string(dim) + " (ndim = " + std::to_string(ndim())
  1151. + ')');
  1152. }
  1153. template <typename... Ix>
  1154. ssize_t byte_offset(Ix... index) const {
  1155. check_dimensions(index...);
  1156. return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
  1157. }
  1158. void check_writeable() const {
  1159. if (!writeable()) {
  1160. throw std::domain_error("array is not writeable");
  1161. }
  1162. }
  1163. template <typename... Ix>
  1164. void check_dimensions(Ix... index) const {
  1165. check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
  1166. }
  1167. void check_dimensions_impl(ssize_t, const ssize_t *) const {}
  1168. template <typename... Ix>
  1169. void check_dimensions_impl(ssize_t axis, const ssize_t *shape, ssize_t i, Ix... index) const {
  1170. if (i >= *shape) {
  1171. throw index_error(std::string("index ") + std::to_string(i)
  1172. + " is out of bounds for axis " + std::to_string(axis)
  1173. + " with size " + std::to_string(*shape));
  1174. }
  1175. check_dimensions_impl(axis + 1, shape + 1, index...);
  1176. }
  1177. /// Create array from any object -- always returns a new reference
  1178. static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
  1179. if (ptr == nullptr) {
  1180. set_error(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
  1181. return nullptr;
  1182. }
  1183. return detail::npy_api::get().PyArray_FromAny_(
  1184. ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
  1185. }
  1186. };
  1187. template <typename T, int ExtraFlags = array::forcecast>
  1188. class array_t : public array {
  1189. private:
  1190. struct private_ctor {};
  1191. // Delegating constructor needed when both moving and accessing in the same constructor
  1192. array_t(private_ctor,
  1193. ShapeContainer &&shape,
  1194. StridesContainer &&strides,
  1195. const T *ptr,
  1196. handle base)
  1197. : array(std::move(shape), std::move(strides), ptr, base) {}
  1198. public:
  1199. static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
  1200. using value_type = T;
  1201. array_t() : array(0, static_cast<const T *>(nullptr)) {}
  1202. array_t(handle h, borrowed_t) : array(h, borrowed_t{}) {}
  1203. array_t(handle h, stolen_t) : array(h, stolen_t{}) {}
  1204. PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
  1205. array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
  1206. if (!m_ptr) {
  1207. PyErr_Clear();
  1208. }
  1209. if (!is_borrowed) {
  1210. Py_XDECREF(h.ptr());
  1211. }
  1212. }
  1213. // NOLINTNEXTLINE(google-explicit-constructor)
  1214. array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
  1215. if (!m_ptr) {
  1216. throw error_already_set();
  1217. }
  1218. }
  1219. explicit array_t(const buffer_info &info, handle base = handle()) : array(info, base) {}
  1220. array_t(ShapeContainer shape,
  1221. StridesContainer strides,
  1222. const T *ptr = nullptr,
  1223. handle base = handle())
  1224. : array(std::move(shape), std::move(strides), ptr, base) {}
  1225. explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
  1226. : array_t(private_ctor{},
  1227. std::move(shape),
  1228. (ExtraFlags & f_style) != 0 ? detail::f_strides(*shape, itemsize())
  1229. : detail::c_strides(*shape, itemsize()),
  1230. ptr,
  1231. base) {}
  1232. explicit array_t(ssize_t count, const T *ptr = nullptr, handle base = handle())
  1233. : array({count}, {}, ptr, base) {}
  1234. constexpr ssize_t itemsize() const { return sizeof(T); }
  1235. template <typename... Ix>
  1236. ssize_t index_at(Ix... index) const {
  1237. return offset_at(index...) / itemsize();
  1238. }
  1239. template <typename... Ix>
  1240. const T *data(Ix... index) const {
  1241. return static_cast<const T *>(array::data(index...));
  1242. }
  1243. template <typename... Ix>
  1244. T *mutable_data(Ix... index) {
  1245. return static_cast<T *>(array::mutable_data(index...));
  1246. }
  1247. // Reference to element at a given index
  1248. template <typename... Ix>
  1249. const T &at(Ix... index) const {
  1250. if ((ssize_t) sizeof...(index) != ndim()) {
  1251. fail_dim_check(sizeof...(index), "index dimension mismatch");
  1252. }
  1253. return *(static_cast<const T *>(array::data())
  1254. + byte_offset(ssize_t(index)...) / itemsize());
  1255. }
  1256. // Mutable reference to element at a given index
  1257. template <typename... Ix>
  1258. T &mutable_at(Ix... index) {
  1259. if ((ssize_t) sizeof...(index) != ndim()) {
  1260. fail_dim_check(sizeof...(index), "index dimension mismatch");
  1261. }
  1262. return *(static_cast<T *>(array::mutable_data())
  1263. + byte_offset(ssize_t(index)...) / itemsize());
  1264. }
  1265. /**
  1266. * Returns a proxy object that provides access to the array's data without bounds or
  1267. * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
  1268. * care: the array must not be destroyed or reshaped for the duration of the returned object,
  1269. * and the caller must take care not to access invalid dimensions or dimension indices.
  1270. */
  1271. template <ssize_t Dims = -1>
  1272. detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
  1273. return array::mutable_unchecked<T, Dims>();
  1274. }
  1275. /**
  1276. * Returns a proxy object that provides const access to the array's data without bounds or
  1277. * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
  1278. * underlying array have the `writable` flag. Use with care: the array must not be destroyed
  1279. * or reshaped for the duration of the returned object, and the caller must take care not to
  1280. * access invalid dimensions or dimension indices.
  1281. */
  1282. template <ssize_t Dims = -1>
  1283. detail::unchecked_reference<T, Dims> unchecked() const & {
  1284. return array::unchecked<T, Dims>();
  1285. }
  1286. /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
  1287. /// it). In case of an error, nullptr is returned and the Python error is cleared.
  1288. static array_t ensure(handle h) {
  1289. auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
  1290. if (!result) {
  1291. PyErr_Clear();
  1292. }
  1293. return result;
  1294. }
  1295. static bool check_(handle h) {
  1296. const auto &api = detail::npy_api::get();
  1297. return api.PyArray_Check_(h.ptr())
  1298. && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr,
  1299. dtype::of<T>().ptr())
  1300. && detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
  1301. }
  1302. protected:
  1303. /// Create array from any object -- always returns a new reference
  1304. static PyObject *raw_array_t(PyObject *ptr) {
  1305. if (ptr == nullptr) {
  1306. set_error(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
  1307. return nullptr;
  1308. }
  1309. return detail::npy_api::get().PyArray_FromAny_(ptr,
  1310. dtype::of<T>().release().ptr(),
  1311. 0,
  1312. 0,
  1313. detail::npy_api::NPY_ARRAY_ENSUREARRAY_
  1314. | ExtraFlags,
  1315. nullptr);
  1316. }
  1317. };
  1318. template <typename T>
  1319. struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
  1320. static std::string format() {
  1321. return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
  1322. }
  1323. };
  1324. template <size_t N>
  1325. struct format_descriptor<char[N]> {
  1326. static std::string format() { return std::to_string(N) + 's'; }
  1327. };
  1328. template <size_t N>
  1329. struct format_descriptor<std::array<char, N>> {
  1330. static std::string format() { return std::to_string(N) + 's'; }
  1331. };
  1332. template <typename T>
  1333. struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
  1334. static std::string format() {
  1335. return format_descriptor<
  1336. typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
  1337. }
  1338. };
  1339. template <typename T>
  1340. struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
  1341. static std::string format() {
  1342. using namespace detail;
  1343. static constexpr auto extents = const_name("(") + array_info<T>::extents + const_name(")");
  1344. return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
  1345. }
  1346. };
  1347. PYBIND11_NAMESPACE_BEGIN(detail)
  1348. template <typename T, int ExtraFlags>
  1349. struct pyobject_caster<array_t<T, ExtraFlags>> {
  1350. using type = array_t<T, ExtraFlags>;
  1351. bool load(handle src, bool convert) {
  1352. if (!convert && !type::check_(src)) {
  1353. return false;
  1354. }
  1355. value = type::ensure(src);
  1356. return static_cast<bool>(value);
  1357. }
  1358. static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
  1359. return src.inc_ref();
  1360. }
  1361. PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
  1362. };
  1363. template <typename T>
  1364. struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
  1365. static bool compare(const buffer_info &b) {
  1366. return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
  1367. }
  1368. };
  1369. template <typename T>
  1370. struct npy_format_descriptor<
  1371. T,
  1372. enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
  1373. : npy_format_descriptor_name<T> {
  1374. private:
  1375. // NB: the order here must match the one in common.h
  1376. constexpr static const int values[15] = {npy_api::NPY_BOOL_,
  1377. npy_api::NPY_BYTE_,
  1378. npy_api::NPY_UBYTE_,
  1379. npy_api::NPY_INT16_,
  1380. npy_api::NPY_UINT16_,
  1381. npy_api::NPY_INT32_,
  1382. npy_api::NPY_UINT32_,
  1383. npy_api::NPY_INT64_,
  1384. npy_api::NPY_UINT64_,
  1385. npy_api::NPY_FLOAT_,
  1386. npy_api::NPY_DOUBLE_,
  1387. npy_api::NPY_LONGDOUBLE_,
  1388. npy_api::NPY_CFLOAT_,
  1389. npy_api::NPY_CDOUBLE_,
  1390. npy_api::NPY_CLONGDOUBLE_};
  1391. public:
  1392. static constexpr int value = values[detail::is_fmt_numeric<T>::index];
  1393. static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
  1394. };
  1395. template <typename T>
  1396. struct npy_format_descriptor<
  1397. T,
  1398. enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value
  1399. || ((std::is_same<T, handle>::value || std::is_same<T, object>::value)
  1400. && sizeof(T) == sizeof(PyObject *))>> {
  1401. static constexpr auto name = const_name("numpy.object_");
  1402. static constexpr int value = npy_api::NPY_OBJECT_;
  1403. static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
  1404. };
  1405. #define PYBIND11_DECL_CHAR_FMT \
  1406. static constexpr auto name = const_name("S") + const_name<N>(); \
  1407. static pybind11::dtype dtype() { \
  1408. return pybind11::dtype(std::string("S") + std::to_string(N)); \
  1409. }
  1410. template <size_t N>
  1411. struct npy_format_descriptor<char[N]> {
  1412. PYBIND11_DECL_CHAR_FMT
  1413. };
  1414. template <size_t N>
  1415. struct npy_format_descriptor<std::array<char, N>> {
  1416. PYBIND11_DECL_CHAR_FMT
  1417. };
  1418. #undef PYBIND11_DECL_CHAR_FMT
  1419. template <typename T>
  1420. struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
  1421. private:
  1422. using base_descr = npy_format_descriptor<typename array_info<T>::type>;
  1423. public:
  1424. static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
  1425. static constexpr auto name
  1426. = const_name("(") + array_info<T>::extents + const_name(")") + base_descr::name;
  1427. static pybind11::dtype dtype() {
  1428. list shape;
  1429. array_info<T>::append_extents(shape);
  1430. return pybind11::dtype::from_args(
  1431. pybind11::make_tuple(base_descr::dtype(), std::move(shape)));
  1432. }
  1433. };
  1434. template <typename T>
  1435. struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
  1436. private:
  1437. using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
  1438. public:
  1439. static constexpr auto name = base_descr::name;
  1440. static pybind11::dtype dtype() { return base_descr::dtype(); }
  1441. };
  1442. struct field_descriptor {
  1443. const char *name;
  1444. ssize_t offset;
  1445. ssize_t size;
  1446. std::string format;
  1447. dtype descr;
  1448. };
  1449. PYBIND11_NOINLINE void register_structured_dtype(any_container<field_descriptor> fields,
  1450. const std::type_info &tinfo,
  1451. ssize_t itemsize,
  1452. bool (*direct_converter)(PyObject *, void *&)) {
  1453. auto &numpy_internals = get_numpy_internals();
  1454. if (numpy_internals.get_type_info(tinfo, false)) {
  1455. pybind11_fail("NumPy: dtype is already registered");
  1456. }
  1457. // Use ordered fields because order matters as of NumPy 1.14:
  1458. // https://docs.scipy.org/doc/numpy/release.html#multiple-field-indexing-assignment-of-structured-arrays
  1459. std::vector<field_descriptor> ordered_fields(std::move(fields));
  1460. std::sort(
  1461. ordered_fields.begin(),
  1462. ordered_fields.end(),
  1463. [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
  1464. list names, formats, offsets;
  1465. for (auto &field : ordered_fields) {
  1466. if (!field.descr) {
  1467. pybind11_fail(std::string("NumPy: unsupported field dtype: `") + field.name + "` @ "
  1468. + tinfo.name());
  1469. }
  1470. names.append(pybind11::str(field.name));
  1471. formats.append(field.descr);
  1472. offsets.append(pybind11::int_(field.offset));
  1473. }
  1474. auto *dtype_ptr
  1475. = pybind11::dtype(std::move(names), std::move(formats), std::move(offsets), itemsize)
  1476. .release()
  1477. .ptr();
  1478. // There is an existing bug in NumPy (as of v1.11): trailing bytes are
  1479. // not encoded explicitly into the format string. This will supposedly
  1480. // get fixed in v1.12; for further details, see these:
  1481. // - https://github.com/numpy/numpy/issues/7797
  1482. // - https://github.com/numpy/numpy/pull/7798
  1483. // Because of this, we won't use numpy's logic to generate buffer format
  1484. // strings and will just do it ourselves.
  1485. ssize_t offset = 0;
  1486. std::ostringstream oss;
  1487. // mark the structure as unaligned with '^', because numpy and C++ don't
  1488. // always agree about alignment (particularly for complex), and we're
  1489. // explicitly listing all our padding. This depends on none of the fields
  1490. // overriding the endianness. Putting the ^ in front of individual fields
  1491. // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
  1492. oss << "^T{";
  1493. for (auto &field : ordered_fields) {
  1494. if (field.offset > offset) {
  1495. oss << (field.offset - offset) << 'x';
  1496. }
  1497. oss << field.format << ':' << field.name << ':';
  1498. offset = field.offset + field.size;
  1499. }
  1500. if (itemsize > offset) {
  1501. oss << (itemsize - offset) << 'x';
  1502. }
  1503. oss << '}';
  1504. auto format_str = oss.str();
  1505. // Smoke test: verify that NumPy properly parses our buffer format string
  1506. auto &api = npy_api::get();
  1507. auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
  1508. if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) {
  1509. pybind11_fail("NumPy: invalid buffer descriptor!");
  1510. }
  1511. auto tindex = std::type_index(tinfo);
  1512. numpy_internals.registered_dtypes[tindex] = {dtype_ptr, std::move(format_str)};
  1513. with_internals([tindex, &direct_converter](internals &internals) {
  1514. internals.direct_conversions[tindex].push_back(direct_converter);
  1515. });
  1516. }
  1517. template <typename T, typename SFINAE>
  1518. struct npy_format_descriptor {
  1519. static_assert(is_pod_struct<T>::value,
  1520. "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
  1521. static constexpr auto name = make_caster<T>::name;
  1522. static pybind11::dtype dtype() { return reinterpret_borrow<pybind11::dtype>(dtype_ptr()); }
  1523. static std::string format() {
  1524. static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
  1525. return format_str;
  1526. }
  1527. static void register_dtype(any_container<field_descriptor> fields) {
  1528. register_structured_dtype(std::move(fields),
  1529. typeid(typename std::remove_cv<T>::type),
  1530. sizeof(T),
  1531. &direct_converter);
  1532. }
  1533. private:
  1534. static PyObject *dtype_ptr() {
  1535. static PyObject *ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
  1536. return ptr;
  1537. }
  1538. static bool direct_converter(PyObject *obj, void *&value) {
  1539. auto &api = npy_api::get();
  1540. if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) {
  1541. return false;
  1542. }
  1543. if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
  1544. if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
  1545. value = ((PyVoidScalarObject_Proxy *) obj)->obval;
  1546. return true;
  1547. }
  1548. }
  1549. return false;
  1550. }
  1551. };
  1552. #ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
  1553. # define PYBIND11_NUMPY_DTYPE(Type, ...) ((void) 0)
  1554. # define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void) 0)
  1555. #else
  1556. # define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
  1557. ::pybind11::detail::field_descriptor { \
  1558. Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
  1559. ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
  1560. ::pybind11::detail::npy_format_descriptor< \
  1561. decltype(std::declval<T>().Field)>::dtype() \
  1562. }
  1563. // Extract name, offset and format descriptor for a struct field
  1564. # define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
  1565. // The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
  1566. // (C) William Swanson, Paul Fultz
  1567. # define PYBIND11_EVAL0(...) __VA_ARGS__
  1568. # define PYBIND11_EVAL1(...) PYBIND11_EVAL0(PYBIND11_EVAL0(PYBIND11_EVAL0(__VA_ARGS__)))
  1569. # define PYBIND11_EVAL2(...) PYBIND11_EVAL1(PYBIND11_EVAL1(PYBIND11_EVAL1(__VA_ARGS__)))
  1570. # define PYBIND11_EVAL3(...) PYBIND11_EVAL2(PYBIND11_EVAL2(PYBIND11_EVAL2(__VA_ARGS__)))
  1571. # define PYBIND11_EVAL4(...) PYBIND11_EVAL3(PYBIND11_EVAL3(PYBIND11_EVAL3(__VA_ARGS__)))
  1572. # define PYBIND11_EVAL(...) PYBIND11_EVAL4(PYBIND11_EVAL4(PYBIND11_EVAL4(__VA_ARGS__)))
  1573. # define PYBIND11_MAP_END(...)
  1574. # define PYBIND11_MAP_OUT
  1575. # define PYBIND11_MAP_COMMA ,
  1576. # define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
  1577. # define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
  1578. # define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0(test, next, 0)
  1579. # define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1(PYBIND11_MAP_GET_END test, next)
  1580. # if defined(_MSC_VER) \
  1581. && !defined(__clang__) // MSVC is not as eager to expand macros, hence this workaround
  1582. # define PYBIND11_MAP_LIST_NEXT1(test, next) \
  1583. PYBIND11_EVAL0(PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0))
  1584. # else
  1585. # define PYBIND11_MAP_LIST_NEXT1(test, next) \
  1586. PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0)
  1587. # endif
  1588. # define PYBIND11_MAP_LIST_NEXT(test, next) \
  1589. PYBIND11_MAP_LIST_NEXT1(PYBIND11_MAP_GET_END test, next)
  1590. # define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
  1591. f(t, x) PYBIND11_MAP_LIST_NEXT(peek, PYBIND11_MAP_LIST1)(f, t, peek, __VA_ARGS__)
  1592. # define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
  1593. f(t, x) PYBIND11_MAP_LIST_NEXT(peek, PYBIND11_MAP_LIST0)(f, t, peek, __VA_ARGS__)
  1594. // PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
  1595. # define PYBIND11_MAP_LIST(f, t, ...) \
  1596. PYBIND11_EVAL(PYBIND11_MAP_LIST1(f, t, __VA_ARGS__, (), 0))
  1597. # define PYBIND11_NUMPY_DTYPE(Type, ...) \
  1598. ::pybind11::detail::npy_format_descriptor<Type>::register_dtype( \
  1599. ::std::vector<::pybind11::detail::field_descriptor>{ \
  1600. PYBIND11_MAP_LIST(PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
  1601. # if defined(_MSC_VER) && !defined(__clang__)
  1602. # define PYBIND11_MAP2_LIST_NEXT1(test, next) \
  1603. PYBIND11_EVAL0(PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0))
  1604. # else
  1605. # define PYBIND11_MAP2_LIST_NEXT1(test, next) \
  1606. PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0)
  1607. # endif
  1608. # define PYBIND11_MAP2_LIST_NEXT(test, next) \
  1609. PYBIND11_MAP2_LIST_NEXT1(PYBIND11_MAP_GET_END test, next)
  1610. # define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
  1611. f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT(peek, PYBIND11_MAP2_LIST1)(f, t, peek, __VA_ARGS__)
  1612. # define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
  1613. f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT(peek, PYBIND11_MAP2_LIST0)(f, t, peek, __VA_ARGS__)
  1614. // PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
  1615. # define PYBIND11_MAP2_LIST(f, t, ...) \
  1616. PYBIND11_EVAL(PYBIND11_MAP2_LIST1(f, t, __VA_ARGS__, (), 0))
  1617. # define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
  1618. ::pybind11::detail::npy_format_descriptor<Type>::register_dtype( \
  1619. ::std::vector<::pybind11::detail::field_descriptor>{ \
  1620. PYBIND11_MAP2_LIST(PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
  1621. #endif // __CLION_IDE__
  1622. class common_iterator {
  1623. public:
  1624. using container_type = std::vector<ssize_t>;
  1625. using value_type = container_type::value_type;
  1626. using size_type = container_type::size_type;
  1627. common_iterator() : m_strides() {}
  1628. common_iterator(void *ptr, const container_type &strides, const container_type &shape)
  1629. : p_ptr(reinterpret_cast<char *>(ptr)), m_strides(strides.size()) {
  1630. m_strides.back() = static_cast<value_type>(strides.back());
  1631. for (size_type i = m_strides.size() - 1; i != 0; --i) {
  1632. size_type j = i - 1;
  1633. auto s = static_cast<value_type>(shape[i]);
  1634. m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
  1635. }
  1636. }
  1637. void increment(size_type dim) { p_ptr += m_strides[dim]; }
  1638. void *data() const { return p_ptr; }
  1639. private:
  1640. char *p_ptr{nullptr};
  1641. container_type m_strides;
  1642. };
  1643. template <size_t N>
  1644. class multi_array_iterator {
  1645. public:
  1646. using container_type = std::vector<ssize_t>;
  1647. multi_array_iterator(const std::array<buffer_info, N> &buffers, const container_type &shape)
  1648. : m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator() {
  1649. // Manual copy to avoid conversion warning if using std::copy
  1650. for (size_t i = 0; i < shape.size(); ++i) {
  1651. m_shape[i] = shape[i];
  1652. }
  1653. container_type strides(shape.size());
  1654. for (size_t i = 0; i < N; ++i) {
  1655. init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
  1656. }
  1657. }
  1658. multi_array_iterator &operator++() {
  1659. for (size_t j = m_index.size(); j != 0; --j) {
  1660. size_t i = j - 1;
  1661. if (++m_index[i] != m_shape[i]) {
  1662. increment_common_iterator(i);
  1663. break;
  1664. }
  1665. m_index[i] = 0;
  1666. }
  1667. return *this;
  1668. }
  1669. template <size_t K, class T = void>
  1670. T *data() const {
  1671. return reinterpret_cast<T *>(m_common_iterator[K].data());
  1672. }
  1673. private:
  1674. using common_iter = common_iterator;
  1675. void init_common_iterator(const buffer_info &buffer,
  1676. const container_type &shape,
  1677. common_iter &iterator,
  1678. container_type &strides) {
  1679. auto buffer_shape_iter = buffer.shape.rbegin();
  1680. auto buffer_strides_iter = buffer.strides.rbegin();
  1681. auto shape_iter = shape.rbegin();
  1682. auto strides_iter = strides.rbegin();
  1683. while (buffer_shape_iter != buffer.shape.rend()) {
  1684. if (*shape_iter == *buffer_shape_iter) {
  1685. *strides_iter = *buffer_strides_iter;
  1686. } else {
  1687. *strides_iter = 0;
  1688. }
  1689. ++buffer_shape_iter;
  1690. ++buffer_strides_iter;
  1691. ++shape_iter;
  1692. ++strides_iter;
  1693. }
  1694. std::fill(strides_iter, strides.rend(), 0);
  1695. iterator = common_iter(buffer.ptr, strides, shape);
  1696. }
  1697. void increment_common_iterator(size_t dim) {
  1698. for (auto &iter : m_common_iterator) {
  1699. iter.increment(dim);
  1700. }
  1701. }
  1702. container_type m_shape;
  1703. container_type m_index;
  1704. std::array<common_iter, N> m_common_iterator;
  1705. };
  1706. enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };
  1707. // Populates the shape and number of dimensions for the set of buffers. Returns a
  1708. // broadcast_trivial enum value indicating whether the broadcast is "trivial"--that is, has each
  1709. // buffer being either a singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous
  1710. // (`f_trivial`) storage buffer; returns `non_trivial` otherwise.
  1711. template <size_t N>
  1712. broadcast_trivial
  1713. broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
  1714. ndim = std::accumulate(
  1715. buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
  1716. return std::max(res, buf.ndim);
  1717. });
  1718. shape.clear();
  1719. shape.resize((size_t) ndim, 1);
  1720. // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1
  1721. // or the full size).
  1722. for (size_t i = 0; i < N; ++i) {
  1723. auto res_iter = shape.rbegin();
  1724. auto end = buffers[i].shape.rend();
  1725. for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end;
  1726. ++shape_iter, ++res_iter) {
  1727. const auto &dim_size_in = *shape_iter;
  1728. auto &dim_size_out = *res_iter;
  1729. // Each input dimension can either be 1 or `n`, but `n` values must match across
  1730. // buffers
  1731. if (dim_size_out == 1) {
  1732. dim_size_out = dim_size_in;
  1733. } else if (dim_size_in != 1 && dim_size_in != dim_size_out) {
  1734. pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
  1735. }
  1736. }
  1737. }
  1738. bool trivial_broadcast_c = true;
  1739. bool trivial_broadcast_f = true;
  1740. for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
  1741. if (buffers[i].size == 1) {
  1742. continue;
  1743. }
  1744. // Require the same number of dimensions:
  1745. if (buffers[i].ndim != ndim) {
  1746. return broadcast_trivial::non_trivial;
  1747. }
  1748. // Require all dimensions be full-size:
  1749. if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) {
  1750. return broadcast_trivial::non_trivial;
  1751. }
  1752. // Check for C contiguity (but only if previous inputs were also C contiguous)
  1753. if (trivial_broadcast_c) {
  1754. ssize_t expect_stride = buffers[i].itemsize;
  1755. auto end = buffers[i].shape.crend();
  1756. for (auto shape_iter = buffers[i].shape.crbegin(),
  1757. stride_iter = buffers[i].strides.crbegin();
  1758. trivial_broadcast_c && shape_iter != end;
  1759. ++shape_iter, ++stride_iter) {
  1760. if (expect_stride == *stride_iter) {
  1761. expect_stride *= *shape_iter;
  1762. } else {
  1763. trivial_broadcast_c = false;
  1764. }
  1765. }
  1766. }
  1767. // Check for Fortran contiguity (if previous inputs were also F contiguous)
  1768. if (trivial_broadcast_f) {
  1769. ssize_t expect_stride = buffers[i].itemsize;
  1770. auto end = buffers[i].shape.cend();
  1771. for (auto shape_iter = buffers[i].shape.cbegin(),
  1772. stride_iter = buffers[i].strides.cbegin();
  1773. trivial_broadcast_f && shape_iter != end;
  1774. ++shape_iter, ++stride_iter) {
  1775. if (expect_stride == *stride_iter) {
  1776. expect_stride *= *shape_iter;
  1777. } else {
  1778. trivial_broadcast_f = false;
  1779. }
  1780. }
  1781. }
  1782. }
  1783. return trivial_broadcast_c ? broadcast_trivial::c_trivial
  1784. : trivial_broadcast_f ? broadcast_trivial::f_trivial
  1785. : broadcast_trivial::non_trivial;
  1786. }
  1787. template <typename T>
  1788. struct vectorize_arg {
  1789. static_assert(!std::is_rvalue_reference<T>::value,
  1790. "Functions with rvalue reference arguments cannot be vectorized");
  1791. // The wrapped function gets called with this type:
  1792. using call_type = remove_reference_t<T>;
  1793. // Is this a vectorized argument?
  1794. static constexpr bool vectorize
  1795. = satisfies_any_of<call_type, std::is_arithmetic, is_complex, is_pod>::value
  1796. && satisfies_none_of<call_type,
  1797. std::is_pointer,
  1798. std::is_array,
  1799. is_std_array,
  1800. std::is_enum>::value
  1801. && (!std::is_reference<T>::value
  1802. || (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
  1803. // Accept this type: an array for vectorized types, otherwise the type as-is:
  1804. using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
  1805. };
  1806. // py::vectorize when a return type is present
  1807. template <typename Func, typename Return, typename... Args>
  1808. struct vectorize_returned_array {
  1809. using Type = array_t<Return>;
  1810. static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
  1811. if (trivial == broadcast_trivial::f_trivial) {
  1812. return array_t<Return, array::f_style>(shape);
  1813. }
  1814. return array_t<Return>(shape);
  1815. }
  1816. static Return *mutable_data(Type &array) { return array.mutable_data(); }
  1817. static Return call(Func &f, Args &...args) { return f(args...); }
  1818. static void call(Return *out, size_t i, Func &f, Args &...args) { out[i] = f(args...); }
  1819. };
  1820. // py::vectorize when a return type is not present
  1821. template <typename Func, typename... Args>
  1822. struct vectorize_returned_array<Func, void, Args...> {
  1823. using Type = none;
  1824. static Type create(broadcast_trivial, const std::vector<ssize_t> &) { return none(); }
  1825. static void *mutable_data(Type &) { return nullptr; }
  1826. static detail::void_type call(Func &f, Args &...args) {
  1827. f(args...);
  1828. return {};
  1829. }
  1830. static void call(void *, size_t, Func &f, Args &...args) { f(args...); }
  1831. };
  1832. template <typename Func, typename Return, typename... Args>
  1833. struct vectorize_helper {
  1834. // NVCC for some reason breaks if NVectorized is private
  1835. #ifdef __CUDACC__
  1836. public:
  1837. #else
  1838. private:
  1839. #endif
  1840. static constexpr size_t N = sizeof...(Args);
  1841. static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
  1842. static_assert(
  1843. NVectorized >= 1,
  1844. "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
  1845. public:
  1846. template <typename T,
  1847. // SFINAE to prevent shadowing the copy constructor.
  1848. typename = detail::enable_if_t<
  1849. !std::is_same<vectorize_helper, typename std::decay<T>::type>::value>>
  1850. explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) {}
  1851. object operator()(typename vectorize_arg<Args>::type... args) {
  1852. return run(args...,
  1853. make_index_sequence<N>(),
  1854. select_indices<vectorize_arg<Args>::vectorize...>(),
  1855. make_index_sequence<NVectorized>());
  1856. }
  1857. private:
  1858. remove_reference_t<Func> f;
  1859. // Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling
  1860. // with "/permissive-" flag when arg_call_types is manually inlined.
  1861. using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
  1862. template <size_t Index>
  1863. using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
  1864. using returned_array = vectorize_returned_array<Func, Return, Args...>;
  1865. // Runs a vectorized function given arguments tuple and three index sequences:
  1866. // - Index is the full set of 0 ... (N-1) argument indices;
  1867. // - VIndex is the subset of argument indices with vectorized parameters, letting us access
  1868. // vectorized arguments (anything not in this sequence is passed through)
  1869. // - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that
  1870. // we can store vectorized buffer_infos in an array (argument VIndex has its buffer at
  1871. // index BIndex in the array).
  1872. template <size_t... Index, size_t... VIndex, size_t... BIndex>
  1873. object run(typename vectorize_arg<Args>::type &...args,
  1874. index_sequence<Index...> i_seq,
  1875. index_sequence<VIndex...> vi_seq,
  1876. index_sequence<BIndex...> bi_seq) {
  1877. // Pointers to values the function was called with; the vectorized ones set here will start
  1878. // out as array_t<T> pointers, but they will be changed them to T pointers before we make
  1879. // call the wrapped function. Non-vectorized pointers are left as-is.
  1880. std::array<void *, N> params{{reinterpret_cast<void *>(&args)...}};
  1881. // The array of `buffer_info`s of vectorized arguments:
  1882. std::array<buffer_info, NVectorized> buffers{
  1883. {reinterpret_cast<array *>(params[VIndex])->request()...}};
  1884. /* Determine dimensions parameters of output array */
  1885. ssize_t nd = 0;
  1886. std::vector<ssize_t> shape(0);
  1887. auto trivial = broadcast(buffers, nd, shape);
  1888. auto ndim = (size_t) nd;
  1889. size_t size
  1890. = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());
  1891. // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e.
  1892. // not wrapped in an array).
  1893. if (size == 1 && ndim == 0) {
  1894. PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
  1895. return cast(
  1896. returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
  1897. }
  1898. auto result = returned_array::create(trivial, shape);
  1899. PYBIND11_WARNING_PUSH
  1900. #ifdef PYBIND11_DETECTED_CLANG_WITH_MISLEADING_CALL_STD_MOVE_EXPLICITLY_WARNING
  1901. PYBIND11_WARNING_DISABLE_CLANG("-Wreturn-std-move")
  1902. #endif
  1903. if (size == 0) {
  1904. return result;
  1905. }
  1906. /* Call the function */
  1907. auto *mutable_data = returned_array::mutable_data(result);
  1908. if (trivial == broadcast_trivial::non_trivial) {
  1909. apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
  1910. } else {
  1911. apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
  1912. }
  1913. return result;
  1914. PYBIND11_WARNING_POP
  1915. }
  1916. template <size_t... Index, size_t... VIndex, size_t... BIndex>
  1917. void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
  1918. std::array<void *, N> &params,
  1919. Return *out,
  1920. size_t size,
  1921. index_sequence<Index...>,
  1922. index_sequence<VIndex...>,
  1923. index_sequence<BIndex...>) {
  1924. // Initialize an array of mutable byte references and sizes with references set to the
  1925. // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size
  1926. // (except for singletons, which get an increment of 0).
  1927. std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{
  1928. {std::pair<unsigned char *&, const size_t>(
  1929. reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
  1930. buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>))...}};
  1931. for (size_t i = 0; i < size; ++i) {
  1932. returned_array::call(
  1933. out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
  1934. for (auto &x : vecparams) {
  1935. x.first += x.second;
  1936. }
  1937. }
  1938. }
  1939. template <size_t... Index, size_t... VIndex, size_t... BIndex>
  1940. void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
  1941. std::array<void *, N> &params,
  1942. Return *out,
  1943. size_t size,
  1944. const std::vector<ssize_t> &output_shape,
  1945. index_sequence<Index...>,
  1946. index_sequence<VIndex...>,
  1947. index_sequence<BIndex...>) {
  1948. multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
  1949. for (size_t i = 0; i < size; ++i, ++input_iter) {
  1950. PYBIND11_EXPAND_SIDE_EFFECTS((params[VIndex] = input_iter.template data<BIndex>()));
  1951. returned_array::call(
  1952. out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
  1953. }
  1954. }
  1955. };
  1956. template <typename Func, typename Return, typename... Args>
  1957. vectorize_helper<Func, Return, Args...> vectorize_extractor(const Func &f, Return (*)(Args...)) {
  1958. return detail::vectorize_helper<Func, Return, Args...>(f);
  1959. }
  1960. template <typename T, int Flags>
  1961. struct handle_type_name<array_t<T, Flags>> {
  1962. static constexpr auto name
  1963. = io_name("typing.Annotated[numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
  1964. + npy_format_descriptor<T>::name + const_name("]");
  1965. };
  1966. PYBIND11_NAMESPACE_END(detail)
  1967. // Vanilla pointer vectorizer:
  1968. template <typename Return, typename... Args>
  1969. detail::vectorize_helper<Return (*)(Args...), Return, Args...> vectorize(Return (*f)(Args...)) {
  1970. return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
  1971. }
  1972. // lambda vectorizer:
  1973. template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
  1974. auto vectorize(Func &&f)
  1975. -> decltype(detail::vectorize_extractor(std::forward<Func>(f),
  1976. (detail::function_signature_t<Func> *) nullptr)) {
  1977. return detail::vectorize_extractor(std::forward<Func>(f),
  1978. (detail::function_signature_t<Func> *) nullptr);
  1979. }
  1980. // Vectorize a class method (non-const):
  1981. template <typename Return,
  1982. typename Class,
  1983. typename... Args,
  1984. typename Helper = detail::vectorize_helper<
  1985. decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())),
  1986. Return,
  1987. Class *,
  1988. Args...>>
  1989. Helper vectorize(Return (Class::*f)(Args...)) {
  1990. return Helper(std::mem_fn(f));
  1991. }
  1992. // Vectorize a class method (const):
  1993. template <typename Return,
  1994. typename Class,
  1995. typename... Args,
  1996. typename Helper = detail::vectorize_helper<
  1997. decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())),
  1998. Return,
  1999. const Class *,
  2000. Args...>>
  2001. Helper vectorize(Return (Class::*f)(Args...) const) {
  2002. return Helper(std::mem_fn(f));
  2003. }
  2004. PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)