@@ -71,6 +71,12 @@ inline int prepare_and_add_type(PyTypeObject *type, PyObject *module)
7171template <typename T>
7272inline bool check_trailing_shape (T array, char const * name, long d1)
7373{
74+ if (array.ndim () != 2 ) {
75+ PyErr_Format (PyExc_ValueError,
76+ " Expected 2-dimensional array, got %ld" ,
77+ array.ndim ());
78+ return false ;
79+ }
7480 if (array.shape (1 ) != d1) {
7581 PyErr_Format (PyExc_ValueError,
7682 " %s must have shape (N, %ld), got (%ld, %ld)" ,
@@ -83,6 +89,12 @@ inline bool check_trailing_shape(T array, char const* name, long d1)
8389template <typename T>
8490inline bool check_trailing_shape (T array, char const * name, long d1, long d2)
8591{
92+ if (array.ndim () != 3 ) {
93+ PyErr_Format (PyExc_ValueError,
94+ " Expected 3-dimensional array, got %ld" ,
95+ array.ndim ());
96+ return false ;
97+ }
8698 if (array.shape (1 ) != d1 || array.shape (2 ) != d2) {
8799 PyErr_Format (PyExc_ValueError,
88100 " %s must have shape (N, %ld, %ld), got (%ld, %ld, %ld)" ,
0 commit comments