Skip to content

Commit

Permalink
[bugfix][distributed] fix shm broadcast when the queue size is full (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and prashantgupta24 committed Jul 1, 2024
1 parent 58ba441 commit 9cccdc8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 46 deletions.
49 changes: 33 additions & 16 deletions tests/distributed/test_shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import multiprocessing
import random
import time
from typing import List

import numpy as np
import torch.distributed as dist

from vllm.distributed.device_communicators.shm_broadcast import (
ShmRingBuffer, ShmRingBufferIO)
from vllm.utils import update_environment_variables


def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
np.random.seed(seed)
sizes = np.random.randint(1, 10_000, n)
# on average, each array will have 5k elements
# with int64, each array will have 40kb
return [np.random.randint(1, 100, i) for i in sizes]


def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
Expand Down Expand Up @@ -47,24 +57,31 @@ def wrapped_fn(env):
def worker_fn():
writer_rank = 2
broadcaster = ShmRingBufferIO.create_from_process_group(
dist.group.WORLD, 1024, 2, writer_rank)
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
dist.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {dist.get_rank()} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if dist.get_rank() == writer_rank:
time.sleep(random.random())
broadcaster.broadcast_object(0)
time.sleep(random.random())
broadcaster.broadcast_object({})
time.sleep(random.random())
broadcaster.broadcast_object([])
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
time.sleep(random.random())
a = broadcaster.broadcast_object(None)
time.sleep(random.random())
b = broadcaster.broadcast_object(None)
time.sleep(random.random())
c = broadcaster.broadcast_object(None)
assert a == 0
assert b == {}
assert c == []
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
dist.barrier()


Expand Down
73 changes: 43 additions & 30 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
# if we sleep for too long, it will slow down the writer/reader
# 0.1 us is a good balance
RINGBUFFER_SLEEP_INTERVAL = 1e-7

logger = init_logger(__name__)


Expand Down Expand Up @@ -145,28 +151,29 @@ def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
@contextmanager
def acquire_write(self):
assert self._is_writer, "Only writers can acquire write"
start_index = self.current_idx
start_time = time.time()
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
# this block is written and not read by all readers
# try to write to the next block
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
if self.current_idx == start_index:
# no empty block found
if time.time(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# wait for a while (0.1 us)
time.sleep(1e-7)
# for writers, `self.current_idx` is the next block to write
# if this block is not ready to write,
# we need to wait until it is read by all readers

# wait for a while
time.sleep(RINGBUFFER_SLEEP_INTERVAL)

# if we wait for a long time, we should warn the user
if time.monotonic(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1

continue
# found a block that is either
# (1) not written
Expand All @@ -188,13 +195,14 @@ def acquire_write(self):
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break

@contextmanager
def acquire_read(self):
assert self._is_reader, "Only readers can acquire read"
start_index = self.current_idx
start_time = time.time()
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
Expand All @@ -204,19 +212,22 @@ def acquire_read(self):
# this block is either
# (1) not written
# (2) already read by this reader
# try to read the next block
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
if self.current_idx == start_index:
# no block found
if time.time(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# wait for a while (0.1 us)
time.sleep(1e-7)

# for readers, `self.current_idx` is the next block to read
# if this block is not ready,
# we need to wait until it is written

# wait for a while
time.sleep(RINGBUFFER_SLEEP_INTERVAL)

# if we wait for a long time, we should warn the user
if time.monotonic(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1

continue
# found a block that is not read by this reader
# let caller read from the buffer
Expand All @@ -226,6 +237,8 @@ def acquire_read(self):
# caller has read from the buffer
# set the read flag
metadata_buffer[self.reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break

def enqueue(self, obj):
Expand Down

0 comments on commit 9cccdc8

Please sign in to comment.