| 
 | 1 | +import pickle  | 
 | 2 | +import time  | 
 | 3 | +from contextlib import contextmanager  | 
 | 4 | +from multiprocessing import shared_memory  | 
 | 5 | +from typing import Optional  | 
 | 6 | +from unittest.mock import patch  | 
 | 7 | + | 
 | 8 | +import torch  | 
 | 9 | +import torch.distributed as dist  | 
 | 10 | +from torch.distributed import ProcessGroup  | 
 | 11 | + | 
 | 12 | +import vllm.envs as envs  | 
 | 13 | +from vllm.logger import init_logger  | 
 | 14 | + | 
 | 15 | +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL  | 
 | 16 | + | 
 | 17 | +logger = init_logger(__name__)  | 
 | 18 | + | 
 | 19 | + | 
 | 20 | +class ShmRingBuffer:  | 
 | 21 | + | 
 | 22 | +    def __init__(self,  | 
 | 23 | +                 n_reader: int,  | 
 | 24 | +                 max_chunk_bytes: int,  | 
 | 25 | +                 max_chunks: int,  | 
 | 26 | +                 name: Optional[str] = None):  | 
 | 27 | +        """  | 
 | 28 | +        A shared memory ring buffer implementation for broadcast communication.  | 
 | 29 | +        Essentially, it is a queue where only one will `enqueue` and multiple  | 
 | 30 | +        will `dequeue`. The max size of each item, together with the max number  | 
 | 31 | +        of items that can be stored in the buffer are known in advance.  | 
 | 32 | +        In this case, we don't need to synchronize the access to  | 
 | 33 | +         the buffer.  | 
 | 34 | +          | 
 | 35 | +        Buffer memory layout:  | 
 | 36 | +                  data                                 metadata  | 
 | 37 | +                    |                                      |  | 
 | 38 | +                    | (current_idx)                        | (current_idx)  | 
 | 39 | +                    v                                      v  | 
 | 40 | +        +-------------------------------+----------------------------------------+  | 
 | 41 | +        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |  | 
 | 42 | +        +-------------------------------+----------------------------------------+  | 
 | 43 | +        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |  | 
 | 44 | +
  | 
 | 45 | +        metadata memory layout: each byte is a flag, the first byte is the written  | 
 | 46 | +        flag, and the rest are reader flags. The flags are set to 0 by default.  | 
 | 47 | +        +--------------+--------------+--------------+-----+--------------+  | 
 | 48 | +        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |  | 
 | 49 | +        +--------------+--------------+--------------+-----+--------------+  | 
 | 50 | +
  | 
 | 51 | +        During creation, `name` is None and the buffer is created. We can pass the  | 
 | 52 | +        created object to other processes by pickling it. The other processes will  | 
 | 53 | +        get the name of the shared memory and open it, so that they can access the  | 
 | 54 | +        same shared memory buffer.  | 
 | 55 | +        """# noqa  | 
 | 56 | +        self.n_reader = n_reader  | 
 | 57 | +        self.metadata_size = 1 + n_reader  | 
 | 58 | +        self.max_chunk_bytes = max_chunk_bytes  | 
 | 59 | +        self.max_chunks = max_chunks  | 
 | 60 | +        self.total_bytes_of_buffer = (self.max_chunk_bytes +  | 
 | 61 | +                                      self.metadata_size) * self.max_chunks  | 
 | 62 | +        self.data_offset = 0  | 
 | 63 | +        self.metadata_offset = self.max_chunk_bytes * self.max_chunks  | 
 | 64 | + | 
 | 65 | +        if name is None:  | 
 | 66 | +            # we are creating a buffer  | 
 | 67 | +            self.is_creator = True  | 
 | 68 | +            self.shared_memory = shared_memory.SharedMemory(  | 
 | 69 | +                create=True, size=self.total_bytes_of_buffer)  | 
 | 70 | +            # initialize the metadata section to 0  | 
 | 71 | +            with memoryview(self.shared_memory.buf[self.metadata_offset:]  | 
 | 72 | +                            ) as metadata_buffer:  | 
 | 73 | +                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)  | 
 | 74 | +        else:  | 
 | 75 | +            # we are opening an existing buffer  | 
 | 76 | +            self.is_creator = False  | 
 | 77 | +            # fix to https://stackoverflow.com/q/62748654/9191338  | 
 | 78 | +            # Python incorrectly tracks shared memory even if it is not  | 
 | 79 | +            # created by the process. The following patch is a workaround.  | 
 | 80 | +            with patch("multiprocessing.resource_tracker.register",  | 
 | 81 | +                       lambda *args, **kwargs: None):  | 
 | 82 | +                self.shared_memory = shared_memory.SharedMemory(name=name)  | 
 | 83 | +            assert self.shared_memory.size == self.total_bytes_of_buffer  | 
 | 84 | +            with memoryview(self.shared_memory.buf[self.metadata_offset:]  | 
 | 85 | +                            ) as metadata_buffer:  | 
 | 86 | +                tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8)  | 
 | 87 | +                assert torch.all(tensor == 0)  | 
 | 88 | + | 
 | 89 | +    def __reduce__(self):  | 
 | 90 | +        return (  | 
 | 91 | +            self.__class__,  | 
 | 92 | +            (self.n_reader, self.max_chunk_bytes, self.max_chunks,  | 
 | 93 | +             self.shared_memory.name),  | 
 | 94 | +        )  | 
 | 95 | + | 
 | 96 | +    def __del__(self):  | 
 | 97 | +        self.shared_memory.close()  | 
 | 98 | +        if self.is_creator:  | 
 | 99 | +            self.shared_memory.unlink()  | 
 | 100 | + | 
 | 101 | +    @contextmanager  | 
 | 102 | +    def get_data(self, current_idx: int):  | 
 | 103 | +        start = self.data_offset + current_idx * self.max_chunk_bytes  | 
 | 104 | +        end = start + self.max_chunk_bytes  | 
 | 105 | +        with memoryview(self.shared_memory.buf[start:end]) as buf:  | 
 | 106 | +            yield buf  | 
 | 107 | + | 
 | 108 | +    @contextmanager  | 
 | 109 | +    def get_metadata(self, current_idx: int):  | 
 | 110 | +        start = self.metadata_offset + current_idx * self.metadata_size  | 
 | 111 | +        end = start + self.metadata_size  | 
 | 112 | +        with memoryview(self.shared_memory.buf[start:end]) as buf:  | 
 | 113 | +            yield buf  | 
 | 114 | + | 
 | 115 | + | 
 | 116 | +class ShmRingBufferIO:  | 
 | 117 | + | 
 | 118 | +    def __init__(self, buffer: ShmRingBuffer, reader_rank: int):  | 
 | 119 | +        self.buffer = buffer  | 
 | 120 | +        self.reader_rank = reader_rank  | 
 | 121 | +        self._is_writer = self.reader_rank == -1  | 
 | 122 | +        self._is_reader = not self._is_writer  | 
 | 123 | +        if self._is_reader:  | 
 | 124 | +            assert 0 <= self.reader_rank < buffer.n_reader, \  | 
 | 125 | +                (f"Invalid reader rank {self.reader_rank} for buffer"  | 
 | 126 | +                f" created with {buffer.n_reader} readers")  | 
 | 127 | +        self.current_idx = 0  | 
 | 128 | + | 
 | 129 | +    @contextmanager  | 
 | 130 | +    def acquire_write(self):  | 
 | 131 | +        assert self._is_writer, "Only writers can acquire write"  | 
 | 132 | +        start_index = self.current_idx  | 
 | 133 | +        start_time = time.time()  | 
 | 134 | +        n_warning = 1  | 
 | 135 | +        while True:  | 
 | 136 | +            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:  | 
 | 137 | +                read_count = sum(metadata_buffer[1:])  | 
 | 138 | +                written_flag = metadata_buffer[0]  | 
 | 139 | +                if written_flag and read_count != self.buffer.n_reader:  | 
 | 140 | +                    # this block is written and not read by all readers  | 
 | 141 | +                    # try to write to the next block  | 
 | 142 | +                    self.current_idx = (self.current_idx +  | 
 | 143 | +                                        1) % self.buffer.max_chunks  | 
 | 144 | +                    if self.current_idx == start_index:  | 
 | 145 | +                        # no empty block found  | 
 | 146 | +                        if time.time(  | 
 | 147 | +                        ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:  # noqa  | 
 | 148 | +                            logger.warning(  | 
 | 149 | +                                "No available block found in %s second. ",  | 
 | 150 | +                                VLLM_RINGBUFFER_WARNING_INTERVAL)  | 
 | 151 | +                            n_warning += 1  | 
 | 152 | +                        # wait for a while (0.1 us)  | 
 | 153 | +                        time.sleep(1e-7)  | 
 | 154 | +                    continue  | 
 | 155 | +                # found a block that is either  | 
 | 156 | +                # (1) not written  | 
 | 157 | +                # (2) read by all readers  | 
 | 158 | + | 
 | 159 | +                # mark the block as not written  | 
 | 160 | +                metadata_buffer[0] = 0  | 
 | 161 | +                # let caller write to the buffer  | 
 | 162 | +                with self.buffer.get_data(self.current_idx) as buf:  | 
 | 163 | +                    yield buf  | 
 | 164 | + | 
 | 165 | +                # caller has written to the buffer  | 
 | 166 | +                # mark the block as written  | 
 | 167 | +                metadata_buffer[0] = 1  | 
 | 168 | +                for i in range(1, self.buffer.n_reader + 1):  | 
 | 169 | +                    # set read flag to 0, meaning it is not read yet  | 
 | 170 | +                    metadata_buffer[i] = 0  | 
 | 171 | +                break  | 
 | 172 | + | 
 | 173 | +    @contextmanager  | 
 | 174 | +    def acquire_read(self):  | 
 | 175 | +        assert self._is_reader, "Only readers can acquire read"  | 
 | 176 | +        start_index = self.current_idx  | 
 | 177 | +        start_time = time.time()  | 
 | 178 | +        n_warning = 1  | 
 | 179 | +        while True:  | 
 | 180 | +            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:  | 
 | 181 | +                read_flag = metadata_buffer[self.reader_rank + 1]  | 
 | 182 | +                written_flag = metadata_buffer[0]  | 
 | 183 | +                if not written_flag or read_flag:  | 
 | 184 | +                    # this block is either  | 
 | 185 | +                    # (1) not written  | 
 | 186 | +                    # (2) already read by this reader  | 
 | 187 | +                    # try to read the next block  | 
 | 188 | +                    self.current_idx = (self.current_idx +  | 
 | 189 | +                                        1) % self.buffer.max_chunks  | 
 | 190 | +                    if self.current_idx == start_index:  | 
 | 191 | +                        # no block found  | 
 | 192 | +                        if time.time(  | 
 | 193 | +                        ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:  # noqa  | 
 | 194 | +                            logger.warning(  | 
 | 195 | +                                "No available block found in %s second. ",  | 
 | 196 | +                                VLLM_RINGBUFFER_WARNING_INTERVAL)  | 
 | 197 | +                            n_warning += 1  | 
 | 198 | +                        # wait for a while (0.1 us)  | 
 | 199 | +                        time.sleep(1e-7)  | 
 | 200 | +                    continue  | 
 | 201 | +                # found a block that is not read by this reader  | 
 | 202 | +                # let caller read from the buffer  | 
 | 203 | +                with self.buffer.get_data(self.current_idx) as buf:  | 
 | 204 | +                    yield buf  | 
 | 205 | + | 
 | 206 | +                # caller has read from the buffer  | 
 | 207 | +                # set the read flag  | 
 | 208 | +                metadata_buffer[self.reader_rank + 1] = 1  | 
 | 209 | +                break  | 
 | 210 | + | 
 | 211 | +    def enqueue(self, obj):  | 
 | 212 | +        assert self._is_writer, "Only writers can enqueue"  | 
 | 213 | +        serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)  | 
 | 214 | +        if len(serialized_obj) > self.buffer.max_chunk_bytes:  | 
 | 215 | +            raise RuntimeError(  | 
 | 216 | +                f"{len(serialized_obj)=} larger than the allowed value "  | 
 | 217 | +                f"{self.buffer.max_chunk_bytes},"  | 
 | 218 | +                "Please increase the max_chunk_bytes parameter.")  | 
 | 219 | +        with self.acquire_write() as buf:  | 
 | 220 | +            buf[:len(serialized_obj)] = serialized_obj  | 
 | 221 | + | 
 | 222 | +    def dequeue(self):  | 
 | 223 | +        assert self._is_reader, "Only readers can dequeue"  | 
 | 224 | +        with self.acquire_read() as buf:  | 
 | 225 | +            # no need to know the size of serialized object  | 
 | 226 | +            # pickle format itself contains the size information internally  | 
 | 227 | +            # see https://docs.python.org/3/library/pickle.html  | 
 | 228 | +            obj = pickle.loads(buf)  | 
 | 229 | +        return obj  | 
 | 230 | + | 
 | 231 | +    def broadcast_object(self, obj=None):  | 
 | 232 | +        if self._is_writer:  | 
 | 233 | +            self.enqueue(obj)  | 
 | 234 | +            return obj  | 
 | 235 | +        else:  | 
 | 236 | +            return self.dequeue()  | 
 | 237 | + | 
 | 238 | +    def create_from_process_group(pg: ProcessGroup,  | 
 | 239 | +                                  max_chunk_bytes,  | 
 | 240 | +                                  max_chunks,  | 
 | 241 | +                                  writer_rank=0) -> "ShmRingBufferIO":  | 
 | 242 | +        group_rank = dist.get_rank(pg)  | 
 | 243 | +        group_world_size = dist.get_world_size(pg)  | 
 | 244 | +        ranks_inside_group = list(range(group_world_size))  | 
 | 245 | +        global_ranks = dist.get_process_group_ranks(pg)  | 
 | 246 | +        n_reader = group_world_size - 1  | 
 | 247 | +        buffer: ShmRingBuffer  | 
 | 248 | +        if group_rank == writer_rank:  | 
 | 249 | +            buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)  | 
 | 250 | +            dist.broadcast_object_list([buffer], src=global_ranks[writer_rank])  | 
 | 251 | +            dist.barrier(pg)  | 
 | 252 | +            return ShmRingBufferIO(buffer, -1)  | 
 | 253 | +        else:  | 
 | 254 | +            recv = [None]  | 
 | 255 | +            dist.broadcast_object_list(recv, src=global_ranks[writer_rank])  | 
 | 256 | +            dist.barrier(pg)  | 
 | 257 | +            buffer = recv[0]  # type: ignore  | 
 | 258 | +            rest_ranks = [r for r in ranks_inside_group if r != writer_rank]  | 
 | 259 | +            return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))  | 
0 commit comments