diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 8095cbeeea4a..2b860b6b63ec 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -120,7 +120,7 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); DataType dtype = DataType(send->dtype); - if (dtype == DataType::NVFloat8E4M3() || dtype == DataType::NVFloat8E5M2()) { + if (dtype == DataType::Float8E4M3FN() || dtype == DataType::Float8E5M2()) { LOG(FATAL) << "Float8 data type cannot be allreduced, as nccl does not support this data type."; } NCCL_CALL(ncclAllReduce(send->data, recv->data, numel, diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index fff165bfdd04..e24687d8675f 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -86,8 +86,8 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { if (dtype == DataType::Int(8)) { return ncclInt8; } - if (dtype == DataType::UInt(8) || dtype == DataType::NVFloat8E4M3() || - dtype == DataType::NVFloat8E5M2()) { + if (dtype == DataType::UInt(8) || dtype == DataType::Float8E4M3FN() || + dtype == DataType::Float8E5M2()) { // For float8 data type, pretend to be uint8 in nccl. // And will throw error when allreduce, as it makes no sense in this case. return ncclUint8;