Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix][distributed] fix shm broadcast when the queue size is full #5801

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions tests/distributed/test_shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import multiprocessing
import random
import time
from typing import List

import numpy as np
import torch.distributed as dist

from vllm.distributed.device_communicators.shm_broadcast import (
ShmRingBuffer, ShmRingBufferIO)
from vllm.utils import update_environment_variables


def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
np.random.seed(seed)
sizes = np.random.randint(1, 10_000, n)
# on average, each array will have 5k elements
# with int64, each array will have 40kb
return [np.random.randint(1, 100, i) for i in sizes]


def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
Expand Down Expand Up @@ -47,24 +57,31 @@ def wrapped_fn(env):
def worker_fn():
writer_rank = 2
broadcaster = ShmRingBufferIO.create_from_process_group(
dist.group.WORLD, 1024, 2, writer_rank)
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
dist.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {dist.get_rank()} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if dist.get_rank() == writer_rank:
time.sleep(random.random())
broadcaster.broadcast_object(0)
time.sleep(random.random())
broadcaster.broadcast_object({})
time.sleep(random.random())
broadcaster.broadcast_object([])
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
time.sleep(random.random())
a = broadcaster.broadcast_object(None)
time.sleep(random.random())
b = broadcaster.broadcast_object(None)
time.sleep(random.random())
c = broadcaster.broadcast_object(None)
assert a == 0
assert b == {}
assert c == []
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
dist.barrier()


Expand Down
73 changes: 43 additions & 30 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
# if we sleep for too long, it will slow down the writer/reader
# 0.1 us is a good balance
RINGBUFFER_SLEEP_INTERVAL = 1e-7

logger = init_logger(__name__)


Expand Down Expand Up @@ -145,28 +151,29 @@ def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
@contextmanager
def acquire_write(self):
assert self._is_writer, "Only writers can acquire write"
start_index = self.current_idx
start_time = time.time()
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
# this block is written and not read by all readers
# try to write to the next block
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
if self.current_idx == start_index:
# no empty block found
if time.time(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# wait for a while (0.1 us)
time.sleep(1e-7)
# for writers, `self.current_idx` is the next block to write
# if this block is not ready to write,
# we need to wait until it is read by all readers

# wait for a while
time.sleep(RINGBUFFER_SLEEP_INTERVAL)

# if we wait for a long time, we should warn the user
if time.monotonic(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1

continue
# found a block that is either
# (1) not written
Expand All @@ -188,13 +195,14 @@ def acquire_write(self):
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break

@contextmanager
def acquire_read(self):
assert self._is_reader, "Only readers can acquire read"
start_index = self.current_idx
start_time = time.time()
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
Expand All @@ -204,19 +212,22 @@ def acquire_read(self):
# this block is either
# (1) not written
# (2) already read by this reader
# try to read the next block
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
if self.current_idx == start_index:
# no block found
if time.time(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# wait for a while (0.1 us)
time.sleep(1e-7)

# for readers, `self.current_idx` is the next block to read
# if this block is not ready,
# we need to wait until it is written

# wait for a while
time.sleep(RINGBUFFER_SLEEP_INTERVAL)

# if we wait for a long time, we should warn the user
if time.monotonic(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1

continue
# found a block that is not read by this reader
# let caller read from the buffer
Expand All @@ -226,6 +237,8 @@ def acquire_read(self):
# caller has read from the buffer
# set the read flag
metadata_buffer[self.reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break

def enqueue(self, obj):
Expand Down
Loading