1
+
2
+ /*
3
+ pybind11/eigen_tensor.h: Transparent conversion for Eigen tensors
4
+
5
+ Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
6
+
7
+ All rights reserved. Use of this source code is governed by a
8
+ BSD-style license that can be found in the LICENSE file.
9
+ */
10
+
11
+ #pragma once
12
+
13
+ /* HINT: To suppress warnings originating from the Eigen headers, use -isystem.
14
+ See also:
15
+ https://stackoverflow.com/questions/2579576/i-dir-vs-isystem-dir
16
+ https://stackoverflow.com/questions/1741816/isystem-for-ms-visual-studio-c-compiler
17
+ */
18
+
19
+ #include " ../numpy.h"
20
+
21
+ // The C4127 suppression was introduced for Eigen 3.4.0. In theory we could
22
+ // make it version specific, or even remove it later, but considering that
23
+ // 1. C4127 is generally far more distracting than useful for modern template code, and
24
+ // 2. we definitely want to ignore any MSVC warnings originating from Eigen code,
25
+ // it is probably best to keep this around indefinitely.
26
+ #if defined(_MSC_VER)
27
+ # pragma warning(push)
28
+ # pragma warning(disable : 4554) // Tensor.h warning
29
+ // C5054: operator '&': deprecated between enumerations of different types
30
+ #elif defined(__MINGW32__)
31
+ # pragma GCC diagnostic push
32
+ # pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
33
+ #endif
34
+
35
+ #include < unsupported/Eigen/CXX11/Tensor>
36
+
37
+ #if defined(_MSC_VER)
38
+ # pragma warning(pop)
39
+ #elif defined(__MINGW32__)
40
+ # pragma GCC diagnostic pop
41
+ #endif
42
+
43
+ PYBIND11_NAMESPACE_BEGIN (PYBIND11_NAMESPACE)
44
+
45
+ PYBIND11_NAMESPACE_BEGIN(detail)
46
+
47
+ template <typename T>
48
+ constexpr int compute_array_flag_from_tensor() {
49
+ static_assert (((int ) T::Layout == (int ) Eigen::RowMajor)
50
+ || ((int ) T::Layout == (int ) Eigen::ColMajor),
51
+ " Layout must be row or column major" );
52
+ return ((int ) T::Layout == (int ) Eigen::RowMajor) ? array::c_style : array::f_style;
53
+ }
54
+
55
+ template <typename T>
56
+ struct eigen_tensor_helper {};
57
+
58
+ template <typename Scalar_, int NumIndices_, int Options_, typename IndexType>
59
+ struct eigen_tensor_helper <Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>> {
60
+ using T = Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>;
61
+ using ValidType = void ;
62
+
63
+ static std::array<typename T::Index, T::NumIndices> get_shape (const T &f) {
64
+ return f.dimensions ();
65
+ }
66
+
67
+ static constexpr bool
68
+ is_correct_shape (const std::array<typename T::Index, T::NumIndices> & /* shape*/ ) {
69
+ return true ;
70
+ }
71
+
72
+ template <typename T>
73
+ struct helper {};
74
+
75
+ template <size_t ... Is>
76
+ struct helper <index_sequence<Is...>> {
77
+ static constexpr auto value = concat(const_name(((void ) Is, " ?" ))...);
78
+ };
79
+
80
+ static constexpr auto dimensions_descriptor
81
+ = helper<decltype(make_index_sequence<T::NumIndices>())>::value;
82
+ };
83
+
84
+ template <typename Scalar_, typename std::ptrdiff_t ... Indices, int Options_, typename IndexType>
85
+ struct eigen_tensor_helper <
86
+ Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>> {
87
+ using T = Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>;
88
+ using ValidType = void ;
89
+
90
+ static constexpr std::array<typename T::Index, T::NumIndices> get_shape (const T & /* f*/ ) {
91
+ return get_shape ();
92
+ }
93
+
94
+ static constexpr std::array<typename T::Index, T::NumIndices> get_shape () {
95
+ return {{Indices...}};
96
+ }
97
+
98
+ static bool is_correct_shape (const std::array<typename T::Index, T::NumIndices> &shape) {
99
+ return get_shape () == shape;
100
+ }
101
+
102
+ static constexpr auto dimensions_descriptor = concat(const_name<Indices>()...);
103
+ };
104
+
105
+ template <typename T>
106
+ struct get_tensor_descriptor {
107
+ static constexpr auto value
108
+ = const_name(" numpy.ndarray[" ) + npy_format_descriptor<typename T::Scalar>::name
109
+ + const_name(" [" ) + eigen_tensor_helper<T>::dimensions_descriptor
110
+ + const_name(" ], flags.writeable, " )
111
+ + const_name<(int ) T::Layout == (int ) Eigen::RowMajor>(" flags.c_contiguous" ,
112
+ " flags.f_contiguous" );
113
+ };
114
+
115
+ template <typename Type>
116
+ struct type_caster <Type, typename eigen_tensor_helper<Type>::ValidType> {
117
+ using H = eigen_tensor_helper<Type>;
118
+ PYBIND11_TYPE_CASTER (Type, get_tensor_descriptor<Type>::value);
119
+
120
+ bool load (handle src, bool /* convert*/ ) {
121
+ array_t <typename Type::Scalar, compute_array_flag_from_tensor<Type>()> a (
122
+ reinterpret_borrow<object>(src));
123
+
124
+ if (a.ndim () != Type::NumIndices) {
125
+ return false ;
126
+ }
127
+
128
+ std::array<typename Type::Index, Type::NumIndices> shape;
129
+ std::copy (a.shape (), a.shape () + Type::NumIndices, shape.begin ());
130
+
131
+ if (!H::is_correct_shape (shape)) {
132
+ return false ;
133
+ }
134
+
135
+ value = Eigen::TensorMap<Type>(const_cast <typename Type::Scalar *>(a.data ()), shape);
136
+
137
+ return true ;
138
+ }
139
+
140
+ static handle cast (Type &&src, return_value_policy policy, handle parent) {
141
+ if (policy == return_value_policy::reference
142
+ || policy == return_value_policy::reference_internal) {
143
+ pybind11_fail (" Cannot use a reference return value policy for an rvalue" );
144
+ }
145
+ return cast_impl (&src, return_value_policy::move, parent);
146
+ }
147
+
148
+ static handle cast (const Type &&src, return_value_policy policy, handle parent) {
149
+ if (policy == return_value_policy::reference
150
+ || policy == return_value_policy::reference_internal) {
151
+ pybind11_fail (" Cannot use a reference return value policy for an rvalue" );
152
+ }
153
+ return cast_impl (&src, return_value_policy::move, parent);
154
+ }
155
+
156
+ static handle cast (Type &src, return_value_policy policy, handle parent) {
157
+ if (policy == return_value_policy::automatic
158
+ || policy == return_value_policy::automatic_reference) {
159
+ policy = return_value_policy::copy;
160
+ }
161
+ return cast_impl (&src, policy, parent);
162
+ }
163
+
164
+ static handle cast (const Type &src, return_value_policy policy, handle parent) {
165
+ if (policy == return_value_policy::automatic
166
+ || policy == return_value_policy::automatic_reference) {
167
+ policy = return_value_policy::copy;
168
+ }
169
+ return cast (&src, policy, parent);
170
+ }
171
+
172
+ static handle cast (Type *src, return_value_policy policy, handle parent) {
173
+ if (policy == return_value_policy::automatic) {
174
+ policy = return_value_policy::take_ownership;
175
+ } else if (policy == return_value_policy::automatic_reference) {
176
+ policy = return_value_policy::reference;
177
+ }
178
+ return cast_impl (src, policy, parent);
179
+ }
180
+
181
+ static handle cast (const Type *src, return_value_policy policy, handle parent) {
182
+ if (policy == return_value_policy::automatic) {
183
+ policy = return_value_policy::take_ownership;
184
+ } else if (policy == return_value_policy::automatic_reference) {
185
+ policy = return_value_policy::reference;
186
+ }
187
+ return cast_impl (src, policy, parent);
188
+ }
189
+
190
+ template <typename C>
191
+ static handle cast_impl (C *src, return_value_policy policy, handle parent) {
192
+ object parent_object;
193
+ bool writeable = false ;
194
+ switch (policy) {
195
+ case return_value_policy::move:
196
+ if (std::is_const<C>::value) {
197
+ pybind11_fail (" Cannot move from a constant reference" );
198
+ }
199
+ {
200
+ Eigen::aligned_allocator<Type> allocator;
201
+ Type *copy = ::new (allocator.allocate (1 )) Type (std::move (*src));
202
+ src = copy;
203
+ }
204
+
205
+ parent_object = capsule (src, [](void *ptr) {
206
+ Eigen::aligned_allocator<Type> allocator;
207
+ Type *copy = (Type *) ptr;
208
+ copy->~Type ();
209
+ allocator.deallocate (copy, 1 );
210
+ });
211
+ writeable = true ;
212
+ break ;
213
+
214
+ case return_value_policy::take_ownership:
215
+ if (std::is_const<C>::value) {
216
+ pybind11_fail (" Cannot take ownership of a const reference" );
217
+ }
218
+ parent_object = capsule (src, [](void *ptr) { delete (Type *) ptr; });
219
+ writeable = true ;
220
+ break ;
221
+
222
+ case return_value_policy::copy:
223
+ parent_object = {};
224
+ writeable = true ;
225
+ break ;
226
+
227
+ case return_value_policy::reference:
228
+ parent_object = none ();
229
+ writeable = !std::is_const<C>::value;
230
+ break ;
231
+
232
+ case return_value_policy::reference_internal:
233
+ // Default should do the right thing
234
+ parent_object = reinterpret_borrow<object>(parent);
235
+ writeable = !std::is_const<C>::value;
236
+ break ;
237
+
238
+ default :
239
+ pybind11_fail (" pybind11 bug in eigen.h, please file a bug report" );
240
+ }
241
+
242
+ handle result = array_t <typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
243
+ H::get_shape (*src), src->data (), parent_object)
244
+ .release ();
245
+
246
+ if (!writeable) {
247
+ array_proxy (result.ptr ())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
248
+ }
249
+
250
+ return result;
251
+ }
252
+ };
253
+
254
+ template <typename Type>
255
+ struct type_caster <Eigen::TensorMap<Type>, typename eigen_tensor_helper<Type>::ValidType> {
256
+ using H = eigen_tensor_helper<Type>;
257
+
258
+ bool load (handle src, bool /* convert*/ ) {
259
+ // Note that we have a lot more checks here as we want to make sure to avoid copies
260
+ auto a = reinterpret_borrow<array>(src);
261
+ if ((a.flags () & compute_array_flag_from_tensor<Type>()) == 0 ) {
262
+ return false ;
263
+ }
264
+
265
+ if (!a.dtype ().is (dtype::of<typename Type::Scalar>())) {
266
+ return false ;
267
+ }
268
+
269
+ if (a.ndim () != Type::NumIndices) {
270
+ return false ;
271
+ }
272
+
273
+ std::array<typename Type::Index, Type::NumIndices> shape;
274
+ std::copy (a.shape (), a.shape () + Type::NumIndices, shape.begin ());
275
+
276
+ if (!H::is_correct_shape (shape)) {
277
+ return false ;
278
+ }
279
+
280
+ value.reset (new Eigen::TensorMap<Type>(
281
+ reinterpret_cast <typename Type::Scalar *>(a.mutable_data ()), shape));
282
+
283
+ return true ;
284
+ }
285
+
286
+ static handle cast (Eigen::TensorMap<Type> &&src, return_value_policy policy, handle parent) {
287
+ return cast_impl (&src, policy, parent);
288
+ }
289
+
290
+ static handle
291
+ cast (const Eigen::TensorMap<Type> &&src, return_value_policy policy, handle parent) {
292
+ return cast_impl (&src, policy, parent);
293
+ }
294
+
295
+ static handle cast (Eigen::TensorMap<Type> &src, return_value_policy policy, handle parent) {
296
+ if (policy == return_value_policy::automatic
297
+ || policy == return_value_policy::automatic_reference) {
298
+ policy = return_value_policy::copy;
299
+ }
300
+ return cast_impl (&src, policy, parent);
301
+ }
302
+
303
+ static handle
304
+ cast (const Eigen::TensorMap<Type> &src, return_value_policy policy, handle parent) {
305
+ if (policy == return_value_policy::automatic
306
+ || policy == return_value_policy::automatic_reference) {
307
+ policy = return_value_policy::copy;
308
+ }
309
+ return cast (&src, policy, parent);
310
+ }
311
+
312
+ static handle cast (Eigen::TensorMap<Type> *src, return_value_policy policy, handle parent) {
313
+ if (policy == return_value_policy::automatic) {
314
+ policy = return_value_policy::take_ownership;
315
+ } else if (policy == return_value_policy::automatic_reference) {
316
+ policy = return_value_policy::reference;
317
+ }
318
+ return cast_impl (src, policy, parent);
319
+ }
320
+
321
+ static handle
322
+ cast (const Eigen::TensorMap<Type> *src, return_value_policy policy, handle parent) {
323
+ if (policy == return_value_policy::automatic) {
324
+ policy = return_value_policy::take_ownership;
325
+ } else if (policy == return_value_policy::automatic_reference) {
326
+ policy = return_value_policy::reference;
327
+ }
328
+ return cast_impl (src, policy, parent);
329
+ }
330
+
331
+ template <typename C>
332
+ static handle cast_impl (C *src, return_value_policy policy, handle parent) {
333
+ object parent_object;
334
+ constexpr bool writeable = !std::is_const<C>::value;
335
+ switch (policy) {
336
+ case return_value_policy::reference:
337
+ parent_object = none ();
338
+ break ;
339
+
340
+ case return_value_policy::reference_internal:
341
+ // Default should do the right thing
342
+ parent_object = reinterpret_borrow<object>(parent);
343
+ break ;
344
+
345
+ default :
346
+ // move, take_ownership don't make any sense for a ref/map:
347
+ pybind11_fail (" Invalid return_value_policy for Eigen Map type, must be either "
348
+ " reference or reference_internal" );
349
+ }
350
+
351
+ handle result = array_t <typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
352
+ H::get_shape (*src), src->data (), parent_object)
353
+ .release ();
354
+
355
+ if (!writeable) {
356
+ array_proxy (result.ptr ())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
357
+ }
358
+
359
+ return result;
360
+ }
361
+
362
+ protected:
363
+ // TODO: Move to std::optional once std::optional has more support
364
+ std::unique_ptr<Eigen::TensorMap<Type>> value;
365
+
366
+ public:
367
+ static constexpr auto name = get_tensor_descriptor<Type>::value;
368
+ explicit operator Eigen::TensorMap<Type> *() {
369
+ return value.get ();
370
+ } /* NOLINT(bugprone-macro-parentheses) */
371
+ explicit operator Eigen::TensorMap<Type> &() {
372
+ return *value;
373
+ } /* NOLINT(bugprone-macro-parentheses) */
374
+ explicit operator Eigen::TensorMap<Type> &&() && {
375
+ return std::move (*value);
376
+ } /* NOLINT(bugprone-macro-parentheses) */
377
+
378
+ template <typename T_>
379
+ using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>;
380
+ };
381
+
382
+ PYBIND11_NAMESPACE_END (detail)
383
+ PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
0 commit comments