forked from mlc-ai/relax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Runtime] Introduce MSCCLPP with NCCL equivalent interface (#16804)
* [Runtime] Introduce MSCCLPP with NCCL equivalent interface * Add a fast and simple AllReduce kernel (sum only) using using mscclpp smChannel scratch for small reductions up to 2**24 bytes.
- Loading branch information
Showing
6 changed files
with
1,161 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
#ifndef MSCCL_COMMON_HPP_ | ||
#define MSCCL_COMMON_HPP_ | ||
|
||
#if defined(__HIP_PLATFORM_AMD__) | ||
#define WARP_SIZE 64 | ||
#define __syncwarp() __builtin_amdgcn_wave_barrier() | ||
#else | ||
#define WARP_SIZE 32 | ||
#endif | ||
|
||
constexpr int NRANKS_PER_NODE = 8; | ||
constexpr int SCRATCH_SIZE = 1024 * 1024 * 70; // 35 thread-blocks * 8 ranks * 256KB = 70MB | ||
|
||
template <typename To, typename From> | ||
__forceinline__ __device__ To bit_cast(const From& src) { | ||
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast"); | ||
|
||
union { | ||
From f; | ||
To t; | ||
} u; | ||
u.f = src; | ||
return u.t; | ||
} | ||
|
||
template <typename T> | ||
__forceinline__ __device__ T add_elements(T a, T b) { | ||
return a + b; | ||
} | ||
|
||
template <> | ||
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) { | ||
return __hadd2(a, b); | ||
} | ||
|
||
template <typename T> | ||
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) { | ||
int4 ret; | ||
ret.w = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w))); | ||
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x))); | ||
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y))); | ||
ret.z = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z))); | ||
return ret; | ||
} | ||
|
||
template <typename T> | ||
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) { | ||
return add_vectors_helper<T>(a, b); | ||
} | ||
|
||
template <> | ||
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) { | ||
return add_vectors_helper<__half2>(a, b); | ||
} | ||
|
||
template <typename T> | ||
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) { | ||
uint2 ret; | ||
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x))); | ||
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y))); | ||
return ret; | ||
} | ||
|
||
template <typename T> | ||
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) { | ||
return add_vectors_helper<T>(a, b); | ||
} | ||
|
||
template <> | ||
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) { | ||
return add_vectors_helper<__half2>(a, b); | ||
} | ||
|
||
template <typename T> | ||
__forceinline__ __device__ int add_vectors_helper(int a, int b) { | ||
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b))); | ||
} | ||
|
||
template <typename T> | ||
__forceinline__ __device__ int add_vectors(int a, int b) { | ||
return add_vectors_helper<T>(a, b); | ||
} | ||
|
||
template <> | ||
__forceinline__ __device__ int add_vectors<__half>(int a, int b) { | ||
return add_vectors_helper<__half2>(a, b); | ||
} | ||
|
||
template <typename T> | ||
__forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) { | ||
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b))); | ||
} | ||
|
||
template <typename T> | ||
__forceinline__ __device__ uint32_t add_vectors(uint32_t a, uint32_t b) { | ||
return add_vectors_helper<T>(a, b); | ||
} | ||
|
||
template <> | ||
__forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b) { | ||
return add_vectors_helper<__half2>(a, b); | ||
} | ||
|
||
#endif // MSCCL_COMMON_HPP_ |
Oops, something went wrong.