From 860a1d635ce6b150a8de79d4121a898af3582b17 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 22 Jun 2024 10:00:43 -0700 Subject: [PATCH] [core][distributed] improve shared memory broadcast (#5754) --- .../device_communicators/shm_broadcast.py | 42 ++++++++++++++----- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 119befcf64052..c44bd2f11ee8b 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -48,6 +48,26 @@ def __init__(self, | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | +--------------+--------------+--------------+-----+--------------+ + The state of metadata is as follows: + + (case 1) 0???...???: the block is not written yet, cannot read, can write + (case 2) 1000...000: the block is just written, can read, cannot write + (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write + (case 4) 1111...111: the block is written and read by all readers, cannot read, can write + + State transition for readers: + + When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read. + Only after the caller finishes reading the block, the reader can mark the block as read. + Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0). + + State transition for writer: + + When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case + to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer + can reset the reader flags to 0, and mark the block as written (from 0 to 1). + NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct. + During creation, `name` is None and the buffer is created. We can pass the created object to other processes by pickling it. The other processes will get the name of the shared memory and open it, so that they can access the @@ -81,10 +101,6 @@ def __init__(self, lambda *args, **kwargs: None): self.shared_memory = shared_memory.SharedMemory(name=name) assert self.shared_memory.size == self.total_bytes_of_buffer - with memoryview(self.shared_memory.buf[self.metadata_offset:] - ) as metadata_buffer: - tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8) - assert torch.all(tensor == 0) def __reduce__(self): return ( @@ -163,11 +179,15 @@ def acquire_write(self): yield buf # caller has written to the buffer - # mark the block as written - metadata_buffer[0] = 1 + # NOTE: order is important here + # first set the read flags to 0 + # then set the written flag to 1 + # otherwise, the readers may think they already read the block for i in range(1, self.buffer.n_reader + 1): # set read flag to 0, meaning it is not read yet metadata_buffer[i] = 0 + # mark the block as written + metadata_buffer[0] = 1 break @contextmanager @@ -247,13 +267,15 @@ def create_from_process_group(pg: ProcessGroup, buffer: ShmRingBuffer if group_rank == writer_rank: buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks) - dist.broadcast_object_list([buffer], src=global_ranks[writer_rank]) - dist.barrier(pg) + dist.broadcast_object_list([buffer], + src=global_ranks[writer_rank], + group=pg) return ShmRingBufferIO(buffer, -1) else: recv = [None] - dist.broadcast_object_list(recv, src=global_ranks[writer_rank]) - dist.barrier(pg) + dist.broadcast_object_list(recv, + src=global_ranks[writer_rank], + group=pg) buffer = recv[0] # type: ignore rest_ranks = [r for r in ranks_inside_group if r != writer_rank] return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))