diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 1ef65230591b..951944e618ab 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -96,7 +96,8 @@ def _compute(*idxs): axis_index = 0 for i in range(0, len(idxs)): if i not in real_axis: - indices.append(idxs[i]) + dim = tvm.tir.if_then_else(a.shape[len(indices)] != 1, idxs[i], 0) + indices.append(dim) axis_index += 1 return a(*indices)