diff --git a/src/relax/backend/contrib/hipblas/codegen.cc b/src/relax/backend/contrib/hipblas/codegen.cc index 4a4666a92daf..2e78e0961451 100644 --- a/src/relax/backend/contrib/hipblas/codegen.cc +++ b/src/relax/backend/contrib/hipblas/codegen.cc @@ -96,7 +96,7 @@ Array HipblasCompiler(Array functions, Map()); } return compiled_functions; diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc index 07331e33defe..fb4e394e7fc2 100644 --- a/src/runtime/contrib/hipblas/hipblas.cc +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -300,8 +300,8 @@ inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t hdl) << "leading dimension must divide 4 for int8 gemm"; ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - double alpha = args.size() > 5 ? args[5] : 1.0; - double beta = args.size() > 6 ? args[6] : 0.0; + double alpha = args.size() > 5 ? args[5].cast() : 1.0; + double beta = args.size() > 6 ? args[6].cast() : 0.0; hipblasDatatype_t hip_in_type = GetHipBlasDataType(A->dtype); hipblasDatatype_t hip_out_type = GetHipBlasDataType(C->dtype); @@ -359,8 +359,8 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t << "leading dimension must divide 4 for int8 gemm"; ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - double alpha = args.size() > 5 ? args[5] : 1.0; - double beta = args.size() > 6 ? args[6] : 0.0; + double alpha = args.size() > 5 ? args[5].cast() : 1.0; + double beta = args.size() > 6 ? args[6].cast() : 0.0; int A_stride = A->shape[1] * A->shape[2]; int B_stride = B->shape[1] * B->shape[2]; diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 3f4be327c4b2..60e439125c10 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -72,12 +72,10 @@ class HipblasJSONRuntime : public JSONRuntimeBase { for (size_t i = 0; i < static_cast(args.size()); i++) { auto eid = i < input_var_eid_.size() ? input_var_eid_[i] : EntryID(outputs_[i - input_var_eid_.size()]); - ICHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle) - << "Expect NDArray or DLTensor as inputs"; const DLTensor* arg; - if (args[i].IsObjectRef()) { - NDArray arr = args[i]; + if (auto opt_nd = args[i].as()) { + NDArray arr = opt_nd.value(); arg = arr.operator->(); } else { arg = args[i].cast();