diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu index b8732357c7bd..5164958afeb5 100644 --- a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu @@ -67,8 +67,8 @@ void tvm_cutlass_fp8_blockwise_scaled_gemm(NDArray a, NDArray b, NDArray scales_ CHECK_EQ(scales_b->shape[1] * block_size_1, k); using tvm::runtime::DataType; - CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3()); - CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3()); + CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); @@ -128,8 +128,8 @@ void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, NDArray b, NDArray scales_a CHECK_EQ(scales_b->shape[2] * block_size_1, k); using tvm::runtime::DataType; - CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3()); - CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3()); + CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 8095cbeeea4a..2b860b6b63ec 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -120,7 +120,7 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); DataType dtype = DataType(send->dtype); - if (dtype == DataType::NVFloat8E4M3() || dtype == DataType::NVFloat8E5M2()) { + if (dtype == DataType::Float8E4M3FN() || dtype == DataType::Float8E5M2()) { LOG(FATAL) << "Float8 data type cannot be allreduced, as nccl does not support this data type."; } NCCL_CALL(ncclAllReduce(send->data, recv->data, numel, diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index fff165bfdd04..e24687d8675f 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -86,8 +86,8 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { if (dtype == DataType::Int(8)) { return ncclInt8; } - if (dtype == DataType::UInt(8) || dtype == DataType::NVFloat8E4M3() || - dtype == DataType::NVFloat8E5M2()) { + if (dtype == DataType::UInt(8) || dtype == DataType::Float8E4M3FN() || + dtype == DataType::Float8E5M2()) { // For float8 data type, pretend to be uint8 in nccl. // And will throw error when allreduce, as it makes no sense in this case. return ncclUint8;