diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 5db131c44f2a..d08bef2ab91a 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -273,6 +273,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value, op->span); } + ICHECK(!value.dtype().is_handle()) << "Can't cast a handle to other types."; return tir::Cast(t, value, span); } else { if (value.dtype().lanes() == 1) { diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index dbae0b6fa516..de94464187b0 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import te import numpy as np @@ -104,6 +105,11 @@ def test_cast(): assert isinstance(z, tvm.tir.Broadcast) assert z.lanes == 4 + s = tvm.tir.StringImm("s") + with pytest.raises(tvm.error.TVMError) as cm: + s.astype("int") + assert "Can't cast a handle to other types" in str(cm.execption) + def test_attr(): x = te.var("x") @@ -468,28 +474,4 @@ def test_block_blockrealize(): if __name__ == "__main__": - test_intimm_cond() - test_buffer_load_store() - test_vars() - test_prim_func() - test_cast() - test_attr() - test_const() - test_scalar_dtype_inference() - test_make() - test_ir() - test_basic() - test_stmt() - test_let() - test_dir() - test_dtype() - test_any() - test_all() - test_bitwise() - test_float_bitwise() - test_shift_bounds() - test_divide_by_zero() - test_isnan() - test_equality() - test_equality_string_imm() - test_block_blockrealize() + pytest.main([__file__])