diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index b90d2c362c915..1ddf4d5965d35 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -105,7 +105,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i size_cast_dtype = "float64" target = tvm.target.Target.current() - if "vulkan" in str(target) and isinstance(scan_axis_size, tvm.tir.expr.Var): + if "vulkan" in str(target) and not isinstance(scan_axis_size, tvm.tir.expr.IntImm): # SPIRV seems to have an issue with float64 intrinsic # TODO(masahi): Eliminate this concern by adding TIR level CSE msg = """