Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bfloat16 support #336

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/nccl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ endif()
add_library(mscclpp_nccl_obj OBJECT)
target_sources(mscclpp_nccl_obj PRIVATE ${SOURCES})
target_sources(mscclpp_nccl_obj PUBLIC FILE_SET HEADERS FILES ${HEADERS})
target_include_directories(mscclpp_nccl_obj PRIVATE include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
target_include_directories(mscclpp_nccl_obj PRIVATE include ${PROJECT_SOURCE_DIR}/src/include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
target_link_libraries(mscclpp_nccl_obj PRIVATE ${GPU_LIBRARIES} PUBLIC mscclpp_obj)
set_target_properties(mscclpp_nccl_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
if(USE_CUDA)
Expand Down
27 changes: 26 additions & 1 deletion apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
#include <mscclpp/gpu_data_types.hpp>
#include <mscclpp/packet_device.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>

#include "common.hpp"
#include "gpu_data_types.hpp"

__device__ mscclpp::DeviceSyncer deviceSyncer;

Expand All @@ -38,6 +38,11 @@ __forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
return __hadd2(a, b);
}

template <>
__forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
return __hadd2(a, b);
}

template <typename T>
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
int4 ret;
Expand All @@ -58,6 +63,11 @@ __forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
return add_vectors_helper<__half2>(a, b);
}

template <>
__forceinline__ __device__ int4 add_vectors<__bfloat16>(int4 a, int4 b) {
return add_vectors_helper<__bfloat162>(a, b);
}

template <typename T>
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
uint2 ret;
Expand All @@ -76,6 +86,11 @@ __forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
return add_vectors_helper<__half2>(a, b);
}

template <>
__forceinline__ __device__ uint2 add_vectors<__bfloat16>(uint2 a, uint2 b) {
return add_vectors_helper<__bfloat162>(a, b);
}

template <typename T>
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
Expand All @@ -91,6 +106,11 @@ __forceinline__ __device__ int add_vectors<__half>(int a, int b) {
return add_vectors_helper<__half2>(a, b);
}

template <>
__forceinline__ __device__ int add_vectors<__bfloat16>(int a, int b) {
return add_vectors_helper<__bfloat162>(a, b);
}

template <typename T>
__forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) {
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
Expand All @@ -106,6 +126,11 @@ __forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b)
return add_vectors_helper<__half2>(a, b);
}

template <>
__forceinline__ __device__ uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) {
return add_vectors_helper<__bfloat162>(a, b);
}

template <typename T>
__forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem, int blockId, int nBlocks) {
size_t nInt4 = nElem / 4;
Expand Down
9 changes: 9 additions & 0 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclBfloat16:
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
Expand Down Expand Up @@ -498,6 +503,10 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32,
1024, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
mscclpp::DataType::BFLOAT16, 1024, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, 1024,
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ enum class DataType {
UINT32,
FLOAT16,
FLOAT32,
BFLOAT16,
};

enum class PacketType {
Expand Down
2 changes: 1 addition & 1 deletion include/mscclpp/nvls_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <type_traits>

#if defined(MSCCLPP_DEVICE_CUDA)
#include <mscclpp/gpu_data_types.hpp>
#include <cuda_fp16.h>
#endif // defined(MSCCLPP_DEVICE_CUDA)

#include "device.hpp"
Expand Down
3 changes: 2 additions & 1 deletion python/mscclpp/executor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ void register_executor(nb::module_& m) {
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32);
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16);

nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);

Expand Down
10 changes: 10 additions & 0 deletions src/executor/execution_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
case DataType::BFLOAT16:
executionKernel<__bfloat16><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
}
Expand Down
32 changes: 31 additions & 1 deletion src/include/execution_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "execution_common.hpp"

#if defined(MSCCLPP_DEVICE_COMPILE)
#include <mscclpp/gpu_data_types.hpp>
#include "gpu_data_types.hpp"

namespace {
template <typename To, typename From>
Expand Down Expand Up @@ -60,6 +60,11 @@ MSCCLPP_DEVICE_INLINE int4 add_vectors<__half>(int4 a, int4 b) {
return add_vectors_helper<__half2>(a, b);
}

template <>
MSCCLPP_DEVICE_INLINE int4 add_vectors<__bfloat16>(int4 a, int4 b) {
return add_vectors_helper<__bfloat162>(a, b);
}

template <typename T>
MSCCLPP_DEVICE_INLINE uint2 add_vectors_helper(uint2 a, uint2 b) {
uint2 ret;
Expand All @@ -78,6 +83,11 @@ MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__half>(uint2 a,
return add_vectors_helper<__half2>(a, b);
}

template <>
MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__bfloat16>(uint2 a, uint2 b) {
return add_vectors_helper<__bfloat162>(a, b);
}

template <typename T>
MSCCLPP_DEVICE_INLINE int add_vectors_helper(int a, int b) {
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
Expand All @@ -93,6 +103,11 @@ MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__half>(int a, int
return add_vectors_helper<__half2>(a, b);
}

template <>
MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__bfloat16>(int a, int b) {
return add_vectors_helper<__bfloat162>(a, b);
}

template <typename T>
MSCCLPP_DEVICE_INLINE uint32_t add_vectors_helper(uint32_t a, uint32_t b) {
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
Expand All @@ -108,6 +123,11 @@ MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__half>(uint32_t a, uint32_t b) {
return add_vectors_helper<__half2>(a, b);
}

template <>
MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) {
return add_vectors_helper<__bfloat162>(a, b);
}

} // namespace
#endif // defined(MSCCLPP_DEVICE_COMPILE)

Expand Down Expand Up @@ -502,6 +522,16 @@ class ExecutionKernel {
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
case DataType::BFLOAT16:
executionKernel<__bfloat16, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>

using __bfloat16 = __hip_bfloat16;
using __bfloat162 = __hip_bfloat162;
#define __CUDA_BF16_TYPES_EXIST__

#else

#include <cuda_fp16.h>
Expand All @@ -19,6 +23,9 @@
#include <cuda_fp8.h>
#endif

using __bfloat16 = __nv_bfloat16;
using __bfloat162 = __nv_bfloat162;

#endif

#endif // MSCCLPP_GPU_DATA_TYPES_HPP_
Loading