Skip to content

Commit

Permalink
Add max threads checking for Metal
Browse files Browse the repository at this point in the history
Originally, this checking will be asserted by Metal API Validation
in Xcode, otherwise the program will crash or output wrong results.
  • Loading branch information
xndcn committed Dec 22, 2020
1 parent b22598c commit 2fb72ab
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/runtime/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ WEAK mtl_compute_pipeline_state *new_compute_pipeline_state_with_function(mtl_de
return result;
}

WEAK unsigned long get_max_total_threads_per_threadgroup(mtl_compute_pipeline_state *pipeline_state) {
typedef unsigned long (*get_max_total_threads_per_threadgroup_method)(objc_id pipeline_state, objc_sel sel);
get_max_total_threads_per_threadgroup_method method = (get_max_total_threads_per_threadgroup_method)&objc_msgSend;
return (*method)(pipeline_state, sel_getUid("maxTotalThreadsPerThreadgroup"));
}

WEAK void set_compute_pipeline_state(mtl_compute_command_encoder *encoder, mtl_compute_pipeline_state *pipeline_state) {
typedef void (*set_compute_pipeline_state_method)(objc_id encoder, objc_sel sel, objc_id pipeline_state);
set_compute_pipeline_state_method method = (set_compute_pipeline_state_method)&objc_msgSend;
Expand Down Expand Up @@ -796,6 +802,15 @@ WEAK int halide_metal_run(void *user_context,
error(user_context) << "Metal: Could not allocate pipeline state.\n";
return -1;
}

uint64_t max_total_threads_per_threadgroup = get_max_total_threads_per_threadgroup(pipeline_state);
if (max_total_threads_per_threadgroup < (uint64_t)(threadsX * threadsY * threadsZ)) {
error(user_context) << "Metal: threadsX(" << threadsX << ") * threadsY(" << threadsY << ") * threadsZ(" << threadsZ << ") (" << (threadsX * threadsY * threadsZ) << ") must be <= " << max_total_threads_per_threadgroup << ". (device threadgroup size limit)\n";
end_encoding(encoder);
release_ns_object(pipeline_state);
return -1;
}

set_compute_pipeline_state(encoder, pipeline_state);

size_t total_args_size = 0;
Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ tests(GROUPS error
memoize_different_compute_store.cpp
memoize_redefine_eviction_key.cpp
metal_vector_too_large.cpp
metal_threads_too_large.cpp
missing_args.cpp
no_default_device.cpp
nonexistent_update_stage.cpp
Expand Down
33 changes: 33 additions & 0 deletions test/error/metal_threads_too_large.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "Halide.h"
#include "halide_test_dirs.h"

using namespace Halide;

int main(int argc, char **argv) {
ImageParam im(UInt(16), 2, "input");
Func f("f");
Var x("x"), y("y");

f(x, y) = im(x, y) + 42;
f.gpu_blocks(y).gpu_threads(x, DeviceAPI::Metal);

// 65536 is larger enough than `maxTotalThreadsPerThreadgroup`
Buffer<uint16_t> input = lambda(x, y, cast<uint16_t>(x + y)).realize(65536, 1);
im.set(input);

Buffer<uint16_t> output(input.width(), input.height());
f.realize(output);
output.copy_to_host();

for (int32_t i = 0; i < output.width(); i++) {
for (int32_t j = 0; j < output.height(); j++) {
if (output(i, j) != uint16_t(i + j + 42)) {
std::cerr << "Expected " << (x + y + 42) << " at (" << i << ", " << j << ") got " << output(i, j) << "\n";
assert(false);
}
}
}

printf("Success!\n");
return 0;
}

0 comments on commit 2fb72ab

Please sign in to comment.