Skip to content

Commit 247348e

Browse files
author
ilmarkov
committed
Update test. Disable QR by default. Set fp16 ovfl flag.
Signed-off-by: ilmarkov <imarkov@redhat.com>
1 parent f8bf2e9 commit 247348e

File tree

4 files changed

+81
-49
lines changed

4 files changed

+81
-49
lines changed

csrc/quickreduce/quick_reduce_impl.cuh

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ struct CodecBase {
1212
__quickreduce_device_inline__ CodecBase(int thread, int rank)
1313
: thread(thread),
1414
rank(rank),
15-
group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {}
15+
group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {
16+
set_fp16_ovfl(true);
17+
}
1618
};
1719

1820
// Default full precision codec.
@@ -98,9 +100,7 @@ struct CodecQ4 : public CodecBase {
98100
static constexpr int kRangeBias = 0x00080008;
99101

100102
__quickreduce_device_inline__ CodecQ4(int thread, int rank)
101-
: CodecBase(thread, rank) {
102-
set_fp16_ovfl(true);
103-
}
103+
: CodecBase(thread, rank) {}
104104

105105
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
106106
const int32x4_t* __restrict__ data) {
@@ -253,9 +253,7 @@ struct CodecQ6 : public CodecBase {
253253
static constexpr int kRangeBias = 0x00200020;
254254

255255
__quickreduce_device_inline__ CodecQ6(int thread, int rank)
256-
: CodecBase(thread, rank) {
257-
set_fp16_ovfl(true);
258-
}
256+
: CodecBase(thread, rank) {}
259257

260258
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
261259
const int32x4_t* __restrict__ data) {
@@ -431,9 +429,7 @@ struct CodecQ8 : public CodecBase {
431429
static constexpr int kRangeBias = 0x00800080;
432430

433431
__quickreduce_device_inline__ CodecQ8(int thread, int rank)
434-
: CodecBase(thread, rank) {
435-
set_fp16_ovfl(true);
436-
}
432+
: CodecBase(thread, rank) {}
437433

438434
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
439435
int32x4_t const* __restrict__ data) {

tests/distributed/test_custom_all_reduce.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import os
45
import random
56

67
import pytest
@@ -86,7 +87,7 @@ def graph_allreduce(
8687

8788

8889
@ray.remote(num_gpus=1, max_calls=1)
89-
def eager_allreduce(
90+
def eager_custom_allreduce(
9091
monkeypatch: pytest.MonkeyPatch,
9192
tp_size,
9293
pp_size,
@@ -111,19 +112,51 @@ def eager_allreduce(
111112
inp = torch.ones(sz, dtype=torch.float32, device=device)
112113
out = inp
113114
for _ in range(num_communication):
114-
out = fa.all_reduce(out, registered=False)
115+
out = fa.ca_all_reduce(out, registered=False)
115116
torch.testing.assert_close(out, inp * (tp_size**num_communication))
116117

117118
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
118119
out = inp
119120
for _ in range(num_communication):
120-
out = fa.all_reduce(out, registered=False)
121+
out = fa.ca_all_reduce(out, registered=False)
121122
torch.testing.assert_close(out, inp * (tp_size**num_communication))
122123

123124

125+
@ray.remote(num_gpus=1, max_calls=1)
126+
def eager_quickreduce(
127+
monkeypatch: pytest.MonkeyPatch,
128+
tp_size,
129+
pp_size,
130+
rank,
131+
distributed_init_port,
132+
):
133+
with monkeypatch.context() as m:
134+
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
135+
os.environ["VLLM_ROCM_QR_QUANT_REGIME"] = "FP"
136+
device = torch.device(f"cuda:{rank}")
137+
torch.cuda.set_device(device)
138+
init_test_distributed_environment(tp_size, pp_size, rank,
139+
distributed_init_port)
140+
141+
sz = 1024 * 1024
142+
fa = get_tp_group().device_communicator.ca_comm
143+
inp = torch.ones(sz, dtype=torch.float16, device=device)
144+
out = inp
145+
out = fa.qr_all_reduce(out)
146+
torch.testing.assert_close(out, inp * tp_size)
147+
148+
sz = 1024 * 1024
149+
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
150+
out = inp
151+
out = fa.qr_all_reduce(out)
152+
torch.testing.assert_close(out, inp * tp_size)
153+
154+
124155
@pytest.mark.parametrize("tp_size", [2])
125156
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
126-
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
157+
@pytest.mark.parametrize(
158+
"test_target",
159+
[eager_custom_allreduce, graph_allreduce, eager_quickreduce])
127160
def test_custom_allreduce(
128161
monkeypatch: pytest.MonkeyPatch,
129162
tp_size,

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ class CustomAllreduce:
7171
# TODO: We should set a reasonable range for FP.
7272
MB = 1024 * 1024
7373
_QR_MIN_SIZE = {
74-
(torch.float16, 2): [16 * MB, 2 * MB, 2 * MB, 1 * MB],
75-
(torch.float16, 4): [16 * MB, 64 * MB, 4 * MB, 2 * MB],
74+
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
75+
(torch.float16, 4): [1 * MB, 64 * MB, 4 * MB, 2 * MB],
7676
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
77-
(torch.bfloat16, 2): [16 * MB, 8 * MB, 8 * MB, 8 * MB],
78-
(torch.bfloat16, 4): [16 * MB, 128 * MB, 128 * MB, 16 * MB],
77+
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
78+
(torch.bfloat16, 4): [8 * MB, 128 * MB, 128 * MB, 16 * MB],
7979
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
8080
}
8181

@@ -256,40 +256,43 @@ def init_custom_quick_allreduce(self):
256256
Initialize a custom quick allreduce implementation for AMD
257257
based on quick reduce (https://github.com/mk1-project/quickreduce).
258258
"""
259+
if not self._QR_SHOULD_INIT:
260+
return
261+
self.use_fp16_kernels = envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16
262+
regime_str = envs.VLLM_ROCM_QR_QUANT_REGIME
263+
if regime_str not in QuickReduceRegime.__members__:
264+
logger.warning(
265+
"Custom quick allreduce:",
266+
f"Invalid quantization level: {regime_str}. "
267+
"Supported levels: "
268+
f"{list(QuickReduceRegime.__members__.keys())}")
269+
return
270+
271+
if regime_str == "NONE":
272+
logger.debug("Custom quick allreduce is disabled based "
273+
"on env variable VLLM_ROCM_QR_QUANT_REGIME")
274+
return
275+
259276
vllm_config = get_current_vllm_config()
260-
# for test mode
261-
if vllm_config is not None and hasattr(vllm_config, "model_config"):
277+
if vllm_config is not None and \
278+
hasattr(vllm_config, "model_config") and \
279+
hasattr(vllm_config.model_config, "dtype"):
262280
dtype = vllm_config.model_config.dtype
263281
if dtype not in [torch.float16, torch.bfloat16]:
264282
self._QR_SHOULD_INIT = False
265-
# On RocM bfloat16 kernels are slower than fp16
266-
# due to slower match operations
267-
# If environment is not set to 1 we convert input to fp16
268-
self.use_fp16_kernels: bool = envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16
269-
regime_str = envs.VLLM_ROCM_QR_QUANT_REGIME
270-
if self._QR_SHOULD_INIT:
271-
if regime_str not in QuickReduceRegime.__members__:
272-
logger.warning(
273-
"Custom quick allreduce:",
274-
f"Invalid quantization level: {regime_str}. "
275-
"Supported levels: "
276-
f"{list(QuickReduceRegime.__members__.keys())}")
277-
return
278-
279-
if regime_str == "NONE":
280-
logger.debug("Custom quick allreduce is disabled based "
281-
"on env variable VLLM_ROCM_QR_QUANT_REGIME")
282-
return
283-
284-
self.qr_quant_level = QuickReduceRegime[regime_str]
285-
self._qr_ptr = ops.init_custom_qr(self.rank, self.world_size)
286-
self.create_qr_shared_buffer()
283+
# On RocM bfloat16 kernels are slower than fp16
284+
# due to slower match operations
285+
# If environment variable is not set to 1 we convert input to fp16
287286
if dtype == torch.bfloat16 and not self.use_fp16_kernels:
288287
logger.info(
289288
"Custom quick allreduce: converting bf16 inputs to "
290289
"fp16 can improve performance"
291290
"set envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=1 to turn on.")
292-
self.qr_disabled = False
291+
292+
self.qr_quant_level = QuickReduceRegime[regime_str]
293+
self._qr_ptr = ops.init_custom_qr(self.rank, self.world_size)
294+
self.create_qr_shared_buffer()
295+
self.qr_disabled = False
293296

294297
@contextmanager
295298
def capture(self):
@@ -346,7 +349,7 @@ def should_quick_allreduce(self, inp: torch.Tensor):
346349
if self.use_fp16_kernels:
347350
dtype = torch.float16
348351
return inp_size <= self.qr_max_size and \
349-
inp_size > self._QR_MIN_SIZE[(dtype, self.world_size)]\
352+
inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\
350353
[self.qr_quant_level.value]
351354

352355
def should_custom_allreduce(self, inp: torch.Tensor):
@@ -369,7 +372,7 @@ def should_custom_ar(self, inp: torch.Tensor):
369372
return self.should_quick_allreduce(
370373
inp) or self.should_custom_allreduce(inp)
371374

372-
def cr_all_reduce(self,
375+
def ca_all_reduce(self,
373376
inp: torch.Tensor,
374377
*,
375378
out: torch.Tensor = None,
@@ -411,7 +414,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
411414
if self.should_custom_allreduce(input):
412415
if self._IS_CAPTURING:
413416
if torch.cuda.is_current_stream_capturing():
414-
return self.cr_all_reduce(input, registered=True)
417+
return self.ca_all_reduce(input, registered=True)
415418
else:
416419
# If warm up, mimic the allocation pattern since custom
417420
# allreduce is out-of-place.
@@ -421,7 +424,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
421424
# incurs a cost of cudaMemcpy, which should be small
422425
# (<=1% of overall latency) compared to the performance
423426
# gain of using custom kernels
424-
return self.cr_all_reduce(input, registered=False)
427+
return self.ca_all_reduce(input, registered=False)
425428

426429
return None
427430

vllm/envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@
129129
VLLM_SLEEP_WHEN_IDLE: bool = False
130130
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
131131
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
132-
VLLM_ROCM_QR_QUANT_REGIME: str = "FP"
132+
VLLM_ROCM_QR_QUANT_REGIME: str = "NONE"
133133
VLLM_ROCM_QR_CAST_BF16_TO_FP16: bool = False
134134

135135

@@ -677,7 +677,7 @@ def get_vllm_port() -> Optional[int]:
677677
# Choice of quantization level: FP, INT8, INT6, INT4 or NONE
678678
# Recommended for large models to get allreduce
679679
"VLLM_ROCM_QR_QUANT_REGIME":
680-
lambda: os.getenv("VLLM_ROCM_QR_QUANT_REGIME", "FP").upper(),
680+
lambda: os.getenv("VLLM_ROCM_QR_QUANT_REGIME", "NONE").upper(),
681681

682682
# Custom quick allreduce kernel for MI3* cards
683683
# Due to the lack of the bfloat16 asm instruction, bfloat16

0 commit comments

Comments
 (0)