diff --git a/pytensor/tensor/c_code/dimshuffle.c b/pytensor/tensor/c_code/dimshuffle.c index b99a0ee419..0bfc5df3bb 100644 --- a/pytensor/tensor/c_code/dimshuffle.c +++ b/pytensor/tensor/c_code/dimshuffle.c @@ -33,12 +33,17 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA npy_intp original_size = PyArray_SIZE(input); npy_intp new_size = 1; for (npy_intp i = 0; i < nd_out; ++i) { + // We set the strides of length 1 dimensions to PyArray_ITEMSIZE(input). + // The value is arbitrary, because there is never a next element. + // np.expand_dims(x, 0) and x[None] do different things here. + // I would prefer zero, but there are some poorly implemented BLAS operations + // That don't handle zero strides correctly. At least they won't fail because of DimShuffle. if (new_order[i] != -1) { dimensions[i] = PyArray_DIMS(input)[new_order[i]]; - strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(input)[new_order[i]]; + strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? PyArray_ITEMSIZE(input) : PyArray_STRIDES(input)[new_order[i]]; } else { dimensions[i] = 1; - strides[i] = 0; + strides[i] = PyArray_ITEMSIZE(input); } new_size *= dimensions[i]; } diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index d5aac0113b..77d41a03c5 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -185,14 +185,13 @@ def test_c_views(self): # as the broadcasted value; that way, we'll be able to tell that we're getting # junk data from a poorly constructed array view. x_val = np.broadcast_to(2039, (5000,)) - expected_x_val = x_val[None] for i in range(1): inputs[0].storage[0] = x_val thunk() # Make sure it's a view of the original data assert np.shares_memory(x_val, outputs[0].storage[0]) # Confirm the right strides - assert outputs[0].storage[0].strides == expected_x_val.strides + assert outputs[0].storage[0].strides[-1] == 0 # Confirm the broadcasted value in the output assert np.array_equiv(outputs[0].storage[0], 2039)