diff --git a/python/tvm/topi/cuda/pooling.py b/python/tvm/topi/cuda/pooling.py index 1b5cc94c3db0..ba2e7da8e11e 100644 --- a/python/tvm/topi/cuda/pooling.py +++ b/python/tvm/topi/cuda/pooling.py @@ -65,7 +65,7 @@ def _schedule_non_global(Pool): def traverse(OP): """Internal traverse function""" # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(OP.tag): + if tag.is_injective(OP.tag): if OP not in s.outputs: s[OP].compute_inline() for tensor in OP.input_tensors: diff --git a/python/tvm/topi/x86/pooling.py b/python/tvm/topi/x86/pooling.py index db0f9faf1970..b3f4eedec67c 100644 --- a/python/tvm/topi/x86/pooling.py +++ b/python/tvm/topi/x86/pooling.py @@ -89,7 +89,7 @@ def _schedule(PaddedInput, Pool): def traverse(OP): """Internal traverse function""" # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(OP.tag): + if tag.is_injective(OP.tag): if OP not in s.outputs: s[OP].compute_inline() for tensor in OP.input_tensors: @@ -137,7 +137,7 @@ def schedule_adaptive_pool(outs): def traverse(OP): """Internal traverse function""" # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(OP.tag): + if tag.is_injective(OP.tag): if OP not in s.outputs: s[OP].compute_inline() for tensor in OP.input_tensors: