diff --git a/include/tvm/topi/cuda/reduction.h b/include/tvm/topi/cuda/reduction.h index 7160419422a6..51f35ed8dc25 100644 --- a/include/tvm/topi/cuda/reduction.h +++ b/include/tvm/topi/cuda/reduction.h @@ -70,7 +70,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch, if (out_stage->op.as()->axis.size() > 0) { all_reduce = false; num_thread = 32; - if (target->kind->name == "opencl") { + if (target->kind->name == "opencl" || target->kind->name == "metal") { // Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests. // Don't know why. num_thread = 16; diff --git a/python/tvm/topi/cuda/reduction.py b/python/tvm/topi/cuda/reduction.py index ceab71640533..b9d02d9c81d8 100644 --- a/python/tvm/topi/cuda/reduction.py +++ b/python/tvm/topi/cuda/reduction.py @@ -37,7 +37,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): all_reduce = False num_thread = 32 target = tvm.target.Target.current() - if target and target.kind.name == "opencl": + if target and (target.kind.name == "opencl" or target.kind.name == "metal"): # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py # don't know why num_thread = 16