Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a3a7de7
Add quickreduce as alternative to custom allreduce
Apr 17, 2025
ad8ec75
WIP
May 5, 2025
f4c03ad
Add bf16 support
May 13, 2025
00c6d18
WIP
May 15, 2025
b537c0d
Refactor QuickReduce
May 19, 2025
1bff2fa
Cleanup
May 20, 2025
69f04dc
Remove config param. Add faster low latency OneShot algo
Jun 3, 2025
7effd4d
Some fixes
Jun 3, 2025
6970671
fix bfloat16 recv
lihaoyang-amd Jun 10, 2025
343b272
fix env
lihaoyang-amd Jun 11, 2025
ecd85a0
fix log info
lihaoyang-amd Jun 11, 2025
bd921a5
for env
lihaoyang-amd Jun 11, 2025
f21d4ca
Add int8 quantization. Remove changes to custom_allreduce
Jun 13, 2025
6a0d8b0
Update after review comments
Jun 13, 2025
a2f2922
add Q6 support
lihaoyang-amd Jun 13, 2025
87949fa
Adjusted to static constexpr int
lihaoyang-amd Jun 13, 2025
8eb9e62
Remove useless functions
lihaoyang-amd Jun 13, 2025
0425ac5
fix max size err
lihaoyang-amd Jun 16, 2025
20fc13b
adjust for comments
lihaoyang-amd Jun 16, 2025
982400b
integrate_qr2cr
lihaoyang-amd Jun 16, 2025
af265c1
fix message size
lihaoyang-amd Jun 16, 2025
ff506e1
Fix fp 2GB bug
Jun 16, 2025
796be62
adjust condition
lihaoyang-amd Jun 17, 2025
f524aad
fix vll_config
lihaoyang-amd Jun 17, 2025
41907b1
change comment
lihaoyang-amd Jun 17, 2025
776030b
Update test. Disable QR by default. Set fp16 ovfl flag.
Jun 17, 2025
db3f1d3
Fix CodecQ4
Jun 17, 2025
deb72c6
Update min sizes
Jun 17, 2025
ab99dfd
fix Q4
lihaoyang-amd Jun 18, 2025
0bf6342
move bf2fp to cpp
lihaoyang-amd Jun 19, 2025
ce2b715
fix compile err
lihaoyang-amd Jun 19, 2025
25f8e40
fix qr for cuda
lihaoyang-amd Jun 20, 2025
210358d
fix f-string
lihaoyang-amd Jun 20, 2025
0bada3c
adjust test case for quick allreduce
lihaoyang-amd Jun 21, 2025
2173c38
del TODO and rebase
lihaoyang-amd Jun 23, 2025
a2dd7bd
Optimized format
lihaoyang-amd Jun 24, 2025
816cf2d
add test for multi modes
lihaoyang-amd Jun 24, 2025
876dbec
for fmt
lihaoyang-amd Jun 24, 2025
42a0bdb
Adjustable max_size
lihaoyang-amd Jun 24, 2025
e40a61d
go back to splitting
lihaoyang-amd Jun 25, 2025
a02b2ef
change default of max_size to None
lihaoyang-amd Jun 25, 2025
03f6163
adjust name of var
lihaoyang-amd Jun 25, 2025
2b52580
restore custom allreduce
lihaoyang-amd Jun 25, 2025
a5d7963
check rocm for qr
lihaoyang-amd Jun 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# if CUDA endif
endif()

if (VLLM_GPU_LANG STREQUAL "HIP")
# Add QuickReduce kernels
list(APPEND VLLM_EXT_SRC
"csrc/custom_quickreduce.cu"
)
# if ROCM endif
endif()

message(STATUS "Enabling C extension.")
define_gpu_extension_target(
_C
Expand Down
114 changes: 114 additions & 0 deletions csrc/custom_quickreduce.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>

#ifdef USE_ROCM

#include "quickreduce/quick_reduce.h"

quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size) {
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size == 6)
throw std::invalid_argument("world size == 6 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
fptr->init(world_size, rank, qr_max_size);
return (quickreduce::fptr_t)fptr;
}

void qr_destroy(quickreduce::fptr_t _fa) {
if (_fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
fa->destroy();
delete fa;
}
}

torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
hipIpcMemHandle_t handle = fa->get_handle();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
return data_handle;
}

void qr_open_handles(quickreduce::fptr_t _fa,
const std::vector<torch::Tensor>& handles) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
std::vector<hipIpcMemHandle_t> ipc_handles;
ipc_handles.reserve(handles.size());
for (auto& handle : handles) {
// Ensure the tensor is on the same device as the current device.
hipIpcMemHandle_t ipc_handle;
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
ipc_handles.push_back(ipc_handle);
}
fa->open_ipc_handles(ipc_handles);
}

void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp,
torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();

TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
if (out.scalar_type() == at::ScalarType::Half) {
fa->allreduce<half, false>(reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
out.numel(), quant_level, stream);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
if (cast_bf2half) {
fa->allreduce<half, true>(reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
out.numel(), quant_level, stream);
} else {
fa->allreduce<quickreduce::nv_bfloat16, false>(
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
out.numel(), quant_level, stream);
}
} else {
throw std::runtime_error(
"quick allreduce only supports float16 and bfloat16");
}
}

int64_t qr_max_size() {
// The default is 2GB (2,147,483,648 bytes)
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
}

#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, \
cast_bf2half>; \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, \
cast_bf2half>; \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;

INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)

INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)

#endif // USE_ROCM
11 changes: 11 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,14 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);

#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size = std::nullopt);
void qr_destroy(fptr_t _fa);
torch::Tensor qr_get_handle(fptr_t _fa);
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#endif
Loading