Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom all reduce kernels #2192

Merged
merged 61 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
3cc29ca
a working impl & squash commits
hanzhi713 Dec 18, 2023
89f8b97
add missing cuda free
hanzhi713 Dec 19, 2023
53ed0f9
link driver
hanzhi713 Dec 19, 2023
39000b8
add more notes
hanzhi713 Dec 19, 2023
f4dc283
add todo
hanzhi713 Dec 19, 2023
2f49454
add flag and format
hanzhi713 Dec 19, 2023
16447e5
fix
hanzhi713 Dec 19, 2023
60a51f2
fix arg passing
hanzhi713 Dec 19, 2023
b82c9d5
Merge branch 'main' into fast_ar_sq
hanzhi713 Dec 20, 2023
2150a90
trailing comma
hanzhi713 Dec 20, 2023
cd898ba
use pytest for fast allreduce test
hanzhi713 Dec 20, 2023
41115c7
Merge branch 'main' into fast_ar_sq
hanzhi713 Dec 20, 2023
15672a9
small refactor
hanzhi713 Dec 21, 2023
ad7d220
cleanup code add verify correctness
hanzhi713 Dec 21, 2023
da5772e
improve test robustness
hanzhi713 Dec 21, 2023
7644f7c
Apply suggestions from code review
hanzhi713 Dec 26, 2023
4096c9d
use context manager
hanzhi713 Dec 27, 2023
af015e7
add p2p check
hanzhi713 Dec 27, 2023
7f11bc5
address review
hanzhi713 Dec 27, 2023
47653b5
do not reinit
hanzhi713 Dec 27, 2023
d7ee1ad
Merge branch 'main' into fast_ar_sq
hanzhi713 Jan 4, 2024
78546de
format
hanzhi713 Jan 4, 2024
dbae9b0
Merge branch 'main' into fast_ar_sq
hanzhi713 Jan 15, 2024
c8d41b3
fix tests and format
hanzhi713 Jan 15, 2024
364c06e
add a few more comments
hanzhi713 Jan 15, 2024
8884de8
use untyped storage
hanzhi713 Jan 19, 2024
cce7c98
move test utils
hanzhi713 Jan 19, 2024
bdbf4e0
Merge branch 'main' into fast_ar_sq
hanzhi713 Jan 19, 2024
6b85e42
rename to custom all reduce
hanzhi713 Jan 20, 2024
6ae050e
add support for eager mode
hanzhi713 Jan 20, 2024
73ab0a8
format
hanzhi713 Jan 20, 2024
7652c4b
Merge branch 'main' into fast_ar_sq
hanzhi713 Jan 20, 2024
bedf60e
add comment
hanzhi713 Jan 20, 2024
8626b8c
move function
hanzhi713 Jan 20, 2024
7da0723
format
hanzhi713 Jan 20, 2024
2ab52d0
add comments
hanzhi713 Jan 20, 2024
dcf2735
fix name
hanzhi713 Jan 20, 2024
21f2fcc
Minor fixes on comments
WoosukKwon Jan 22, 2024
50cc5f8
Don't compile for ROCm backend
WoosukKwon Jan 22, 2024
0581f4e
Minor fix for ROCm backend
WoosukKwon Jan 22, 2024
704416a
Use context manager for NVML
WoosukKwon Jan 22, 2024
602930a
Minor fix for long comment
WoosukKwon Jan 22, 2024
3bfd8fa
Minor
WoosukKwon Jan 23, 2024
d8f92bc
Add library=cuda
WoosukKwon Jan 23, 2024
b4711a1
Skip ops for ROCm backend
WoosukKwon Jan 23, 2024
84ee019
Minor
WoosukKwon Jan 23, 2024
8fbb2aa
Fix can_p2p
WoosukKwon Jan 23, 2024
c5b4212
Apply suggestions from code review
hanzhi713 Jan 23, 2024
60e013a
add notes
hanzhi713 Jan 23, 2024
10a906e
add size check
hanzhi713 Jan 23, 2024
b896fbd
Update csrc/custom_all_reduce.cuh
hanzhi713 Jan 23, 2024
aae30fb
Apply suggestions from code review
hanzhi713 Jan 23, 2024
627a49f
grammar
hanzhi713 Jan 23, 2024
c7e3704
move test to c++
hanzhi713 Jan 23, 2024
036bb68
add warnings and do few renames
hanzhi713 Jan 24, 2024
c6367aa
Merge branch 'main' into fast_ar_sq
hanzhi713 Jan 24, 2024
e1e802e
Fix custom all reduce tests
WoosukKwon Jan 27, 2024
bbfc263
Move test_utils to tests/distributed/utils
WoosukKwon Jan 27, 2024
75efb8a
Merge branch 'main' into fast_ar_sq
WoosukKwon Jan 27, 2024
6f61347
Minor
WoosukKwon Jan 27, 2024
c09772c
Roll back to test_utils
WoosukKwon Jan 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions csrc/custom_all_reduce.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>

#include "custom_all_reduce.cuh"

// fake pointer type
using fptr_t = uint64_t;
static_assert(sizeof(void *) == sizeof(fptr_t));

fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets, int rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");

cudaIpcMemHandle_t ipc_handles[8];
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
}
return (fptr_t) new vllm::FastAllreduce(
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}

/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor &t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
}

void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
cudaStream_t stream) {
auto fa = reinterpret_cast<vllm::FastAllreduce *>(_fa);
TORCH_CHECK(_is_weak_contiguous(inp));
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
reinterpret_cast<float *>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()),
out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}

void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}

void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
torch::Tensor &out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();

auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
TORCH_CHECK(_is_weak_contiguous(inp));
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, cudaMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}

void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::FastAllreduce *>(_fa);
delete fa;
}

int meta_size() { return sizeof(vllm::Metadata); }

void register_buffer(fptr_t _fa, torch::Tensor &t,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets) {
auto fa = reinterpret_cast<vllm::FastAllreduce *>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}

std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<vllm::FastAllreduce *>(_fa);
return fa->get_graph_buffer_ipc_meta();
}

void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets) {
auto fa = reinterpret_cast<vllm::FastAllreduce *>(_fa);
fa->register_graph_buffers(handles, offsets);
}
Loading
Loading