Skip to content

Commit

Permalink
[Metal] Reduce number of threads for reduction layers (apache#8206)
Browse files Browse the repository at this point in the history
Reduced default number of threads in reduction kernels for Metal.
Default code generation generated thread block with the following size:
32x32x1. With this size number of threads per threadgroup was equal to
1024 (32 * 32 * 1). Sometimes device doesn't have enough resources and
in this case we will get an exception that the block size is greater
than value of maxTotalThreadsPerThreadgroup.
To prevent such situation we decrease default number of threads. With
this fix every model should work with default codegen and auto-tuning or
auto-scheduling will select the optimal number of threads.
  • Loading branch information
echuraev authored and trevor-m committed Jun 17, 2021
1 parent c77914f commit 7c068dc
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion include/tvm/topi/cuda/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch,
if (out_stage->op.as<ComputeOpNode>()->axis.size() > 0) {
all_reduce = false;
num_thread = 32;
if (target->kind->name == "opencl") {
if (target->kind->name == "opencl" || target->kind->name == "metal") {
// Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests.
// Don't know why.
num_thread = 16;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
all_reduce = False
num_thread = 32
target = tvm.target.Target.current()
if target and target.kind.name == "opencl":
if target and (target.kind.name == "opencl" or target.kind.name == "metal"):
# without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py
# don't know why
num_thread = 16
Expand Down

0 comments on commit 7c068dc

Please sign in to comment.