diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 603211b59ebc..a72439079ef7 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -658,9 +658,10 @@ def BindParams( for k, v in params.items(): if isinstance(v, np.ndarray): v = tvm.nd.array(v) - assert isinstance( - v, tvm.runtime.NDArray - ), f"param values are expected to be TVM.NDArray or numpy.ndarray, but got {type(v)}" + assert isinstance(v, (tvm.runtime.NDArray, tvm.relax.Constant)), ( + f"param values are expected to be TVM.NDArray," + f"numpy.ndarray or tvm.relax.Constant, but got {type(v)}" + ) tvm_params[k] = v return _ffi_api.BindParams(func_name, tvm_params) # type: ignore