Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def __init__(
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: list[int] | None = None,
max_chunk_bytes: int = 1024 * 1024 * 24, # 24MiB
# Default of 24MiB chosen to be large enough to accommodate grammar
# bitmask tensors for large batches (1024 requests).
max_chunk_bytes: int = 1024 * 1024 * 24,
max_chunks: int = 10,
connect_ip: str | None = None,
):
Expand Down Expand Up @@ -538,6 +540,10 @@ def oob_callback(buf: PickleBuffer) -> bool:
buf[0] = 1 # overflow
self.local_socket.send_multipart(all_buffers, copy=False)
else:
# Byte 0: 0
# Bytes 1-2: Count of buffers
# Then each buffer follows, preceded by 4 bytes containing its length:
# [4 byte int L][L bytes of buffer content] ...
with self.acquire_write(timeout) as buf:
buf[0] = 0 # not overflow
offset = 3
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ class SchedulerOutput:
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]

# ids of structured outputs requests included in the bitmask, in order.
# ids of structured outputs requests included in the bitmask, in the
# same order as the corresponding stacked rows of the bitmask.
# There may be more than one row per request in the case of speculative decoding.
structured_output_request_ids: list[str]
# the bitmask for the whole batch
grammar_bitmask: "npt.NDArray[np.int32] | None"
Expand Down