diff --git a/src/runtime/HalideBuffer.h b/src/runtime/HalideBuffer.h index d39f5461c9f4..5416e8bab8d8 100644 --- a/src/runtime/HalideBuffer.h +++ b/src/runtime/HalideBuffer.h @@ -2172,9 +2172,13 @@ class Buffer { } } + // Return pair is template - HALIDE_NEVER_INLINE static bool for_each_value_prep(for_each_value_task_dim *t, - const halide_buffer_t **buffers) { + HALIDE_NEVER_INLINE static std::pair for_each_value_prep(for_each_value_task_dim *t, + const halide_buffer_t **buffers) { + const int dimensions = buffers[0]->dimensions; + assert(dimensions > 0); + // Check the buffers all have clean host allocations for (int i = 0; i < N; i++) { if (buffers[i]->device) { @@ -2188,8 +2192,6 @@ class Buffer { } } - const int dimensions = buffers[0]->dimensions; - // Extract the strides in all the dimensions for (int i = 0; i < dimensions; i++) { for (int j = 0; j < N; j++) { @@ -2219,42 +2221,47 @@ class Buffer { } if (flat) { t[i - 1].extent *= t[i].extent; - for (int j = i; j < d; j++) { + for (int j = i; j < d - 1; j++) { t[j] = t[j + 1]; } i--; d--; - t[d].extent = 1; } } + // Note that we assert() that dimensions > 0 above + // (our one-and-only caller will only call us that way) + // so the unchecked access to t[0] should be safe. bool innermost_strides_are_one = true; - if (dimensions > 0) { - for (int i = 0; i < N; i++) { - innermost_strides_are_one &= (t[0].stride[i] == 1); - } + for (int i = 0; i < N; i++) { + innermost_strides_are_one &= (t[0].stride[i] == 1); } - return innermost_strides_are_one; + return {d, innermost_strides_are_one}; } template void for_each_value_impl(Fn &&f, Args &&...other_buffers) const { if (dimensions() > 0) { + const size_t alloc_size = dimensions() * sizeof(for_each_value_task_dim); Buffer<>::for_each_value_task_dim *t = - (Buffer<>::for_each_value_task_dim *)HALIDE_ALLOCA((dimensions() + 1) * sizeof(for_each_value_task_dim)); + (Buffer<>::for_each_value_task_dim *)HALIDE_ALLOCA(alloc_size); // Move the preparatory code into a non-templated helper to // save code size. const halide_buffer_t *buffers[] = {&buf, (&other_buffers.buf)...}; - bool innermost_strides_are_one = Buffer<>::for_each_value_prep(t, buffers); - - Buffer<>::for_each_value_helper(f, dimensions() - 1, - innermost_strides_are_one, - t, - data(), (other_buffers.data())...); - } else { - f(*data(), (*other_buffers.data())...); + auto [new_dims, innermost_strides_are_one] = Buffer<>::for_each_value_prep(t, buffers); + if (new_dims > 0) { + Buffer<>::for_each_value_helper(f, new_dims - 1, + innermost_strides_are_one, + t, + data(), (other_buffers.data())...); + return; + } + // else fall thru } + + // zero-dimensional case + f(*data(), (*other_buffers.data())...); } // @}