From 1738e8e59a15694be169c7b4f9e0c068ff2ec856 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Fri, 11 Jul 2025 01:46:16 -0400 Subject: [PATCH 01/16] add shm allreduce --- gloo/CMakeLists.txt | 2 + gloo/allreduce.cc | 11 + gloo/allreduce.h | 24 ++ gloo/allreduce_shm.cc | 741 ++++++++++++++++++++++++++++++++++++++++++ gloo/allreduce_shm.h | 8 + 5 files changed, 786 insertions(+) create mode 100644 gloo/allreduce_shm.cc create mode 100644 gloo/allreduce_shm.h diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index 186fe1288..fb65defd5 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -11,6 +11,7 @@ list(APPEND GLOO_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/allgatherv.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_shm.cc" "${CMAKE_CURRENT_SOURCE_DIR}/alltoall.cc" "${CMAKE_CURRENT_SOURCE_DIR}/alltoallv.cc" "${CMAKE_CURRENT_SOURCE_DIR}/barrier.cc" @@ -34,6 +35,7 @@ list(APPEND GLOO_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.h" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_ring.h" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_ring_chunked.h" + "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_shm.h" "${CMAKE_CURRENT_SOURCE_DIR}/alltoall.h" "${CMAKE_CURRENT_SOURCE_DIR}/alltoallv.h" "${CMAKE_CURRENT_SOURCE_DIR}/barrier.h" diff --git a/gloo/allreduce.cc b/gloo/allreduce.cc index 080f7f302..511e8d3d3 100644 --- a/gloo/allreduce.cc +++ b/gloo/allreduce.cc @@ -15,6 +15,7 @@ #include "gloo/common/logging.h" #include "gloo/math.h" #include "gloo/types.h" +#include "gloo/allreduce_shm.h" namespace gloo { @@ -95,6 +96,7 @@ BroadcastRangeFunction genLocalBroadcastFunction(const BufferVector& out) { } void allreduce(const detail::AllreduceOptionsImpl& opts) { + //printf("In gloo::allreduce\n"); if (opts.elements == 0) { return; } @@ -153,6 +155,15 @@ void ring( const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag); const size_t totalBytes = opts.elements * opts.elementSize; + + if (is_intra_node(context->size)) { + shm(opts); + return; + } + + //shm(opts); + //return; + // Note: context->size > 1 const auto recvRank = (context->size + context->rank + 1) % context->size; const auto sendRank = (context->size + context->rank - 1) % context->size; diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 904eb8b32..2133cf2f3 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -11,9 +11,13 @@ #include #include #include +#include #include "gloo/context.h" #include "gloo/transport/unbound_buffer.h" +#include "gloo/types.h" +//#include "gloo/allreduce_shm.h" + namespace gloo { @@ -41,6 +45,12 @@ struct AllreduceOptionsImpl { BCUBE = 2, }; + enum ScalarType { + BFLOAT16, + HALF, + FLOAT, + }; + explicit AllreduceOptionsImpl(const std::shared_ptr& context) : context(context), timeout(context->getTimeout()), @@ -54,6 +64,9 @@ struct AllreduceOptionsImpl { // Algorithm selection. Algorithm algorithm; + // Scalar type + ScalarType scalarType; + // Input and output buffers. // The output is used as input if input is not specified. std::vector> in; @@ -90,6 +103,7 @@ class AllreduceOptions { public: using Func = detail::AllreduceOptionsImpl::Func; using Algorithm = detail::AllreduceOptionsImpl::Algorithm; + using ScalarType = detail::AllreduceOptionsImpl::ScalarType; explicit AllreduceOptions(const std::shared_ptr& context) : impl_(context) {} @@ -154,6 +168,16 @@ class AllreduceOptions { template void setOutputs(std::vector ptrs, size_t elements) { + //printf("set outputs\n"); + if (std::is_same_v) { + printf("output type is float\n"); + impl_.scalarType = ScalarType::FLOAT; + } else if (std::is_same_v) { + printf("output type is float16\n"); + impl_.scalarType = ScalarType::HALF; + } else { + printf("Unknown datatype\n"); + } setOutputs(ptrs.data(), ptrs.size(), elements); } diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc new file mode 100644 index 000000000..1a6498a14 --- /dev/null +++ b/gloo/allreduce_shm.cc @@ -0,0 +1,741 @@ +#include "gloo/allreduce_shm.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace gloo { + +namespace { +#define VECTOR_LENGTH_IN_BYTES 32 +// states for collectives +enum coll_state { + coll_begin = 0, + coll_allreduce_naive__copy_in_done, + coll_allreduce_naive__reduce_done, + // alternative state when allreduce is working on alternative buffer + // of the double buffer. + coll_alt1_allreduce_naive__copy_in_done, + coll_alt2_allreduce_naive__copy_in_done, + coll_alt1_allreduce_naive__reduce_done, +}; + +// SHM building blocks +struct SharedData { + const char* name; + int descriptor; + void* bytes; + size_t nbytes; +}; + +void shared_open(SharedData* data, const char* name, size_t nbytes) { + int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0); + data->name = name; + data->descriptor = d; + data->bytes = bytes; + data->nbytes = nbytes; + } else { + if (errno != ENOENT) { + // don't print if shm can not be found because we want to loop over from + // caller again until the other ranks created the shm + printf("shared_open %s failed, errno=%d\n", name, errno); + } + data->descriptor = -1; + } +} + +void shared_create( + SharedData* data, + const char* name, + void* bytes, + size_t nbytes) { + int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + if (nbytes = write(d, bytes, nbytes)) { + shared_open(data, name, nbytes); + } + } else { + printf("shared_create %s failed\n", name); + } +} + +static int world_rank = -1; +static int world_size = -1; +static bool is_initialized = false; + +// SHM based allreduce helper functions +// buffer that holds shm name +#define NAME_BUF_SIZE 1000 +#define MAX_BUF_SIZE 1048576 * 32 +#define NAIVE_ALLREDUCE_THRESHOLD 1048576 +#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer" +struct allreduce_workspace { + enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce + // idx=1 -- state for distributed_naive_all_reduce + // double buffer to avoid syncing between rounds + // offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for + // symmetric_naive_all_reduce after that : buffer for + // distributed_naive_all_reduce + char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE]; +}; + +#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD +#define BUFFER1_OFFSET(current_buffer) \ + 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE + +struct allreduce_workspace** workspace; + +// buffer for small messages, double buffer +char** symmetric_buffer[2]; +// buffer for large messages, double buffer +char** distributed_buffer[2]; + +void wait_buffer_state_until_2( + int index, + enum coll_state state0, + enum coll_state state1, + int state_group) { + volatile enum coll_state* state_ptr = + &(workspace[index]->states[state_group]); + + while (1) { + volatile enum coll_state cur_state = *state_ptr; + if (cur_state == state0 || cur_state == state1) + break; + } +} + +__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); +inline __m512 cvt_bf16_to_fp32(const __m256i src) { + auto y = _mm512_cvtepu16_epi32(src); + return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2)); +} + +inline __m256i cvt_fp32_to_bf16(const __m512 src) + __attribute__((target("avx512bw"))); +inline __m256i cvt_fp32_to_bf16(const __m512 src) { + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +} + +__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); +inline __m512 cvt_fp16_to_fp32(const __m256i src) { + return _mm512_cvtph_ps(src); +} + +inline __m256i cvt_fp32_to_fp16(const __m512 src) + __attribute__((target("avx512bw"))); +inline __m256i cvt_fp32_to_fp16(const __m512 src) { + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +void reduce_bf16_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) __attribute__((target("avx512bw"))); + +void reduce_fp16_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) __attribute__((target("avx512bw"))); + +void reduce_fp32_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) __attribute__((target("avx512bw"))); + +void reduce_all_buffers( + int start_elements, + int num_elements, + AllreduceOptions::ScalarType scalar_type, + int to_buffer_idx, + char* to_buffer, + char** buffers) { + switch (scalar_type) { + case AllreduceOptions::ScalarType::BFLOAT16: + assert(!"BFloat16 not supported in gloo yet."); + reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers); + break; + case AllreduceOptions::ScalarType::HALF: + reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers); + break; + case AllreduceOptions::ScalarType::FLOAT: + reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers); + break; + default: + assert(!"Should not get here"); + } +} + +#define CVT_ADD_BF16(x) \ + do { \ + auto in##x##_val = \ + cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ + inout_val = _mm512_add_ps(inout_val, in##x##_val); \ + } while (0) + +// Reduce functions down below use vectorized algorithm, the number of bytes +// processed each iteration depends on vector length. 256bit vector ==> 32 +// bytes, 512bit vector ==> 64 bytes If you change implementation of +// reduce_bf16_buffers, etc. , check whether this number needs to be changed +#define VECTOR_LENGTH_IN_BYTES 32 + +void reduce_bf16_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) { + const int element_size = 2; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; + i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = + cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); + switch (world_size) { + case 16: + CVT_ADD_BF16(15); + case 15: + CVT_ADD_BF16(14); + case 14: + CVT_ADD_BF16(13); + case 13: + CVT_ADD_BF16(12); + case 12: + CVT_ADD_BF16(11); + case 11: + CVT_ADD_BF16(10); + case 10: + CVT_ADD_BF16(9); + case 9: + CVT_ADD_BF16(8); + case 8: + CVT_ADD_BF16(7); + case 7: + CVT_ADD_BF16(6); + case 6: + CVT_ADD_BF16(5); + case 5: + CVT_ADD_BF16(4); + case 4: + CVT_ADD_BF16(3); + case 3: + CVT_ADD_BF16(2); + case 2: + CVT_ADD_BF16(1); + case 1: + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = + cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); + inout_val = _mm512_add_ps(inout_val, in_val); + } + } + _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val)); + } + + // process remaining part + // todo: support bfloat16 + /* + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { + val += *(at::BFloat16*)(buffers[j] + i); + } + *(at::BFloat16*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } + */ +} + +#define CVT_ADD_FP16(x) \ + do { \ + auto in##x##_val = \ + cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ + inout_val = _mm512_add_ps(inout_val, in##x##_val); \ + } while (0) + +void reduce_fp16_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) { + const int element_size = 2; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; + i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = + cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); + switch (world_size) { + case 16: + CVT_ADD_FP16(15); + case 15: + CVT_ADD_FP16(14); + case 14: + CVT_ADD_FP16(13); + case 13: + CVT_ADD_FP16(12); + case 12: + CVT_ADD_FP16(11); + case 11: + CVT_ADD_FP16(10); + case 10: + CVT_ADD_FP16(9); + case 9: + CVT_ADD_FP16(8); + case 8: + CVT_ADD_FP16(7); + case 7: + CVT_ADD_FP16(6); + case 6: + CVT_ADD_FP16(5); + case 5: + CVT_ADD_FP16(4); + case 4: + CVT_ADD_FP16(3); + case 3: + CVT_ADD_FP16(2); + case 2: + CVT_ADD_FP16(1); + case 1: + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = + cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); + inout_val = _mm512_add_ps(inout_val, in_val); + } + } + _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val)); + } + + + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float16 val =float16(0.0f); + for (int j = 0; j < world_size; j++) { + val += *(float16*)(buffers[j] + i); + } + *(float16*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +#define CVT_ADD_F32(x) \ + do { \ + auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \ + inout_val = _mm256_add_ps(inout_val, in##x##_val); \ + } while (0) + +void reduce_fp32_buffers( + int start_elements, + int num_elements, + char* to_buffer, + char** buffers) { + const int element_size = 4; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; + i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i)); + switch (world_size) { + case 16: + CVT_ADD_F32(15); + case 15: + CVT_ADD_F32(14); + case 14: + CVT_ADD_F32(13); + case 13: + CVT_ADD_F32(12); + case 12: + CVT_ADD_F32(11); + case 11: + CVT_ADD_F32(10); + case 10: + CVT_ADD_F32(9); + case 9: + CVT_ADD_F32(8); + case 8: + CVT_ADD_F32(7); + case 7: + CVT_ADD_F32(6); + case 6: + CVT_ADD_F32(5); + case 5: + CVT_ADD_F32(4); + case 4: + CVT_ADD_F32(3); + case 3: + CVT_ADD_F32(2); + case 2: + CVT_ADD_F32(1); + case 1: + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i)); + inout_val = _mm256_add_ps(inout_val, in_val); + } + } + _mm256_storeu_ps((float*)(to_buffer + i), inout_val); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { + val += *(float*)(buffers[j] + i); + } + *(float*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +void shm_initialize(int size, int rank, char* addr_string, char* port_string) { + world_size = size; + world_rank = rank; + + char shm_name_prefix[NAME_BUF_SIZE]; + char shm_name[NAME_BUF_SIZE]; + snprintf( + shm_name_prefix, + NAME_BUF_SIZE, + "%s_%d_%s_%s", + SHM_BUFFER_NAME, + getuid(), + addr_string, + port_string); + // create shared workspace for SHM based allreduce + SharedData allreduce_buffer; + // allocate workspace_buf for current rank + struct allreduce_workspace* workspace_buf; + struct allreduce_workspace* workspace_buf_other; + workspace_buf = + (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace)); + int written = snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); + if (written >= NAME_BUF_SIZE) { + std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; + } + shared_create( + &allreduce_buffer, + shm_name, + workspace_buf, + sizeof(struct allreduce_workspace)); + workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done; + workspace_buf->states[1] = coll_begin; + + // create the workspace pointer list + workspace = (struct allreduce_workspace**)malloc( + size * sizeof(struct allreduce_workspace*)); + symmetric_buffer[0] = (char**)malloc(size * sizeof(char**)); + symmetric_buffer[1] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[0] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[1] = (char**)malloc(size * sizeof(char**)); + + // map shm of all ranks + for (int i = 0; i < size; i++) { + if (i != rank) { + int written = snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); + if (written >= NAME_BUF_SIZE) { + std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; + } + // printf("open %s, %d\n", shm_name, rank); + do { + shared_open( + &allreduce_buffer, shm_name, sizeof(struct allreduce_workspace)); + } while (allreduce_buffer.descriptor == -1 && errno == ENOENT); + workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace[i] = workspace_buf_other; + } else { + workspace[i] = workspace_buf; + } + symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0); + symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1); + distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0); + distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1); + } +} + +static void parallel_memcpy(void* to, void* from, size_t n_bytes) + __attribute__((target("avx512bw"))); +static void parallel_memcpy(void* to, void* from, size_t n_bytes) { + auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES); + // process aligned part +#pragma omp parallel for + for (int i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) { + auto val = _mm256_loadu_si256((__m256i*)((char*)from + i)); + _mm256_storeu_si256((__m256i*)((char*)to + i), val); + } + + // process remaining part + for (int i = aligned_bytes; i < n_bytes; i++) { + *((char*)to + i) = *((char*)from + i); + } +} + +#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod)) +#define rank_mod(rank) positive_mod(rank, world_size) +size_t slice_size(size_t chunk_el, int slice_idx) { + size_t slice_size = chunk_el / world_size; + return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) + : slice_size; +} + +char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) { + size_t slice_size = chunk_el / world_size; + size_t el_offset = slice_size * slice_idx; + return data_ptr + el_offset * el_size; +} + +size_t slice_el_start(size_t chunk_el, int slice_idx) { + size_t slice_size = chunk_el / world_size; + return slice_size * slice_idx; +} + +void symmetric_naive_all_reduce( + char* data_ptr, + AllreduceOptions::ScalarType scalar_type, + size_t chunk_size, + size_t chunk_el) { + const int state_group = 0; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next; + + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + copy_next = coll_alt2_allreduce_naive__copy_in_done; + break; + case 2: + copy_current = coll_alt2_allreduce_naive__copy_in_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: + assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 3; + + parallel_memcpy( + symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + + for (int i = 0; i < world_size; i++) { + // wait until the other rank copy the buffer + if (i != world_rank) { + wait_buffer_state_until_2(i, copy_current, copy_next, state_group); + } + } + + // each rank reduce the buffer independently so therre is no need for + // synchronization afterward + reduce_all_buffers( + 0, + chunk_el, + scalar_type, + world_rank, + data_ptr, + symmetric_buffer[current_buffer]); + + // switch buffer + current_buffer = 1 - current_buffer; +} + +// naive allreduce distributed, each rank do naive reduce on its slice +void distributed_naive_reduce( + char* data_ptr, + AllreduceOptions::ScalarType scalar_type, + size_t chunk_size, + size_t chunk_el) { + const int state_group = 1; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next, reduce_current; + + // similar to symmetric_naive_allreduce, but here we only need two sets of + // states, because distributed naive reduce has two barriers in the algorithm + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + reduce_current = coll_allreduce_naive__reduce_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + reduce_current = coll_alt1_allreduce_naive__reduce_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: + assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 2; + + int data_size = chunk_size / chunk_el; + parallel_memcpy( + distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks copy the buffer + if (i != world_rank) + wait_buffer_state_until_2(i, copy_current, reduce_current, state_group); + } + + // reduce scatter + reduce_all_buffers( + slice_el_start(chunk_el, world_rank), + slice_size(chunk_el, world_rank), + scalar_type, + world_rank, + distributed_buffer[current_buffer][world_rank], + distributed_buffer[current_buffer]); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = reduce_current; + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks reduce the buffer + if (i != world_rank) + wait_buffer_state_until_2(i, reduce_current, copy_next, state_group); + } + + for (int i = 0; i < world_size; i++) { + int rank = (i + world_rank) % world_size; + parallel_memcpy( + slice_data(data_ptr, chunk_el, data_size, rank), + slice_data( + distributed_buffer[current_buffer][rank], + chunk_el, + chunk_size / chunk_el, + rank), + slice_size(chunk_el, rank) * data_size); + } + + current_buffer = 1 - current_buffer; +} + +} // namespace + +bool is_intra_node(const int size) { + // must launch with torchrun + auto local_size_string = std::getenv("LOCAL_WORLD_SIZE"); + int local_size = 0; + if (local_size_string != NULL) { + local_size = std::stoi(local_size_string); + } + + return size > 1 && size == local_size; +} + + +void shm(const detail::AllreduceOptionsImpl& opts) { + + //printf("In shm allreduce\n"); + const auto& context = opts.context; + if (!is_initialized) { + + //int size = context->size; + //int rank = context->rank; + + int size = std::stoi(std::getenv("PMI_SIZE")); + int rank = std::stoi(std::getenv("PMI_RANK")); + + world_size = size; + world_rank = rank; + is_initialized = true; + + auto addr_string = std::getenv("MASTER_ADDR"); + if (addr_string == NULL) { + addr_string = ""; + } + auto port_string = std::getenv("MASTER_PORT"); + if (port_string == NULL) { + port_string = ""; + } + // std::cout << "size: " << size << std::endl; + // std::cout << "rank: " << rank << std::endl; + // std::cout << "addr_string: " << addr_string << std::endl; + // std::cout << "port_string: " << port_string << std::endl; + shm_initialize(size, rank, addr_string, port_string); + } + + const size_t data_size = opts.elements * opts.elementSize; + const std::vector>& out = opts.out; + void* data = out[0].get()->ptr; + + for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { + auto data_ptr = ((char*)(data) + offset); + size_t chunk_size = + data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; + size_t chunk_el = chunk_size / (data_size / opts.elements); + if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { + symmetric_naive_all_reduce( + data_ptr, opts.scalarType, chunk_size, chunk_el); + } else { + distributed_naive_reduce( + data_ptr, opts.scalarType, chunk_size, chunk_el); + } + } + +} + +} //namespace gloo + diff --git a/gloo/allreduce_shm.h b/gloo/allreduce_shm.h new file mode 100644 index 000000000..e9236759c --- /dev/null +++ b/gloo/allreduce_shm.h @@ -0,0 +1,8 @@ +#include "gloo/allreduce.h" + +namespace gloo { + +bool is_intra_node(const int size); +void shm(const detail::AllreduceOptionsImpl& opts); + +} // namespace gloo \ No newline at end of file From 76d111461f802df7a5c634437d6244aec89a0951 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Wed, 16 Jul 2025 03:52:12 -0400 Subject: [PATCH 02/16] add bf16 and half support --- gloo/CMakeLists.txt | 5 +++++ gloo/allreduce.h | 34 +++++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index fb65defd5..6b0ac60b0 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -188,6 +188,11 @@ if(USE_ROCM) endif() endif() +message(STATUS "GLOO_USE_TORCH_DTYPES : ${GLOO_USE_TORCH_DTYPES} ${GLOO_TORCH_DIR}") +if(GLOO_USE_TORCH_DTYPES) +target_include_directories(gloo PRIVATE ${GLOO_TORCH_DIR}) +endif() + # Install if necessary. # If the Gloo build is included from another project's build, it may # want to statically link with Gloo and not install any artifacts. diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 2133cf2f3..2ca69ca94 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -18,6 +18,11 @@ #include "gloo/types.h" //#include "gloo/allreduce_shm.h" +#define GPF_PRINT(...) do {\ + printf("GPF_DEBUG:");\ + printf(__VA_ARGS__);\ + printf("\n");\ +}while(0) namespace gloo { @@ -39,6 +44,11 @@ struct AllreduceOptionsImpl { // using Func = std::function; +#if GLOO_USE_TORCH_DTYPES +using BFloat16 = c10::BFloat16; +using Half = c10::Half; +#endif + enum Algorithm { UNSPECIFIED = 0, RING = 1, @@ -49,6 +59,7 @@ struct AllreduceOptionsImpl { BFLOAT16, HALF, FLOAT, + UNKNOWN, }; explicit AllreduceOptionsImpl(const std::shared_ptr& context) @@ -169,18 +180,23 @@ class AllreduceOptions { template void setOutputs(std::vector ptrs, size_t elements) { //printf("set outputs\n"); - if (std::is_same_v) { - printf("output type is float\n"); - impl_.scalarType = ScalarType::FLOAT; - } else if (std::is_same_v) { - printf("output type is float16\n"); - impl_.scalarType = ScalarType::HALF; - } else { - printf("Unknown datatype\n"); - } + // default is float + impl_.scalarType = ScalarType::FLOAT; + +#if GLOO_USE_TORCH_DTYPES +if (std::is_same_v) { + //GPF_PRINT("output type is half"); + impl_.scalarType = ScalarType::HALF; +} else if (std::is_same_v) { + impl_.scalarType = ScalarType::BFLOAT16; + //GPF_PRINT("output type is bfloat16"); +} +#endif setOutputs(ptrs.data(), ptrs.size(), elements); } + + template void setOutputs(T** ptrs, size_t len, size_t elements) { impl_.elements = elements; From 2d152a33ccbc0012d92320b5e5b90a0f32921349 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Fri, 18 Jul 2025 01:35:01 -0400 Subject: [PATCH 03/16] remove bf16 support --- gloo/allreduce.h | 41 ++++++++++++++++++++--------------------- gloo/allreduce_shm.cc | 8 ++++---- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 2ca69ca94..4ef031073 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -16,7 +16,6 @@ #include "gloo/context.h" #include "gloo/transport/unbound_buffer.h" #include "gloo/types.h" -//#include "gloo/allreduce_shm.h" #define GPF_PRINT(...) do {\ printf("GPF_DEBUG:");\ @@ -44,11 +43,6 @@ struct AllreduceOptionsImpl { // using Func = std::function; -#if GLOO_USE_TORCH_DTYPES -using BFloat16 = c10::BFloat16; -using Half = c10::Half; -#endif - enum Algorithm { UNSPECIFIED = 0, RING = 1, @@ -179,24 +173,9 @@ class AllreduceOptions { template void setOutputs(std::vector ptrs, size_t elements) { - //printf("set outputs\n"); - // default is float - impl_.scalarType = ScalarType::FLOAT; - -#if GLOO_USE_TORCH_DTYPES -if (std::is_same_v) { - //GPF_PRINT("output type is half"); - impl_.scalarType = ScalarType::HALF; -} else if (std::is_same_v) { - impl_.scalarType = ScalarType::BFLOAT16; - //GPF_PRINT("output type is bfloat16"); -} -#endif setOutputs(ptrs.data(), ptrs.size(), elements); } - - template void setOutputs(T** ptrs, size_t len, size_t elements) { impl_.elements = elements; @@ -230,6 +209,26 @@ if (std::is_same_v) { friend void allreduce(const AllreduceOptions&); }; +#if GLOO_USE_TORCH_DTYPES + template <> + void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { + impl_.scalarType = ScalarType::HALF; + setOutputs(ptrs.data(), ptrs.size(), elements); + } + + template <> + void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { + impl_.scalarType = ScalarType::BFLOAT16; + setOutputs(ptrs.data(), ptrs.size(), elements); + } + + template <> + void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { + impl_.scalarType = ScalarType::FLOAT; + setOutputs(ptrs.data(), ptrs.size(), elements); + } +#endif + void allreduce(const AllreduceOptions& opts); } // namespace gloo diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 1a6498a14..0169c3f6e 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -179,7 +179,7 @@ void reduce_all_buffers( char** buffers) { switch (scalar_type) { case AllreduceOptions::ScalarType::BFLOAT16: - assert(!"BFloat16 not supported in gloo yet."); + GLOO_ENFORCE(false, "Bfloat16 for shm_allreduce is not supported yet."); reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers); break; case AllreduceOptions::ScalarType::HALF: @@ -273,9 +273,9 @@ void reduce_bf16_buffers( while (remain_elements > 0) { float val = 0.0f; for (int j = 0; j < world_size; j++) { - val += *(at::BFloat16*)(buffers[j] + i); + val += *(c10::BFloat16*)(buffers[j] + i); } - *(at::BFloat16*)(to_buffer + i) = val; + *(BFloat16*)(to_buffer + i) = val; remain_elements--; i += element_size; } @@ -688,7 +688,6 @@ bool is_intra_node(const int size) { void shm(const detail::AllreduceOptionsImpl& opts) { - //printf("In shm allreduce\n"); const auto& context = opts.context; if (!is_initialized) { @@ -715,6 +714,7 @@ void shm(const detail::AllreduceOptionsImpl& opts) { // std::cout << "addr_string: " << addr_string << std::endl; // std::cout << "port_string: " << port_string << std::endl; shm_initialize(size, rank, addr_string, port_string); + GPF_PRINT("SHM reduce has been initialized"); } const size_t data_size = opts.elements * opts.elementSize; From 8c29eeba1c15aa203393dd916ad92ebaa44b3964 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Tue, 22 Jul 2025 01:38:39 -0400 Subject: [PATCH 04/16] add bf16 support --- gloo/allreduce_shm.cc | 110 +++++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 50 deletions(-) diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 0169c3f6e..848533f5b 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -14,6 +15,9 @@ namespace gloo { namespace { + +using ReductionFunction = AllreduceOptions::Func; + #define VECTOR_LENGTH_IN_BYTES 32 // states for collectives enum coll_state { @@ -156,19 +160,47 @@ void reduce_bf16_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) __attribute__((target("avx512bw"))); + char** buffers, + ReductionFunction fn) __attribute__((target("avx512bw"))); void reduce_fp16_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) __attribute__((target("avx512bw"))); + char** buffers, + ReductionFunction fn) __attribute__((target("avx512bw"))); void reduce_fp32_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) __attribute__((target("avx512bw"))); + char** buffers, + ReductionFunction fn) __attribute__((target("avx512bw"))); + +void reduce_remaining_part( + int start_elements, + int num_elements, + int remain_elements, + int main_elements, + int element_size, + char *to_buffer, + char **buffers, + ReductionFunction fn){ + size_t offset = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + memcpy(to_buffer + offset, buffers[0] + offset, element_size); + for (int j = 1; j < world_size; j++) { + + fn(to_buffer + offset, + to_buffer + offset, + buffers[j] + offset, + 1); + + } + remain_elements--; + offset += element_size; + } +} void reduce_all_buffers( int start_elements, @@ -176,17 +208,18 @@ void reduce_all_buffers( AllreduceOptions::ScalarType scalar_type, int to_buffer_idx, char* to_buffer, - char** buffers) { + char** buffers, + ReductionFunction fn) { switch (scalar_type) { case AllreduceOptions::ScalarType::BFLOAT16: - GLOO_ENFORCE(false, "Bfloat16 for shm_allreduce is not supported yet."); - reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers); + //GLOO_ENFORCE(false, "Bfloat16 for shm_allreduce is not supported yet."); + reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers, fn); break; case AllreduceOptions::ScalarType::HALF: - reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers); + reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers, fn); break; case AllreduceOptions::ScalarType::FLOAT: - reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers); + reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers, fn); break; default: assert(!"Should not get here"); @@ -210,7 +243,8 @@ void reduce_bf16_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) { + char** buffers, + ReductionFunction fn) { const int element_size = 2; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; int main_elements = num_elements - (num_elements % vector_length); @@ -267,19 +301,7 @@ void reduce_bf16_buffers( } // process remaining part - // todo: support bfloat16 - /* - int i = (start_elements + main_elements) * element_size; - while (remain_elements > 0) { - float val = 0.0f; - for (int j = 0; j < world_size; j++) { - val += *(c10::BFloat16*)(buffers[j] + i); - } - *(BFloat16*)(to_buffer + i) = val; - remain_elements--; - i += element_size; - } - */ + reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); } #define CVT_ADD_FP16(x) \ @@ -293,7 +315,8 @@ void reduce_fp16_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) { + char** buffers, + ReductionFunction fn) { const int element_size = 2; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; int main_elements = num_elements - (num_elements % vector_length); @@ -352,16 +375,7 @@ void reduce_fp16_buffers( // process remaining part - int i = (start_elements + main_elements) * element_size; - while (remain_elements > 0) { - float16 val =float16(0.0f); - for (int j = 0; j < world_size; j++) { - val += *(float16*)(buffers[j] + i); - } - *(float16*)(to_buffer + i) = val; - remain_elements--; - i += element_size; - } + reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); } #define CVT_ADD_F32(x) \ @@ -374,7 +388,8 @@ void reduce_fp32_buffers( int start_elements, int num_elements, char* to_buffer, - char** buffers) { + char** buffers, + ReductionFunction fn) { const int element_size = 4; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; int main_elements = num_elements - (num_elements % vector_length); @@ -429,16 +444,7 @@ void reduce_fp32_buffers( } // process remaining part - int i = (start_elements + main_elements) * element_size; - while (remain_elements > 0) { - float val = 0.0f; - for (int j = 0; j < world_size; j++) { - val += *(float*)(buffers[j] + i); - } - *(float*)(to_buffer + i) = val; - remain_elements--; - i += element_size; - } + reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); } void shm_initialize(int size, int rank, char* addr_string, char* port_string) { @@ -547,7 +553,8 @@ void symmetric_naive_all_reduce( char* data_ptr, AllreduceOptions::ScalarType scalar_type, size_t chunk_size, - size_t chunk_el) { + size_t chunk_el, + ReductionFunction fn) { const int state_group = 0; static int current_buffer = 0; static int state_idx = 0; @@ -592,7 +599,8 @@ void symmetric_naive_all_reduce( scalar_type, world_rank, data_ptr, - symmetric_buffer[current_buffer]); + symmetric_buffer[current_buffer], + fn); // switch buffer current_buffer = 1 - current_buffer; @@ -603,7 +611,8 @@ void distributed_naive_reduce( char* data_ptr, AllreduceOptions::ScalarType scalar_type, size_t chunk_size, - size_t chunk_el) { + size_t chunk_el, + ReductionFunction fn) { const int state_group = 1; static int current_buffer = 0; static int state_idx = 0; @@ -647,7 +656,8 @@ void distributed_naive_reduce( scalar_type, world_rank, distributed_buffer[current_buffer][world_rank], - distributed_buffer[current_buffer]); + distributed_buffer[current_buffer], + fn); std::atomic_thread_fence(std::memory_order_release); workspace[world_rank]->states[state_group] = reduce_current; @@ -728,10 +738,10 @@ void shm(const detail::AllreduceOptionsImpl& opts) { size_t chunk_el = chunk_size / (data_size / opts.elements); if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { symmetric_naive_all_reduce( - data_ptr, opts.scalarType, chunk_size, chunk_el); + data_ptr, opts.scalarType, chunk_size, chunk_el, opts.reduce); } else { distributed_naive_reduce( - data_ptr, opts.scalarType, chunk_size, chunk_el); + data_ptr, opts.scalarType, chunk_size, chunk_el, opts.reduce); } } From 554d3176fefbef1fcbfd7043a540d98595e9b177 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Thu, 24 Jul 2025 03:23:17 -0400 Subject: [PATCH 05/16] use reduce function to do reduce job --- gloo/CMakeLists.txt | 4 - gloo/allreduce.h | 31 ---- gloo/allreduce_shm.cc | 357 +++++------------------------------------- 3 files changed, 43 insertions(+), 349 deletions(-) diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index 6b0ac60b0..db54496ef 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -188,10 +188,6 @@ if(USE_ROCM) endif() endif() -message(STATUS "GLOO_USE_TORCH_DTYPES : ${GLOO_USE_TORCH_DTYPES} ${GLOO_TORCH_DIR}") -if(GLOO_USE_TORCH_DTYPES) -target_include_directories(gloo PRIVATE ${GLOO_TORCH_DIR}) -endif() # Install if necessary. # If the Gloo build is included from another project's build, it may diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 4ef031073..8fc494358 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -49,13 +49,6 @@ struct AllreduceOptionsImpl { BCUBE = 2, }; - enum ScalarType { - BFLOAT16, - HALF, - FLOAT, - UNKNOWN, - }; - explicit AllreduceOptionsImpl(const std::shared_ptr& context) : context(context), timeout(context->getTimeout()), @@ -69,9 +62,6 @@ struct AllreduceOptionsImpl { // Algorithm selection. Algorithm algorithm; - // Scalar type - ScalarType scalarType; - // Input and output buffers. // The output is used as input if input is not specified. std::vector> in; @@ -108,7 +98,6 @@ class AllreduceOptions { public: using Func = detail::AllreduceOptionsImpl::Func; using Algorithm = detail::AllreduceOptionsImpl::Algorithm; - using ScalarType = detail::AllreduceOptionsImpl::ScalarType; explicit AllreduceOptions(const std::shared_ptr& context) : impl_(context) {} @@ -209,26 +198,6 @@ class AllreduceOptions { friend void allreduce(const AllreduceOptions&); }; -#if GLOO_USE_TORCH_DTYPES - template <> - void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { - impl_.scalarType = ScalarType::HALF; - setOutputs(ptrs.data(), ptrs.size(), elements); - } - - template <> - void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { - impl_.scalarType = ScalarType::BFLOAT16; - setOutputs(ptrs.data(), ptrs.size(), elements); - } - - template <> - void AllreduceOptions::setOutputs(std::vector ptrs, size_t elements) { - impl_.scalarType = ScalarType::FLOAT; - setOutputs(ptrs.data(), ptrs.size(), elements); - } -#endif - void allreduce(const AllreduceOptions& opts); } // namespace gloo diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 848533f5b..6af27d3ad 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -118,333 +118,62 @@ void wait_buffer_state_until_2( } } -__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); -inline __m512 cvt_bf16_to_fp32(const __m256i src) { - auto y = _mm512_cvtepu16_epi32(src); - return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2)); -} - -inline __m256i cvt_fp32_to_bf16(const __m512 src) - __attribute__((target("avx512bw"))); -inline __m256i cvt_fp32_to_bf16(const __m512 src) { - __m512i value = _mm512_castps_si512(src); - __m512i nan = _mm512_set1_epi32(0xffff); - auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); - __m512i ones = _mm512_set1_epi32(0x1); - __m512i vec_bias = _mm512_set1_epi32(0x7fff); - // uint32_t lsb = (input >> 16) & 1; - auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); - // uint32_t rounding_bias = 0x7fff + lsb; - t_value = _mm512_add_epi32(t_value, vec_bias); - // input += rounding_bias; - t_value = _mm512_add_epi32(t_value, value); - // input = input >> 16; - t_value = _mm512_srli_epi32(t_value, 16); - // Check NaN before converting back to bf16 - t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); - return _mm512_cvtusepi32_epi16(t_value); -} - -__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); -inline __m512 cvt_fp16_to_fp32(const __m256i src) { - return _mm512_cvtph_ps(src); -} - -inline __m256i cvt_fp32_to_fp16(const __m512 src) - __attribute__((target("avx512bw"))); -inline __m256i cvt_fp32_to_fp16(const __m512 src) { - return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); -} - -void reduce_bf16_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) __attribute__((target("avx512bw"))); - -void reduce_fp16_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) __attribute__((target("avx512bw"))); - -void reduce_fp32_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) __attribute__((target("avx512bw"))); - -void reduce_remaining_part( - int start_elements, - int num_elements, - int remain_elements, - int main_elements, - int element_size, - char *to_buffer, - char **buffers, - ReductionFunction fn){ - size_t offset = (start_elements + main_elements) * element_size; - while (remain_elements > 0) { - memcpy(to_buffer + offset, buffers[0] + offset, element_size); - for (int j = 1; j < world_size; j++) { - - fn(to_buffer + offset, - to_buffer + offset, - buffers[j] + offset, - 1); - - } - remain_elements--; - offset += element_size; - } -} - void reduce_all_buffers( int start_elements, int num_elements, - AllreduceOptions::ScalarType scalar_type, + int element_size, int to_buffer_idx, char* to_buffer, char** buffers, ReductionFunction fn) { - switch (scalar_type) { - case AllreduceOptions::ScalarType::BFLOAT16: - //GLOO_ENFORCE(false, "Bfloat16 for shm_allreduce is not supported yet."); - reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers, fn); - break; - case AllreduceOptions::ScalarType::HALF: - reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers, fn); - break; - case AllreduceOptions::ScalarType::FLOAT: - reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers, fn); - break; - default: - assert(!"Should not get here"); - } -} - -#define CVT_ADD_BF16(x) \ - do { \ - auto in##x##_val = \ - cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ - inout_val = _mm512_add_ps(inout_val, in##x##_val); \ - } while (0) - -// Reduce functions down below use vectorized algorithm, the number of bytes -// processed each iteration depends on vector length. 256bit vector ==> 32 -// bytes, 512bit vector ==> 64 bytes If you change implementation of -// reduce_bf16_buffers, etc. , check whether this number needs to be changed -#define VECTOR_LENGTH_IN_BYTES 32 - -void reduce_bf16_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) { - const int element_size = 2; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; int main_elements = num_elements - (num_elements % vector_length); int remain_elements = num_elements % vector_length; - - // process aligned part + #pragma omp parallel for for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = - cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); - switch (world_size) { - case 16: - CVT_ADD_BF16(15); - case 15: - CVT_ADD_BF16(14); - case 14: - CVT_ADD_BF16(13); - case 13: - CVT_ADD_BF16(12); - case 12: - CVT_ADD_BF16(11); - case 11: - CVT_ADD_BF16(10); - case 10: - CVT_ADD_BF16(9); - case 9: - CVT_ADD_BF16(8); - case 8: - CVT_ADD_BF16(7); - case 7: - CVT_ADD_BF16(6); - case 6: - CVT_ADD_BF16(5); - case 5: - CVT_ADD_BF16(4); - case 4: - CVT_ADD_BF16(3); - case 3: - CVT_ADD_BF16(2); - case 2: - CVT_ADD_BF16(1); - case 1: - break; - default: - for (int j = 1; j < world_size; j++) { - auto in_val = - cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); - inout_val = _mm512_add_ps(inout_val, in_val); + memcpy(to_buffer + i, buffers[0] + i, element_size); + switch (world_size){ + case 16: fn(to_buffer + i, to_buffer + i, buffers[15] + i, vector_length); + case 15: fn(to_buffer + i, to_buffer + i, buffers[14] + i, vector_length); + case 14: fn(to_buffer + i, to_buffer + i, buffers[13] + i, vector_length); + case 13: fn(to_buffer + i, to_buffer + i, buffers[12] + i, vector_length); + case 12: fn(to_buffer + i, to_buffer + i, buffers[11] + i, vector_length); + case 11: fn(to_buffer + i, to_buffer + i, buffers[10] + i, vector_length); + case 10: fn(to_buffer + i, to_buffer + i, buffers[9] + i, vector_length); + case 9: fn(to_buffer + i, to_buffer + i, buffers[8] + i, vector_length); + case 8: fn(to_buffer + i, to_buffer + i, buffers[7] + i, vector_length); + case 7: fn(to_buffer + i, to_buffer + i, buffers[6] + i, vector_length); + case 6: fn(to_buffer + i, to_buffer + i, buffers[5] + i, vector_length); + case 5: fn(to_buffer + i, to_buffer + i, buffers[4] + i, vector_length); + case 4: fn(to_buffer + i, to_buffer + i, buffers[3] + i, vector_length); + case 3: fn(to_buffer + i, to_buffer + i, buffers[2] + i, vector_length); + case 2: fn(to_buffer + i, to_buffer + i, buffers[1] + i, vector_length); + case 1: break; + default: + for (int j = 1; j < world_size; j++) { + fn(to_buffer + i, to_buffer + i, buffers[j] + i, vector_length); + } } - } - _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val)); - } - - // process remaining part - reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); -} - -#define CVT_ADD_FP16(x) \ - do { \ - auto in##x##_val = \ - cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ - inout_val = _mm512_add_ps(inout_val, in##x##_val); \ - } while (0) - -void reduce_fp16_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) { - const int element_size = 2; - const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; - int main_elements = num_elements - (num_elements % vector_length); - int remain_elements = num_elements % vector_length; - - // process aligned part -#pragma omp parallel for - for (int i = start_elements * element_size; - i < (start_elements + main_elements) * element_size; - i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = - cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); - switch (world_size) { - case 16: - CVT_ADD_FP16(15); - case 15: - CVT_ADD_FP16(14); - case 14: - CVT_ADD_FP16(13); - case 13: - CVT_ADD_FP16(12); - case 12: - CVT_ADD_FP16(11); - case 11: - CVT_ADD_FP16(10); - case 10: - CVT_ADD_FP16(9); - case 9: - CVT_ADD_FP16(8); - case 8: - CVT_ADD_FP16(7); - case 7: - CVT_ADD_FP16(6); - case 6: - CVT_ADD_FP16(5); - case 5: - CVT_ADD_FP16(4); - case 4: - CVT_ADD_FP16(3); - case 3: - CVT_ADD_FP16(2); - case 2: - CVT_ADD_FP16(1); - case 1: - break; - default: - for (int j = 1; j < world_size; j++) { - auto in_val = - cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); - inout_val = _mm512_add_ps(inout_val, in_val); } - } - _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val)); - } - - - - // process remaining part - reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); -} - -#define CVT_ADD_F32(x) \ - do { \ - auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \ - inout_val = _mm256_add_ps(inout_val, in##x##_val); \ - } while (0) - -void reduce_fp32_buffers( - int start_elements, - int num_elements, - char* to_buffer, - char** buffers, - ReductionFunction fn) { - const int element_size = 4; - const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; - int main_elements = num_elements - (num_elements % vector_length); - int remain_elements = num_elements % vector_length; - // process aligned part -#pragma omp parallel for - for (int i = start_elements * element_size; - i < (start_elements + main_elements) * element_size; - i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i)); - switch (world_size) { - case 16: - CVT_ADD_F32(15); - case 15: - CVT_ADD_F32(14); - case 14: - CVT_ADD_F32(13); - case 13: - CVT_ADD_F32(12); - case 12: - CVT_ADD_F32(11); - case 11: - CVT_ADD_F32(10); - case 10: - CVT_ADD_F32(9); - case 9: - CVT_ADD_F32(8); - case 8: - CVT_ADD_F32(7); - case 7: - CVT_ADD_F32(6); - case 6: - CVT_ADD_F32(5); - case 5: - CVT_ADD_F32(4); - case 4: - CVT_ADD_F32(3); - case 3: - CVT_ADD_F32(2); - case 2: - CVT_ADD_F32(1); - case 1: - break; - default: - for (int j = 1; j < world_size; j++) { - auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i)); - inout_val = _mm256_add_ps(inout_val, in_val); - } + size_t offset = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + memcpy(to_buffer + offset, buffers[0] + offset, element_size); + for (int j = 1; j < world_size; j++) { + + fn(to_buffer + offset, + to_buffer + offset, + buffers[j] + offset, + 1); + } - _mm256_storeu_ps((float*)(to_buffer + i), inout_val); + remain_elements--; + offset += element_size; } - - // process remaining part - reduce_remaining_part(start_elements, num_elements, remain_elements, main_elements, element_size, to_buffer, buffers, fn); + } void shm_initialize(int size, int rank, char* addr_string, char* port_string) { @@ -551,7 +280,7 @@ size_t slice_el_start(size_t chunk_el, int slice_idx) { void symmetric_naive_all_reduce( char* data_ptr, - AllreduceOptions::ScalarType scalar_type, + int element_size, size_t chunk_size, size_t chunk_el, ReductionFunction fn) { @@ -596,7 +325,7 @@ void symmetric_naive_all_reduce( reduce_all_buffers( 0, chunk_el, - scalar_type, + element_size, world_rank, data_ptr, symmetric_buffer[current_buffer], @@ -609,7 +338,7 @@ void symmetric_naive_all_reduce( // naive allreduce distributed, each rank do naive reduce on its slice void distributed_naive_reduce( char* data_ptr, - AllreduceOptions::ScalarType scalar_type, + int element_size, size_t chunk_size, size_t chunk_el, ReductionFunction fn) { @@ -653,7 +382,7 @@ void distributed_naive_reduce( reduce_all_buffers( slice_el_start(chunk_el, world_rank), slice_size(chunk_el, world_rank), - scalar_type, + element_size, world_rank, distributed_buffer[current_buffer][world_rank], distributed_buffer[current_buffer], @@ -738,10 +467,10 @@ void shm(const detail::AllreduceOptionsImpl& opts) { size_t chunk_el = chunk_size / (data_size / opts.elements); if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { symmetric_naive_all_reduce( - data_ptr, opts.scalarType, chunk_size, chunk_el, opts.reduce); + data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); } else { distributed_naive_reduce( - data_ptr, opts.scalarType, chunk_size, chunk_el, opts.reduce); + data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); } } From 0fdde35a948e81970d07f2c149d6eda3801ed44d Mon Sep 17 00:00:00 2001 From: gaopengf Date: Thu, 24 Jul 2025 03:27:29 -0400 Subject: [PATCH 06/16] refine format --- gloo/CMakeLists.txt | 1 - gloo/allreduce.cc | 1 - gloo/allreduce.h | 8 -------- gloo/allreduce_shm.cc | 6 ++++++ gloo/allreduce_shm.h | 2 +- 5 files changed, 7 insertions(+), 11 deletions(-) diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index db54496ef..fb65defd5 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -188,7 +188,6 @@ if(USE_ROCM) endif() endif() - # Install if necessary. # If the Gloo build is included from another project's build, it may # want to statically link with Gloo and not install any artifacts. diff --git a/gloo/allreduce.cc b/gloo/allreduce.cc index 511e8d3d3..4099dd757 100644 --- a/gloo/allreduce.cc +++ b/gloo/allreduce.cc @@ -96,7 +96,6 @@ BroadcastRangeFunction genLocalBroadcastFunction(const BufferVector& out) { } void allreduce(const detail::AllreduceOptionsImpl& opts) { - //printf("In gloo::allreduce\n"); if (opts.elements == 0) { return; } diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 8fc494358..904eb8b32 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -11,17 +11,9 @@ #include #include #include -#include #include "gloo/context.h" #include "gloo/transport/unbound_buffer.h" -#include "gloo/types.h" - -#define GPF_PRINT(...) do {\ - printf("GPF_DEBUG:");\ - printf(__VA_ARGS__);\ - printf("\n");\ -}while(0) namespace gloo { diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 6af27d3ad..f972cd3ea 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -11,6 +11,12 @@ #include #include +#define GPF_PRINT(...) do {\ + printf("GPF_DEBUG:");\ + printf(__VA_ARGS__);\ + printf("\n");\ +}while(0) + namespace gloo { diff --git a/gloo/allreduce_shm.h b/gloo/allreduce_shm.h index e9236759c..3271ba4a2 100644 --- a/gloo/allreduce_shm.h +++ b/gloo/allreduce_shm.h @@ -5,4 +5,4 @@ namespace gloo { bool is_intra_node(const int size); void shm(const detail::AllreduceOptionsImpl& opts); -} // namespace gloo \ No newline at end of file +} // namespace gloo From be7da7ce6a9f459874b66b698bf4385ffd34ed19 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Mon, 11 Aug 2025 03:17:51 -0400 Subject: [PATCH 07/16] fix accuracy issue --- gloo/allreduce_shm.cc | 51 ++++--------------------------------------- 1 file changed, 4 insertions(+), 47 deletions(-) diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index f972cd3ea..9d71d901f 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -132,54 +132,11 @@ void reduce_all_buffers( char* to_buffer, char** buffers, ReductionFunction fn) { - const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; - int main_elements = num_elements - (num_elements % vector_length); - int remain_elements = num_elements % vector_length; - -#pragma omp parallel for - for (int i = start_elements * element_size; - i < (start_elements + main_elements) * element_size; - i += VECTOR_LENGTH_IN_BYTES) { - memcpy(to_buffer + i, buffers[0] + i, element_size); - switch (world_size){ - case 16: fn(to_buffer + i, to_buffer + i, buffers[15] + i, vector_length); - case 15: fn(to_buffer + i, to_buffer + i, buffers[14] + i, vector_length); - case 14: fn(to_buffer + i, to_buffer + i, buffers[13] + i, vector_length); - case 13: fn(to_buffer + i, to_buffer + i, buffers[12] + i, vector_length); - case 12: fn(to_buffer + i, to_buffer + i, buffers[11] + i, vector_length); - case 11: fn(to_buffer + i, to_buffer + i, buffers[10] + i, vector_length); - case 10: fn(to_buffer + i, to_buffer + i, buffers[9] + i, vector_length); - case 9: fn(to_buffer + i, to_buffer + i, buffers[8] + i, vector_length); - case 8: fn(to_buffer + i, to_buffer + i, buffers[7] + i, vector_length); - case 7: fn(to_buffer + i, to_buffer + i, buffers[6] + i, vector_length); - case 6: fn(to_buffer + i, to_buffer + i, buffers[5] + i, vector_length); - case 5: fn(to_buffer + i, to_buffer + i, buffers[4] + i, vector_length); - case 4: fn(to_buffer + i, to_buffer + i, buffers[3] + i, vector_length); - case 3: fn(to_buffer + i, to_buffer + i, buffers[2] + i, vector_length); - case 2: fn(to_buffer + i, to_buffer + i, buffers[1] + i, vector_length); - case 1: break; - default: - for (int j = 1; j < world_size; j++) { - fn(to_buffer + i, to_buffer + i, buffers[j] + i, vector_length); - } - } - } - - size_t offset = (start_elements + main_elements) * element_size; - while (remain_elements > 0) { - memcpy(to_buffer + offset, buffers[0] + offset, element_size); - for (int j = 1; j < world_size; j++) { - - fn(to_buffer + offset, - to_buffer + offset, - buffers[j] + offset, - 1); - - } - remain_elements--; - offset += element_size; + size_t offset = start_elements * element_size; + memcpy(to_buffer + offset, buffers[0] + offset, num_elements * element_size); + for (int i = 1; i < world_size; i++) { + fn(to_buffer + offset, to_buffer + offset, buffers[i] + offset, num_elements); } - } void shm_initialize(int size, int rank, char* addr_string, char* port_string) { From 5b698dce331235ef13db46755da8ff05e76693c2 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Mon, 18 Aug 2025 04:47:29 -0400 Subject: [PATCH 08/16] move intro-node check to allreduce() --- gloo/allreduce.cc | 19 ++++++++++--------- gloo/allreduce.h | 1 + 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/gloo/allreduce.cc b/gloo/allreduce.cc index 4099dd757..bc78a263f 100644 --- a/gloo/allreduce.cc +++ b/gloo/allreduce.cc @@ -132,7 +132,13 @@ void allreduce(const detail::AllreduceOptionsImpl& opts) { return; } - switch (opts.algorithm) { + auto algorithm = opts.algorithm; + if (is_intra_node(context->size)) { + algorithm = detail::AllreduceOptionsImpl::SHM; + } + + + switch (algorithm) { case detail::AllreduceOptionsImpl::UNSPECIFIED: case detail::AllreduceOptionsImpl::RING: ring(opts, reduceInputs, broadcastOutputs); @@ -140,6 +146,9 @@ void allreduce(const detail::AllreduceOptionsImpl& opts) { case detail::AllreduceOptionsImpl::BCUBE: bcube(opts, reduceInputs, broadcastOutputs); break; + case detail::AllreduceOptionsImpl::SHM: + shm(opts); + break; default: GLOO_ENFORCE(false, "Algorithm not handled."); } @@ -154,14 +163,6 @@ void ring( const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag); const size_t totalBytes = opts.elements * opts.elementSize; - - if (is_intra_node(context->size)) { - shm(opts); - return; - } - - //shm(opts); - //return; // Note: context->size > 1 const auto recvRank = (context->size + context->rank + 1) % context->size; diff --git a/gloo/allreduce.h b/gloo/allreduce.h index 904eb8b32..6f6037ede 100644 --- a/gloo/allreduce.h +++ b/gloo/allreduce.h @@ -39,6 +39,7 @@ struct AllreduceOptionsImpl { UNSPECIFIED = 0, RING = 1, BCUBE = 2, + SHM = 3, }; explicit AllreduceOptionsImpl(const std::shared_ptr& context) From 3564e959f9bff325cbfe75ff55ca2f3ad8ca354b Mon Sep 17 00:00:00 2001 From: gaopengf Date: Wed, 20 Aug 2025 04:15:32 -0400 Subject: [PATCH 09/16] remove debug code --- gloo/allreduce_shm.cc | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 9d71d901f..579d701ef 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -11,13 +11,6 @@ #include #include -#define GPF_PRINT(...) do {\ - printf("GPF_DEBUG:");\ - printf(__VA_ARGS__);\ - printf("\n");\ -}while(0) - - namespace gloo { namespace { @@ -87,7 +80,7 @@ static bool is_initialized = false; #define NAME_BUF_SIZE 1000 #define MAX_BUF_SIZE 1048576 * 32 #define NAIVE_ALLREDUCE_THRESHOLD 1048576 -#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer" +#define SHM_BUFFER_NAME "shm_allreduce_buffer" struct allreduce_workspace { enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce // idx=1 -- state for distributed_naive_all_reduce @@ -393,11 +386,8 @@ void shm(const detail::AllreduceOptionsImpl& opts) { const auto& context = opts.context; if (!is_initialized) { - //int size = context->size; - //int rank = context->rank; - - int size = std::stoi(std::getenv("PMI_SIZE")); - int rank = std::stoi(std::getenv("PMI_RANK")); + int size = context->size; + int rank = context->rank; world_size = size; world_rank = rank; @@ -411,12 +401,7 @@ void shm(const detail::AllreduceOptionsImpl& opts) { if (port_string == NULL) { port_string = ""; } - // std::cout << "size: " << size << std::endl; - // std::cout << "rank: " << rank << std::endl; - // std::cout << "addr_string: " << addr_string << std::endl; - // std::cout << "port_string: " << port_string << std::endl; shm_initialize(size, rank, addr_string, port_string); - GPF_PRINT("SHM reduce has been initialized"); } const size_t data_size = opts.elements * opts.elementSize; From 3ac80659c01cdc3cd8f3613708cc954b11a64dc4 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Mon, 25 Aug 2025 07:25:00 -0400 Subject: [PATCH 10/16] use local_rank to check intra-node condition --- gloo/allreduce.cc | 2 +- gloo/allreduce_shm.cc | 31 ++++++++++++++++++++++++------- gloo/allreduce_shm.h | 2 +- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/gloo/allreduce.cc b/gloo/allreduce.cc index bc78a263f..cefa14c64 100644 --- a/gloo/allreduce.cc +++ b/gloo/allreduce.cc @@ -133,7 +133,7 @@ void allreduce(const detail::AllreduceOptionsImpl& opts) { } auto algorithm = opts.algorithm; - if (is_intra_node(context->size)) { + if (is_intra_node(opts)) { algorithm = detail::AllreduceOptionsImpl::SHM; } diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 579d701ef..af8250f3c 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -1,4 +1,6 @@ #include "gloo/allreduce_shm.h" +#include "gloo/barrier.h" +#include "gloo/broadcast.h" #include #include @@ -369,18 +371,33 @@ void distributed_naive_reduce( } // namespace -bool is_intra_node(const int size) { - // must launch with torchrun - auto local_size_string = std::getenv("LOCAL_WORLD_SIZE"); - int local_size = 0; - if (local_size_string != NULL) { - local_size = std::stoi(local_size_string); +bool is_intra_node(const detail::AllreduceOptionsImpl& opts) { + + // It's difficult to get local_world_size directly in gloo. However, we could get local_rank infos from each rank's connection info. + // In intra-node scenario, the local_rank of last rank is supposed to be equal to world_size - 1. + const auto& context = opts.context; + int rank = context->rank; + int world_size = context->size; + int max_local_rank = 0; + // Get max local rank from pair 0. + if (rank == world_size - 1) { + max_local_rank = context->getPair(0)->getLocalRank(); } - return size > 1 && size == local_size; + // Do broadcast + BroadcastOptions broadcast_opts(context); + broadcast_opts.setRoot(world_size - 1); + broadcast_opts.setOutput(&max_local_rank, sizeof(max_local_rank)); + broadcast(broadcast_opts); + + // Do barrier + BarrierOptions barrier_opts(context); + barrier(barrier_opts); + return max_local_rank == world_size - 1; } + void shm(const detail::AllreduceOptionsImpl& opts) { const auto& context = opts.context; diff --git a/gloo/allreduce_shm.h b/gloo/allreduce_shm.h index 3271ba4a2..7f3f21d88 100644 --- a/gloo/allreduce_shm.h +++ b/gloo/allreduce_shm.h @@ -2,7 +2,7 @@ namespace gloo { -bool is_intra_node(const int size); +bool is_intra_node(const detail::AllreduceOptionsImpl& opts); void shm(const detail::AllreduceOptionsImpl& opts); } // namespace gloo From 8d3ae22267011e64ad591502e01c511216447faa Mon Sep 17 00:00:00 2001 From: gaopengf Date: Fri, 29 Aug 2025 06:09:37 -0400 Subject: [PATCH 11/16] add support for multi-thread and code for local reduction --- gloo/allreduce_shm.cc | 60 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index af8250f3c..41a741aae 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -73,9 +73,9 @@ void shared_create( } } -static int world_rank = -1; -static int world_size = -1; -static bool is_initialized = false; +thread_local static int world_rank = -1; +thread_local static int world_size = -1; +thread_local static bool is_initialized = false; // SHM based allreduce helper functions // buffer that holds shm name @@ -97,12 +97,12 @@ struct allreduce_workspace { #define BUFFER1_OFFSET(current_buffer) \ 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE -struct allreduce_workspace** workspace; +thread_local struct allreduce_workspace** workspace; // buffer for small messages, double buffer -char** symmetric_buffer[2]; +thread_local char** symmetric_buffer[2]; // buffer for large messages, double buffer -char** distributed_buffer[2]; +thread_local char** distributed_buffer[2]; void wait_buffer_state_until_2( int index, @@ -243,8 +243,8 @@ void symmetric_naive_all_reduce( size_t chunk_el, ReductionFunction fn) { const int state_group = 0; - static int current_buffer = 0; - static int state_idx = 0; + thread_local static int current_buffer = 0; + thread_local static int state_idx = 0; enum coll_state copy_current, copy_next; @@ -301,8 +301,8 @@ void distributed_naive_reduce( size_t chunk_el, ReductionFunction fn) { const int state_group = 1; - static int current_buffer = 0; - static int state_idx = 0; + thread_local static int current_buffer = 0; + thread_local static int state_idx = 0; enum coll_state copy_current, copy_next, reduce_current; @@ -421,8 +421,37 @@ void shm(const detail::AllreduceOptionsImpl& opts) { shm_initialize(size, rank, addr_string, port_string); } - const size_t data_size = opts.elements * opts.elementSize; - const std::vector>& out = opts.out; + const size_t data_size = opts.elements * opts.elementSize; + auto& in = opts.in; + auto& out = opts.out; + + // Do local reduction + if (in.size() > 0) { + if (in.size() == 1) { + memcpy(static_cast(out[0]->ptr), static_cast(in[0]->ptr), data_size); + } else { + opts.reduce(static_cast(out[0]->ptr), + static_cast(in[0]->ptr), + static_cast(in[1]->ptr), + opts.elements); + for (size_t i = 2; i < in.size(); i++) { + opts.reduce(static_cast(out[0]->ptr), + static_cast(out[0]->ptr), + static_cast(in[i]->ptr), + opts.elements); + } + } + } else { + for (size_t i = 1; i < out.size(); i++) { + opts.reduce(static_cast(out[0]->ptr), + static_cast(out[0]->ptr), + static_cast(out[i]->ptr), + opts.elements); + } + } + + + void* data = out[0].get()->ptr; for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { @@ -439,6 +468,13 @@ void shm(const detail::AllreduceOptionsImpl& opts) { } } + if (out.size() > 1) { + for (size_t i = 1; i < out.size(); i++) { + memcpy(static_cast(out[i]->ptr), static_cast(out[0]->ptr), data_size); + } + } + + } } //namespace gloo From 2985c3e82e5f1a460accc2266a450c10b3f7a609 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Mon, 1 Sep 2025 04:05:13 -0400 Subject: [PATCH 12/16] move intra-node check to createAndConnectAllPairs --- gloo/CMakeLists.txt | 7 +++++-- gloo/allreduce.cc | 6 ++++-- gloo/allreduce_shm.cc | 27 --------------------------- gloo/allreduce_shm.h | 1 - gloo/context.cc | 4 ++++ gloo/context.h | 2 ++ gloo/transport/context.cc | 3 +++ gloo/transport/context.h | 7 +++++++ gloo/transport/tcp/context.cc | 3 +++ 9 files changed, 28 insertions(+), 32 deletions(-) diff --git a/gloo/CMakeLists.txt b/gloo/CMakeLists.txt index fb65defd5..0a14cc629 100644 --- a/gloo/CMakeLists.txt +++ b/gloo/CMakeLists.txt @@ -11,7 +11,6 @@ list(APPEND GLOO_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/allgatherv.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_shm.cc" "${CMAKE_CURRENT_SOURCE_DIR}/alltoall.cc" "${CMAKE_CURRENT_SOURCE_DIR}/alltoallv.cc" "${CMAKE_CURRENT_SOURCE_DIR}/barrier.cc" @@ -35,7 +34,6 @@ list(APPEND GLOO_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.h" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_ring.h" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_ring_chunked.h" - "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_shm.h" "${CMAKE_CURRENT_SOURCE_DIR}/alltoall.h" "${CMAKE_CURRENT_SOURCE_DIR}/alltoallv.h" "${CMAKE_CURRENT_SOURCE_DIR}/barrier.h" @@ -53,6 +51,11 @@ list(APPEND GLOO_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/types.h" ) +if(NOT MSVC) + list(APPEND GLOO_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_shm.cc") + list(APPEND GLOO_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_shm.h") +endif() + if(USE_CUDA) file(GLOB GLOO_CUDA_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/cuda*.cc" diff --git a/gloo/allreduce.cc b/gloo/allreduce.cc index cefa14c64..f659d58da 100644 --- a/gloo/allreduce.cc +++ b/gloo/allreduce.cc @@ -133,10 +133,12 @@ void allreduce(const detail::AllreduceOptionsImpl& opts) { } auto algorithm = opts.algorithm; - if (is_intra_node(opts)) { + +#ifndef _WIN32 + if (context->isIntraNode()) { algorithm = detail::AllreduceOptionsImpl::SHM; } - +#endif switch (algorithm) { case detail::AllreduceOptionsImpl::UNSPECIFIED: diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 41a741aae..9b62d11ad 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -371,33 +371,6 @@ void distributed_naive_reduce( } // namespace -bool is_intra_node(const detail::AllreduceOptionsImpl& opts) { - - // It's difficult to get local_world_size directly in gloo. However, we could get local_rank infos from each rank's connection info. - // In intra-node scenario, the local_rank of last rank is supposed to be equal to world_size - 1. - const auto& context = opts.context; - int rank = context->rank; - int world_size = context->size; - int max_local_rank = 0; - // Get max local rank from pair 0. - if (rank == world_size - 1) { - max_local_rank = context->getPair(0)->getLocalRank(); - } - - // Do broadcast - BroadcastOptions broadcast_opts(context); - broadcast_opts.setRoot(world_size - 1); - broadcast_opts.setOutput(&max_local_rank, sizeof(max_local_rank)); - broadcast(broadcast_opts); - - // Do barrier - BarrierOptions barrier_opts(context); - barrier(barrier_opts); - return max_local_rank == world_size - 1; -} - - - void shm(const detail::AllreduceOptionsImpl& opts) { const auto& context = opts.context; diff --git a/gloo/allreduce_shm.h b/gloo/allreduce_shm.h index 7f3f21d88..ca8b14621 100644 --- a/gloo/allreduce_shm.h +++ b/gloo/allreduce_shm.h @@ -2,7 +2,6 @@ namespace gloo { -bool is_intra_node(const detail::AllreduceOptionsImpl& opts); void shm(const detail::AllreduceOptionsImpl& opts); } // namespace gloo diff --git a/gloo/context.cc b/gloo/context.cc index fd9b83c7b..367ad88b4 100644 --- a/gloo/context.cc +++ b/gloo/context.cc @@ -67,4 +67,8 @@ std::chrono::milliseconds Context::getTimeout() const { return timeout_; } +bool Context::isIntraNode() const { + return transportContext_->isIntraNode(); +} + } // namespace gloo diff --git a/gloo/context.h b/gloo/context.h index 0bcf0bef7..0dc07d9a5 100644 --- a/gloo/context.h +++ b/gloo/context.h @@ -51,6 +51,8 @@ class Context { std::chrono::milliseconds getTimeout() const; + bool isIntraNode() const; + protected: std::shared_ptr device_; std::shared_ptr transportContext_; diff --git a/gloo/transport/context.cc b/gloo/transport/context.cc index 4b4cfadbf..edd3729f3 100644 --- a/gloo/transport/context.cc +++ b/gloo/transport/context.cc @@ -42,6 +42,7 @@ void Context::createAndConnectAllPairs(std::shared_ptr store) { const std::vector value(localHostName.begin(), localHostName.end()); store->set(localKey, value); + intraNode_ = true; for (int i = 0; i < size; i++) { if (i == rank) { break; @@ -54,6 +55,8 @@ void Context::createAndConnectAllPairs(std::shared_ptr store) { if (hostName == localHostName) { localRank++; } + + intraNode_ = intraNode_ && hostName == localHostName; } // Create pairs diff --git a/gloo/transport/context.h b/gloo/transport/context.h index ca87ad365..65ff1ab20 100644 --- a/gloo/transport/context.h +++ b/gloo/transport/context.h @@ -68,6 +68,10 @@ class Context { return timeout_; } + bool isIntraNode() const{ + return intraNode_; + } + protected: // Protects access to the pending operations and expected // notifications vectors. These vectors can only be mutated by an @@ -93,6 +97,9 @@ class Context { // any kind of send/recv operation. std::chrono::milliseconds timeout_; + // Whether is intra-node. + bool intraNode_ = false; + std::vector extractAddress(const std::vector& allAddrs, int i) const; diff --git a/gloo/transport/tcp/context.cc b/gloo/transport/tcp/context.cc index 422896b64..e380ecc05 100644 --- a/gloo/transport/tcp/context.cc +++ b/gloo/transport/tcp/context.cc @@ -123,6 +123,7 @@ void Context::createAndConnectAllPairs(std::shared_ptr store) { }); } + intraNode_ = true; // Connect every pair for (int i = 0; i < size; i++) { if (i == rank) { @@ -140,6 +141,8 @@ void Context::createAndConnectAllPairs(std::shared_ptr store) { ++localRank; } + intraNode_ = intraNode_ && (remoteRankInfo.hostname == localHostName); + const auto& pair = pairs_[i]; auto remoteDeviceAddr = Address(remoteRankInfo.addressBytes).getSockaddr(); From 9ae28428c9f6d2d2b6924b7fc757e6cb7717baf0 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Wed, 3 Sep 2025 04:21:58 -0400 Subject: [PATCH 13/16] add check for gpu input and fix format issue --- gloo/allreduce.cc | 8 ++- gloo/allreduce_shm.cc | 120 +++++++++++++++++++++------------------ gloo/transport/context.h | 2 +- 3 files changed, 72 insertions(+), 58 deletions(-) diff --git a/gloo/allreduce.cc b/gloo/allreduce.cc index f659d58da..e5871ad9a 100644 --- a/gloo/allreduce.cc +++ b/gloo/allreduce.cc @@ -12,10 +12,11 @@ #include #include +#include "gloo/allreduce_shm.h" #include "gloo/common/logging.h" #include "gloo/math.h" +#include "gloo/transport/device.h" #include "gloo/types.h" -#include "gloo/allreduce_shm.h" namespace gloo { @@ -135,7 +136,7 @@ void allreduce(const detail::AllreduceOptionsImpl& opts) { auto algorithm = opts.algorithm; #ifndef _WIN32 - if (context->isIntraNode()) { + if (context->isIntraNode() && !context->getDevice()->hasGPUDirect()) { algorithm = detail::AllreduceOptionsImpl::SHM; } #endif @@ -148,9 +149,11 @@ void allreduce(const detail::AllreduceOptionsImpl& opts) { case detail::AllreduceOptionsImpl::BCUBE: bcube(opts, reduceInputs, broadcastOutputs); break; +#ifndef _WIN32 case detail::AllreduceOptionsImpl::SHM: shm(opts); break; +#endif default: GLOO_ENFORCE(false, "Algorithm not handled."); } @@ -165,7 +168,6 @@ void ring( const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag); const size_t totalBytes = opts.elements * opts.elementSize; - // Note: context->size > 1 const auto recvRank = (context->size + context->rank + 1) % context->size; const auto sendRank = (context->size + context->rank - 1) % context->size; diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 9b62d11ad..88eebbb6b 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -2,6 +2,7 @@ #include "gloo/barrier.h" #include "gloo/broadcast.h" +#include #include #include #include @@ -11,7 +12,6 @@ #include #include #include -#include namespace gloo { @@ -127,14 +127,21 @@ void reduce_all_buffers( char* to_buffer, char** buffers, ReductionFunction fn) { - size_t offset = start_elements * element_size; + size_t offset = start_elements * element_size; memcpy(to_buffer + offset, buffers[0] + offset, num_elements * element_size); for (int i = 1; i < world_size; i++) { - fn(to_buffer + offset, to_buffer + offset, buffers[i] + offset, num_elements); + fn(to_buffer + offset, + to_buffer + offset, + buffers[i] + offset, + num_elements); } } -void shm_initialize(int size, int rank, char* addr_string, char* port_string) { +void shm_initialize( + int size, + int rank, + const char* addr_string, + const char* port_string) { world_size = size; world_rank = rank; @@ -155,7 +162,8 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) { struct allreduce_workspace* workspace_buf_other; workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace)); - int written = snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); + int written = + snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); if (written >= NAME_BUF_SIZE) { std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; } @@ -179,10 +187,11 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) { // map shm of all ranks for (int i = 0; i < size; i++) { if (i != rank) { - int written = snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); - if (written >= NAME_BUF_SIZE) { - std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; - } + int written = + snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); + if (written >= NAME_BUF_SIZE) { + std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; + } // printf("open %s, %d\n", shm_name, rank); do { shared_open( @@ -372,10 +381,8 @@ void distributed_naive_reduce( } // namespace void shm(const detail::AllreduceOptionsImpl& opts) { - - const auto& context = opts.context; + const auto& context = opts.context; if (!is_initialized) { - int size = context->size; int rank = context->rank; @@ -383,72 +390,77 @@ void shm(const detail::AllreduceOptionsImpl& opts) { world_rank = rank; is_initialized = true; - auto addr_string = std::getenv("MASTER_ADDR"); - if (addr_string == NULL) { - addr_string = ""; + std::string addr_string(""), port_string(""); + const auto& addr_string_env = std::getenv("MASTER_ADDR"); + if (addr_string_env != nullptr) { + addr_string = addr_string_env; } - auto port_string = std::getenv("MASTER_PORT"); - if (port_string == NULL) { - port_string = ""; + const auto port_string_env = std::getenv("MASTER_PORT"); + if (port_string_env != NULL) { + port_string = port_string_env; } - shm_initialize(size, rank, addr_string, port_string); + shm_initialize(size, rank, addr_string.c_str(), port_string.c_str()); } - const size_t data_size = opts.elements * opts.elementSize; + const size_t data_size = opts.elements * opts.elementSize; auto& in = opts.in; auto& out = opts.out; // Do local reduction if (in.size() > 0) { if (in.size() == 1) { - memcpy(static_cast(out[0]->ptr), static_cast(in[0]->ptr), data_size); + memcpy( + static_cast(out[0]->ptr), + static_cast(in[0]->ptr), + data_size); } else { - opts.reduce(static_cast(out[0]->ptr), - static_cast(in[0]->ptr), - static_cast(in[1]->ptr), - opts.elements); - for (size_t i = 2; i < in.size(); i++) { - opts.reduce(static_cast(out[0]->ptr), - static_cast(out[0]->ptr), - static_cast(in[i]->ptr), - opts.elements); - } + opts.reduce( + static_cast(out[0]->ptr), + static_cast(in[0]->ptr), + static_cast(in[1]->ptr), + opts.elements); + for (size_t i = 2; i < in.size(); i++) { + opts.reduce( + static_cast(out[0]->ptr), + static_cast(out[0]->ptr), + static_cast(in[i]->ptr), + opts.elements); + } } } else { for (size_t i = 1; i < out.size(); i++) { - opts.reduce(static_cast(out[0]->ptr), - static_cast(out[0]->ptr), - static_cast(out[i]->ptr), - opts.elements); + opts.reduce( + static_cast(out[0]->ptr), + static_cast(out[0]->ptr), + static_cast(out[i]->ptr), + opts.elements); } } - - void* data = out[0].get()->ptr; - for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { - auto data_ptr = ((char*)(data) + offset); - size_t chunk_size = - data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; - size_t chunk_el = chunk_size / (data_size / opts.elements); - if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { - symmetric_naive_all_reduce( - data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); - } else { - distributed_naive_reduce( - data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); - } + for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { + auto data_ptr = ((char*)(data) + offset); + size_t chunk_size = + data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; + size_t chunk_el = chunk_size / (data_size / opts.elements); + if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { + symmetric_naive_all_reduce( + data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); + } else { + distributed_naive_reduce( + data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); + } } if (out.size() > 1) { for (size_t i = 1; i < out.size(); i++) { - memcpy(static_cast(out[i]->ptr), static_cast(out[0]->ptr), data_size); + memcpy( + static_cast(out[i]->ptr), + static_cast(out[0]->ptr), + data_size); } } - - } -} //namespace gloo - +} // namespace gloo diff --git a/gloo/transport/context.h b/gloo/transport/context.h index 9e6eaa241..56d8ad8ec 100644 --- a/gloo/transport/context.h +++ b/gloo/transport/context.h @@ -68,7 +68,7 @@ class Context { return timeout_; } - bool isIntraNode() const{ + bool isIntraNode() const { return intraNode_; } From 1b7660a7942ad46e0cad4bdd24d8283647f15404 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Wed, 3 Sep 2025 04:24:37 -0400 Subject: [PATCH 14/16] fix format issue --- gloo/allreduce_shm.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gloo/allreduce_shm.h b/gloo/allreduce_shm.h index ca8b14621..6cf3ccc44 100644 --- a/gloo/allreduce_shm.h +++ b/gloo/allreduce_shm.h @@ -1,7 +1,7 @@ #include "gloo/allreduce.h" namespace gloo { - + void shm(const detail::AllreduceOptionsImpl& opts); } // namespace gloo From ab9c63b939e9dcf52dcdb8e840cb247e8ab85c73 Mon Sep 17 00:00:00 2001 From: "Gao, Pengfei" Date: Fri, 5 Sep 2025 08:09:36 -0400 Subject: [PATCH 15/16] fix timeout ut --- gloo/allreduce_shm.cc | 53 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index 88eebbb6b..f6404e02d 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -12,6 +12,7 @@ #include #include #include +#include namespace gloo { @@ -108,14 +109,28 @@ void wait_buffer_state_until_2( int index, enum coll_state state0, enum coll_state state1, - int state_group) { + int state_group, + std::chrono::milliseconds timeout) { volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]); - while (1) { + auto total_milliseconds = timeout.count(); + auto count = 0; + while (count < total_milliseconds) { volatile enum coll_state cur_state = *state_ptr; - if (cur_state == state0 || cur_state == state1) + if (cur_state == state0 || cur_state == state1) { break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + count += 10; + } + + volatile enum coll_state cur_state = *state_ptr; + if (!(cur_state == state0 || cur_state == state1)) { + throw ::gloo::IoException(GLOO_ERROR_MSG( + "Timed out waiting", + timeout.count(), + "ms for wait buffer state operation to complete")); } } @@ -250,7 +265,8 @@ void symmetric_naive_all_reduce( int element_size, size_t chunk_size, size_t chunk_el, - ReductionFunction fn) { + ReductionFunction fn, + std::chrono::milliseconds timeout) { const int state_group = 0; thread_local static int current_buffer = 0; thread_local static int state_idx = 0; @@ -283,7 +299,8 @@ void symmetric_naive_all_reduce( for (int i = 0; i < world_size; i++) { // wait until the other rank copy the buffer if (i != world_rank) { - wait_buffer_state_until_2(i, copy_current, copy_next, state_group); + wait_buffer_state_until_2( + i, copy_current, copy_next, state_group, timeout); } } @@ -308,7 +325,8 @@ void distributed_naive_reduce( int element_size, size_t chunk_size, size_t chunk_el, - ReductionFunction fn) { + ReductionFunction fn, + std::chrono::milliseconds timeout) { const int state_group = 1; thread_local static int current_buffer = 0; thread_local static int state_idx = 0; @@ -316,7 +334,8 @@ void distributed_naive_reduce( enum coll_state copy_current, copy_next, reduce_current; // similar to symmetric_naive_allreduce, but here we only need two sets of - // states, because distributed naive reduce has two barriers in the algorithm + // states, because distributed naive reduce has two barriers in the + // algorithm switch (state_idx) { case 0: copy_current = coll_allreduce_naive__copy_in_done; @@ -342,7 +361,8 @@ void distributed_naive_reduce( for (int i = 0; i < world_size; i++) { // wait until all the other ranks copy the buffer if (i != world_rank) - wait_buffer_state_until_2(i, copy_current, reduce_current, state_group); + wait_buffer_state_until_2( + i, copy_current, reduce_current, state_group, timeout); } // reduce scatter @@ -360,7 +380,8 @@ void distributed_naive_reduce( for (int i = 0; i < world_size; i++) { // wait until all the other ranks reduce the buffer if (i != world_rank) - wait_buffer_state_until_2(i, reduce_current, copy_next, state_group); + wait_buffer_state_until_2( + i, reduce_current, copy_next, state_group, timeout); } for (int i = 0; i < world_size; i++) { @@ -446,10 +467,20 @@ void shm(const detail::AllreduceOptionsImpl& opts) { size_t chunk_el = chunk_size / (data_size / opts.elements); if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { symmetric_naive_all_reduce( - data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); + data_ptr, + opts.elementSize, + chunk_size, + chunk_el, + opts.reduce, + opts.timeout); } else { distributed_naive_reduce( - data_ptr, opts.elementSize, chunk_size, chunk_el, opts.reduce); + data_ptr, + opts.elementSize, + chunk_size, + chunk_el, + opts.reduce, + opts.timeout); } } From 9f726e9d160879ac0923f318b2f100566fc88933 Mon Sep 17 00:00:00 2001 From: gaopengf Date: Thu, 11 Sep 2025 22:03:13 -0400 Subject: [PATCH 16/16] add fininalize method and fix ut --- gloo/allreduce_shm.cc | 472 +++++++++++++++++++------------------- gloo/allreduce_shm.h | 62 +++++ gloo/context.h | 4 + gloo/transport/context.cc | 2 +- 4 files changed, 309 insertions(+), 231 deletions(-) diff --git a/gloo/allreduce_shm.cc b/gloo/allreduce_shm.cc index f6404e02d..6de573da4 100644 --- a/gloo/allreduce_shm.cc +++ b/gloo/allreduce_shm.cc @@ -1,6 +1,4 @@ #include "gloo/allreduce_shm.h" -#include "gloo/barrier.h" -#include "gloo/broadcast.h" #include #include @@ -19,19 +17,16 @@ namespace gloo { namespace { using ReductionFunction = AllreduceOptions::Func; +using CollState = AllreduceSharedMemoryData::CollState; +using Allreduceworkspace = AllreduceSharedMemoryData::AllreduceWorkspace; -#define VECTOR_LENGTH_IN_BYTES 32 -// states for collectives -enum coll_state { - coll_begin = 0, - coll_allreduce_naive__copy_in_done, - coll_allreduce_naive__reduce_done, - // alternative state when allreduce is working on alternative buffer - // of the double buffer. - coll_alt1_allreduce_naive__copy_in_done, - coll_alt2_allreduce_naive__copy_in_done, - coll_alt1_allreduce_naive__reduce_done, -}; +constexpr int VECTOR_LENGTH_IN_BYTES = 32; + +#define BUFFER0_OFFSET(current_buffer) \ + current_buffer* Allreduceworkspace::NAIVE_ALLREDUCE_THRESHOLD +#define BUFFER1_OFFSET(current_buffer) \ + 2 * Allreduceworkspace::NAIVE_ALLREDUCE_THRESHOLD + \ + current_buffer* Allreduceworkspace::MAX_BUF_SIZE // SHM building blocks struct SharedData { @@ -74,63 +69,65 @@ void shared_create( } } -thread_local static int world_rank = -1; -thread_local static int world_size = -1; -thread_local static bool is_initialized = false; - -// SHM based allreduce helper functions -// buffer that holds shm name -#define NAME_BUF_SIZE 1000 -#define MAX_BUF_SIZE 1048576 * 32 -#define NAIVE_ALLREDUCE_THRESHOLD 1048576 -#define SHM_BUFFER_NAME "shm_allreduce_buffer" -struct allreduce_workspace { - enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce - // idx=1 -- state for distributed_naive_all_reduce - // double buffer to avoid syncing between rounds - // offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for - // symmetric_naive_all_reduce after that : buffer for - // distributed_naive_all_reduce - char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE]; -}; +void wait_buffer_state( + CollState state0, + CollState state1, + int state_group, + std::chrono::milliseconds timeout, + std::shared_ptr shm_data) { + // Create a new thread + auto workspace = shm_data->workspace; + const int rank = shm_data->rank; + const int world_size = shm_data->world_size; -#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD -#define BUFFER1_OFFSET(current_buffer) \ - 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE + for (int i = 0; i < world_size; i++) { + if (i == rank) { + continue; + } + volatile CollState* state_ptr = &(workspace[i]->states[state_group]); -thread_local struct allreduce_workspace** workspace; + while (true) { + volatile CollState cur_state = *state_ptr; + if (cur_state == state0 || cur_state == state1) { + break; + } + if (shm_data->shutdown) { + return; + } + } + } -// buffer for small messages, double buffer -thread_local char** symmetric_buffer[2]; -// buffer for large messages, double buffer -thread_local char** distributed_buffer[2]; + std::unique_lock lock(shm_data->m); + shm_data->wait_done = true; + lock.unlock(); + shm_data->cv.notify_one(); +} void wait_buffer_state_until_2( - int index, - enum coll_state state0, - enum coll_state state1, + CollState state0, + CollState state1, int state_group, - std::chrono::milliseconds timeout) { - volatile enum coll_state* state_ptr = - &(workspace[index]->states[state_group]); - - auto total_milliseconds = timeout.count(); - auto count = 0; - while (count < total_milliseconds) { - volatile enum coll_state cur_state = *state_ptr; - if (cur_state == state0 || cur_state == state1) { - break; - } - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - count += 10; - } - - volatile enum coll_state cur_state = *state_ptr; - if (!(cur_state == state0 || cur_state == state1)) { + std::chrono::milliseconds timeout, + std::shared_ptr shm_data) { + shm_data->wait_done = false; + shm_data->shutdown = false; + + // Create wait buffer thread. + std::thread t( + wait_buffer_state, state0, state1, state_group, timeout, shm_data); + + std::unique_lock lock(shm_data->m); + auto done = + shm_data->cv.wait_for(lock, timeout, [&] { return shm_data->wait_done; }); + if (!done) { + shm_data->shutdown = true; + t.join(); throw ::gloo::IoException(GLOO_ERROR_MSG( "Timed out waiting", timeout.count(), "ms for wait buffer state operation to complete")); + } else { + t.join(); } } @@ -139,6 +136,7 @@ void reduce_all_buffers( int num_elements, int element_size, int to_buffer_idx, + int world_size, char* to_buffer, char** buffers, ReductionFunction fn) { @@ -152,78 +150,6 @@ void reduce_all_buffers( } } -void shm_initialize( - int size, - int rank, - const char* addr_string, - const char* port_string) { - world_size = size; - world_rank = rank; - - char shm_name_prefix[NAME_BUF_SIZE]; - char shm_name[NAME_BUF_SIZE]; - snprintf( - shm_name_prefix, - NAME_BUF_SIZE, - "%s_%d_%s_%s", - SHM_BUFFER_NAME, - getuid(), - addr_string, - port_string); - // create shared workspace for SHM based allreduce - SharedData allreduce_buffer; - // allocate workspace_buf for current rank - struct allreduce_workspace* workspace_buf; - struct allreduce_workspace* workspace_buf_other; - workspace_buf = - (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace)); - int written = - snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); - if (written >= NAME_BUF_SIZE) { - std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; - } - shared_create( - &allreduce_buffer, - shm_name, - workspace_buf, - sizeof(struct allreduce_workspace)); - workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; - workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done; - workspace_buf->states[1] = coll_begin; - - // create the workspace pointer list - workspace = (struct allreduce_workspace**)malloc( - size * sizeof(struct allreduce_workspace*)); - symmetric_buffer[0] = (char**)malloc(size * sizeof(char**)); - symmetric_buffer[1] = (char**)malloc(size * sizeof(char**)); - distributed_buffer[0] = (char**)malloc(size * sizeof(char**)); - distributed_buffer[1] = (char**)malloc(size * sizeof(char**)); - - // map shm of all ranks - for (int i = 0; i < size; i++) { - if (i != rank) { - int written = - snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); - if (written >= NAME_BUF_SIZE) { - std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; - } - // printf("open %s, %d\n", shm_name, rank); - do { - shared_open( - &allreduce_buffer, shm_name, sizeof(struct allreduce_workspace)); - } while (allreduce_buffer.descriptor == -1 && errno == ENOENT); - workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes; - workspace[i] = workspace_buf_other; - } else { - workspace[i] = workspace_buf; - } - symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0); - symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1); - distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0); - distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1); - } -} - static void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw"))); static void parallel_memcpy(void* to, void* from, size_t n_bytes) { @@ -241,21 +167,24 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes) { } } -#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod)) -#define rank_mod(rank) positive_mod(rank, world_size) -size_t slice_size(size_t chunk_el, int slice_idx) { +size_t slice_size(size_t chunk_el, int slice_idx, int world_size) { size_t slice_size = chunk_el / world_size; return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size; } -char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) { +char* slice_data( + char* data_ptr, + size_t chunk_el, + int el_size, + int slice_idx, + int world_size) { size_t slice_size = chunk_el / world_size; size_t el_offset = slice_size * slice_idx; return data_ptr + el_offset * el_size; } -size_t slice_el_start(size_t chunk_el, int slice_idx) { +size_t slice_el_start(size_t chunk_el, int slice_idx, int world_size) { size_t slice_size = chunk_el / world_size; return slice_size * slice_idx; } @@ -265,44 +194,45 @@ void symmetric_naive_all_reduce( int element_size, size_t chunk_size, size_t chunk_el, - ReductionFunction fn, - std::chrono::milliseconds timeout) { + const detail::AllreduceOptionsImpl& opts) { + const auto& context = opts.context; + auto& shm_data = context->shmData; + const int rank = shm_data->rank; + const int world_size = shm_data->world_size; + auto symmetric_buffer = shm_data->symmetric_buffer; + auto workspace = shm_data->workspace; + auto& state_idx = shm_data->state_idx; + auto& current_buffer = shm_data->current_buffer; + const int state_group = 0; - thread_local static int current_buffer = 0; - thread_local static int state_idx = 0; - enum coll_state copy_current, copy_next; + CollState copy_current, copy_next; switch (state_idx) { case 0: - copy_current = coll_allreduce_naive__copy_in_done; - copy_next = coll_alt1_allreduce_naive__copy_in_done; + copy_current = CollState::coll_allreduce_naive__copy_in_done; + copy_next = CollState::coll_alt1_allreduce_naive__copy_in_done; break; case 1: - copy_current = coll_alt1_allreduce_naive__copy_in_done; - copy_next = coll_alt2_allreduce_naive__copy_in_done; + copy_current = CollState::coll_alt1_allreduce_naive__copy_in_done; + copy_next = CollState::coll_alt2_allreduce_naive__copy_in_done; break; case 2: - copy_current = coll_alt2_allreduce_naive__copy_in_done; - copy_next = coll_allreduce_naive__copy_in_done; + copy_current = CollState::coll_alt2_allreduce_naive__copy_in_done; + copy_next = CollState::coll_allreduce_naive__copy_in_done; break; default: assert(!"Should not get here."); } state_idx = (state_idx + 1) % 3; - parallel_memcpy( - symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size); + parallel_memcpy(symmetric_buffer[current_buffer][rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->states[state_group] = copy_current; + workspace[rank]->states[state_group] = copy_current; - for (int i = 0; i < world_size; i++) { - // wait until the other rank copy the buffer - if (i != world_rank) { - wait_buffer_state_until_2( - i, copy_current, copy_next, state_group, timeout); - } - } + wait_buffer_state_until_2( + copy_current, copy_next, state_group, opts.timeout, shm_data); // each rank reduce the buffer independently so therre is no need for // synchronization afterward @@ -310,10 +240,11 @@ void symmetric_naive_all_reduce( 0, chunk_el, element_size, - world_rank, + rank, + world_size, data_ptr, symmetric_buffer[current_buffer], - fn); + opts.reduce); // switch buffer current_buffer = 1 - current_buffer; @@ -325,27 +256,33 @@ void distributed_naive_reduce( int element_size, size_t chunk_size, size_t chunk_el, - ReductionFunction fn, - std::chrono::milliseconds timeout) { + const detail::AllreduceOptionsImpl& opts) { + const auto& context = opts.context; + auto& shm_data = context->shmData; + const int rank = shm_data->rank; + const int world_size = shm_data->world_size; + auto distributed_buffer = shm_data->distributed_buffer; + auto workspace = shm_data->workspace; + auto& state_idx = shm_data->state_idx; + auto& current_buffer = shm_data->current_buffer; + const int state_group = 1; - thread_local static int current_buffer = 0; - thread_local static int state_idx = 0; - enum coll_state copy_current, copy_next, reduce_current; + CollState copy_current, copy_next, reduce_current; // similar to symmetric_naive_allreduce, but here we only need two sets of // states, because distributed naive reduce has two barriers in the // algorithm switch (state_idx) { case 0: - copy_current = coll_allreduce_naive__copy_in_done; - reduce_current = coll_allreduce_naive__reduce_done; - copy_next = coll_alt1_allreduce_naive__copy_in_done; + copy_current = CollState::coll_allreduce_naive__copy_in_done; + reduce_current = CollState::coll_allreduce_naive__reduce_done; + copy_next = CollState::coll_alt1_allreduce_naive__copy_in_done; break; case 1: - copy_current = coll_alt1_allreduce_naive__copy_in_done; - reduce_current = coll_alt1_allreduce_naive__reduce_done; - copy_next = coll_allreduce_naive__copy_in_done; + copy_current = CollState::coll_alt1_allreduce_naive__copy_in_done; + reduce_current = CollState::coll_alt1_allreduce_naive__reduce_done; + copy_next = CollState::coll_allreduce_naive__copy_in_done; break; default: assert(!"Should not get here."); @@ -354,46 +291,40 @@ void distributed_naive_reduce( int data_size = chunk_size / chunk_el; parallel_memcpy( - distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size); + distributed_buffer[current_buffer][rank], data_ptr, chunk_size); std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->states[state_group] = copy_current; + workspace[rank]->states[state_group] = copy_current; - for (int i = 0; i < world_size; i++) { - // wait until all the other ranks copy the buffer - if (i != world_rank) - wait_buffer_state_until_2( - i, copy_current, reduce_current, state_group, timeout); - } + wait_buffer_state_until_2( + copy_current, reduce_current, state_group, opts.timeout, shm_data); // reduce scatter reduce_all_buffers( - slice_el_start(chunk_el, world_rank), - slice_size(chunk_el, world_rank), + slice_el_start(chunk_el, rank, world_size), + slice_size(chunk_el, rank, world_size), element_size, - world_rank, - distributed_buffer[current_buffer][world_rank], + rank, + world_size, + distributed_buffer[current_buffer][rank], distributed_buffer[current_buffer], - fn); + opts.reduce); std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->states[state_group] = reduce_current; + workspace[rank]->states[state_group] = reduce_current; - for (int i = 0; i < world_size; i++) { - // wait until all the other ranks reduce the buffer - if (i != world_rank) - wait_buffer_state_until_2( - i, reduce_current, copy_next, state_group, timeout); - } + wait_buffer_state_until_2( + copy_current, reduce_current, state_group, opts.timeout, shm_data); for (int i = 0; i < world_size; i++) { - int rank = (i + world_rank) % world_size; + int rank = (i + rank) % world_size; parallel_memcpy( - slice_data(data_ptr, chunk_el, data_size, rank), + slice_data(data_ptr, chunk_el, data_size, rank, world_size), slice_data( distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, - rank), - slice_size(chunk_el, rank) * data_size); + rank, + world_size), + slice_size(chunk_el, rank, world_size) * data_size); } current_buffer = 1 - current_buffer; @@ -401,28 +332,117 @@ void distributed_naive_reduce( } // namespace -void shm(const detail::AllreduceOptionsImpl& opts) { - const auto& context = opts.context; - if (!is_initialized) { - int size = context->size; - int rank = context->rank; - - world_size = size; - world_rank = rank; - is_initialized = true; - - std::string addr_string(""), port_string(""); - const auto& addr_string_env = std::getenv("MASTER_ADDR"); - if (addr_string_env != nullptr) { - addr_string = addr_string_env; +void AllreduceSharedMemoryData::initialize() { + std::string addr_string(""), port_string(""); + const auto& addr_string_env = std::getenv("MASTER_ADDR"); + if (addr_string_env != nullptr) { + addr_string = addr_string_env; + } + const auto port_string_env = std::getenv("MASTER_PORT"); + if (port_string_env != NULL) { + port_string = port_string_env; + } + + char shm_name_prefix[Allreduceworkspace::NAME_BUF_SIZE]; + char shm_name[Allreduceworkspace::NAME_BUF_SIZE]; + snprintf( + shm_name_prefix, + Allreduceworkspace::NAME_BUF_SIZE, + "%s_%d_%s_%s", + "shm_allreduce_buffer", + getuid(), + addr_string.c_str(), + port_string.c_str()); + // create shared workspace for SHM based allreduce + // allocate workspace_buf for current rank + AllreduceWorkspace* workspace_buf; + AllreduceWorkspace* workspace_buf_other; + SharedData allreduce_buffer; + cur_workspace = (AllreduceWorkspace*)malloc(sizeof(AllreduceWorkspace)); + workspace_buf = cur_workspace; + + int written = snprintf( + shm_name, + AllreduceWorkspace::NAME_BUF_SIZE, + "%s_%d", + shm_name_prefix, + rank); + if (written >= AllreduceWorkspace::NAME_BUF_SIZE) { + std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; + } + + shared_create( + &allreduce_buffer, shm_name, workspace_buf, sizeof(AllreduceWorkspace)); + + workspace_buf = (AllreduceWorkspace*)allreduce_buffer.bytes; + workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done; + workspace_buf->states[1] = coll_begin; + workspace_buf->fd = allreduce_buffer.descriptor; + strcpy(workspace_buf->name, shm_name); + + // create the workspace pointer list + workspace = + (AllreduceWorkspace**)malloc(world_size * sizeof(Allreduceworkspace*)); + symmetric_buffer[0] = (char**)malloc(world_size * sizeof(char**)); + symmetric_buffer[1] = (char**)malloc(world_size * sizeof(char**)); + distributed_buffer[0] = (char**)malloc(world_size * sizeof(char**)); + distributed_buffer[1] = (char**)malloc(world_size * sizeof(char**)); + + // map shm of all ranks + for (int i = 0; i < world_size; i++) { + if (i != rank) { + int written = snprintf( + shm_name, + AllreduceWorkspace::NAME_BUF_SIZE, + "%s_%d", + shm_name_prefix, + i); + if (written >= AllreduceWorkspace::NAME_BUF_SIZE) { + std::cout << "[warning]: written >= NAME_BUF_SIZE" << std::endl; + } + + do { + shared_open(&allreduce_buffer, shm_name, sizeof(AllreduceWorkspace)); + } while (allreduce_buffer.descriptor == -1 && errno == ENOENT); + workspace_buf_other = (AllreduceWorkspace*)allreduce_buffer.bytes; + workspace[i] = workspace_buf_other; + } else { + workspace[i] = workspace_buf; } - const auto port_string_env = std::getenv("MASTER_PORT"); - if (port_string_env != NULL) { - port_string = port_string_env; + symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0); + symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1); + distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0); + distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1); + } + is_initialized = true; +} + +AllreduceSharedMemoryData::~AllreduceSharedMemoryData() { + if (is_initialized == true) { + // unlink and munmap shared memory + for (int i = 0; i < world_size; i++) { + std::string shm_name = std::string(workspace[i]->name); + close(workspace[i]->fd); + munmap(workspace[i], sizeof(Allreduceworkspace)); + shm_unlink(shm_name.c_str()); } - shm_initialize(size, rank, addr_string.c_str(), port_string.c_str()); + + free(cur_workspace); + free(workspace); + free(symmetric_buffer[0]); + free(symmetric_buffer[1]); + free(distributed_buffer[0]); + free(distributed_buffer[1]); } +} +void shm(const detail::AllreduceOptionsImpl& opts) { + const auto& context = opts.context; + if (context->shmData == nullptr) { + context->shmData = std::make_shared( + context->rank, context->size); + context->shmData->initialize(); + } const size_t data_size = opts.elements * opts.elementSize; auto& in = opts.in; auto& out = opts.out; @@ -460,27 +480,19 @@ void shm(const detail::AllreduceOptionsImpl& opts) { void* data = out[0].get()->ptr; - for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { + for (int offset = 0; offset < data_size; + offset += Allreduceworkspace::MAX_BUF_SIZE) { auto data_ptr = ((char*)(data) + offset); - size_t chunk_size = - data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; + size_t chunk_size = data_size - offset > Allreduceworkspace::MAX_BUF_SIZE + ? Allreduceworkspace::MAX_BUF_SIZE + : data_size - offset; size_t chunk_el = chunk_size / (data_size / opts.elements); - if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { + if (chunk_size < Allreduceworkspace::NAIVE_ALLREDUCE_THRESHOLD) { symmetric_naive_all_reduce( - data_ptr, - opts.elementSize, - chunk_size, - chunk_el, - opts.reduce, - opts.timeout); + data_ptr, opts.elementSize, chunk_size, chunk_el, opts); } else { distributed_naive_reduce( - data_ptr, - opts.elementSize, - chunk_size, - chunk_el, - opts.reduce, - opts.timeout); + data_ptr, opts.elementSize, chunk_size, chunk_el, opts); } } diff --git a/gloo/allreduce_shm.h b/gloo/allreduce_shm.h index 6cf3ccc44..a7b92b1c3 100644 --- a/gloo/allreduce_shm.h +++ b/gloo/allreduce_shm.h @@ -1,7 +1,69 @@ + +#pragma once + +#include +#include + #include "gloo/allreduce.h" namespace gloo { +struct AllreduceSharedMemoryData { + enum CollState { + coll_begin = 0, + coll_allreduce_naive__copy_in_done, + coll_allreduce_naive__reduce_done, + // alternative state when allreduce is working on alternative buffer + // of the double buffer. + coll_alt1_allreduce_naive__copy_in_done, + coll_alt2_allreduce_naive__copy_in_done, + coll_alt1_allreduce_naive__reduce_done, + }; + + struct AllreduceWorkspace { + static constexpr size_t MAX_BUF_SIZE = 1048576 * 32; + static constexpr size_t NAIVE_ALLREDUCE_THRESHOLD = 1048576; + static constexpr int NAME_BUF_SIZE = 1000; + + int fd; + enum CollState states[2]; // idx=0 -- state for symmetric_naive_all_reduce + // idx=1 -- state for distributed_naive_all_reduce + // double buffer to avoid syncing between rounds + // offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for + // symmetric_naive_all_reduce after that : buffer for + // distributed_naive_all_reduce + char name[NAME_BUF_SIZE]; + char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE]; + }; + + AllreduceSharedMemoryData(int rank, int world_size) + : rank(rank), + world_size(world_size), + current_buffer(0), + state_idx(0), + is_initialized(false) {} + ~AllreduceSharedMemoryData(); + void initialize(); + + int rank; + int world_size; + int current_buffer; + int state_idx; + bool is_initialized; + + AllreduceWorkspace* cur_workspace; + AllreduceWorkspace** workspace; + // buffer for small messages, double buffer + char** symmetric_buffer[2]; + // buffer for large messages, double buffer + char** distributed_buffer[2]; + + std::mutex m; + std::condition_variable cv; + bool wait_done; + bool shutdown; +}; + void shm(const detail::AllreduceOptionsImpl& opts); } // namespace gloo diff --git a/gloo/context.h b/gloo/context.h index c7844e8ae..c114b8866 100644 --- a/gloo/context.h +++ b/gloo/context.h @@ -24,6 +24,8 @@ class Device; class UnboundBuffer; } // namespace transport +class AllreduceSharedMemoryData; + class Context { public: Context(int rank, int size, int base = 2); @@ -33,6 +35,8 @@ class Context { const int size; int base; + std::shared_ptr shmData; + std::shared_ptr& getDevice(); std::unique_ptr& getPair(int i); diff --git a/gloo/transport/context.cc b/gloo/transport/context.cc index edd3729f3..4b517eba5 100644 --- a/gloo/transport/context.cc +++ b/gloo/transport/context.cc @@ -49,7 +49,7 @@ void Context::createAndConnectAllPairs(std::shared_ptr store) { } std::string key("rank_" + std::to_string(i)); - auto val = store->get(key); + auto val = store->wait_get(key, getTimeout()); auto hostName = std::string((const char*)val.data(), val.size()); if (hostName == localHostName) {