Skip to content

Commit 44a7c30

Browse files
committed
[Docs] Docstrings for all new methods provided by EPD disagg update
Signed-off-by: LastZhabka <sakhmoldin.mukhammadarif@gmail.com>
1 parent 8c71d64 commit 44a7c30

File tree

5 files changed

+98
-20
lines changed

5 files changed

+98
-20
lines changed

vllm/separated_encode/ec_transfer/connector/redis.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,33 @@ def __init__(self,
3636
)
3737

3838
def _get_request_ranks(self, request_id: str):
39-
# request_id format: $ACTUAL_REQUEST_ID|$E_RANK|$PD_RANK
39+
"""Extract E_RANK and PD_RANK from a proxy-formatted request ID.
40+
41+
Extracts the request_id with format $ACTUAL_REQUEST_ID|$E_RANK|$PD_RANK
42+
43+
Args:
44+
request_id: The formatted request ID string from the proxy.
45+
46+
Returns:
47+
Tuple containing (E_RANK, PD_RANK).
48+
"""
4049
result = request_id.split("|")
41-
return int(result[1]), int(result[2])
50+
return int(result[-2]), int(result[-1])
4251

4352
def _send_prealloc_notification(self, request_id: str, input_id: int,
4453
successful: bool, mm_hash: str) -> None:
45-
# PD -> E
54+
"""
55+
Send pre-allocation notification from PD to E instance via Redis.
56+
57+
Notifies the encoder instance whether pre-allocation was successful
58+
and whether the encoder cache should be sent.
59+
60+
Args:
61+
request_id: The formatted request ID containing rank information.
62+
input_id: Index of the multimodal input within the request.
63+
successful: Whether pre-allocation succeeded and cache should be sent.
64+
mm_hash: Hash of the multimodal input.
65+
"""
4666
transfer_data = {
4767
"request_id": request_id,
4868
"input_id": input_id,
@@ -58,7 +78,18 @@ def _send_encoder_cache_metas(
5878
self, request_id: str, input_id: int,
5979
num_encoder_tokens: int, mm_hash: str
6080
) -> None:
61-
# E -> PD
81+
"""
82+
Send encoder cache metadata from E to PD instance via Redis.
83+
84+
Transfers metadata needed for pre-allocating space for the encoder cache
85+
on the prefill/decode instance.
86+
87+
Args:
88+
request_id: The formatted request ID containing rank information.
89+
input_id: Index of the multimodal input within the request.
90+
num_encoder_tokens: Number of tokens in the encoder cache.
91+
mm_hash: Hash of the multimodal input.
92+
"""
6293
transfer_data = {
6394
"request_id": request_id,
6495
"input_id": input_id,
@@ -73,7 +104,18 @@ def _send_encoder_cache_metas(
73104
def _send_encoder_cache(
74105
self, request_id: str, input_id: int,
75106
encoder_cache: torch.Tensor, mm_hash: str) -> None:
76-
# E -> PD
107+
"""
108+
Send encoder cache tensor from E to PD instance via Redis.
109+
110+
Converts the encoder cache to CPU float16 numpy array before sending
111+
to optimize transfer size.
112+
113+
Args:
114+
request_id: The formatted request ID containing rank information.
115+
input_id: Index of the multimodal input within the request.
116+
encoder_cache: The encoder output tensor to transfer.
117+
mm_hash: Hash of the multimodal input.
118+
"""
77119
encoder_cache_numpy = encoder_cache.to("cpu", dtype=torch.float16).numpy()
78120
transfer_data = msgpack_numpy.packb({
79121
"request_id": request_id,
@@ -88,6 +130,16 @@ def _send_encoder_cache(
88130
def _recv_prealloc_notification(
89131
self, maybe_send_cache_callback: Callable[[str, int, bool, str],
90132
None]) -> None:
133+
"""
134+
Receive pre-allocation notification on E instance from Redis.
135+
136+
Blocks until a notification is received, then unpacks the data and
137+
invokes the callback to handle cache sending logic.
138+
139+
Args:
140+
maybe_send_cache_callback: Callback to determine whether to send
141+
the encoder cache based on the pre-allocation result.
142+
"""
91143
transfered_data = self.redis_client.blpop(f"prealloc{self.rank}")[1]
92144
transfered_data = msgpack_numpy.unpackb(transfered_data, raw=False)
93145
request_id, input_id, successful, mm_hash = (
@@ -102,6 +154,16 @@ def _recv_prealloc_notification(
102154
def _recv_encoder_cache_metas(
103155
self, preallocate_callback: Callable[[str, int, int, str],
104156
None]) -> None:
157+
"""
158+
Receive encoder cache metadata on PD instance from Redis.
159+
160+
Blocks until metadata is received, then unpacks the data and invokes
161+
the callback to pre-allocate space in the scheduler.
162+
163+
Args:
164+
preallocate_callback: Scheduler callback to pre-allocate space
165+
for the incoming encoder cache.
166+
"""
105167
transfered_data = self.redis_client.blpop(f"cache_metas{self.rank}")[1]
106168
transfered_data = msgpack_numpy.unpackb(transfered_data, raw=False)
107169
request_id, input_id, num_encoder_tokens, mm_hash = (
@@ -117,6 +179,16 @@ def _recv_encoder_cache(
117179
self,
118180
injection_callback: Callable[[str, int, torch.Tensor, str],None]
119181
) -> None:
182+
"""
183+
Receive encoder cache tensor on PD instance from Redis.
184+
185+
Blocks until cache data is received, converts it from numpy back to
186+
the appropriate torch tensor format, then invokes the injection callback.
187+
188+
Args:
189+
injection_callback: Model runner callback to inject the encoder
190+
cache into the cache dictionary.
191+
"""
120192
transfered_data = self.redis_client.blpop(f"cache{self.rank}")[1]
121193
transfered_data = msgpack_numpy.unpackb(transfered_data, raw=False)
122194
request_id, input_id, encoder_cache, mm_hash = (

vllm/separated_encode/ec_transfer/connector/template.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _send_encoder_cache(
155155
Args:
156156
request_id: id of the encoder cache's request.
157157
input_id: index of the mm input amoung request's mm inputs
158-
encoder_cache: cache produced by vision model, in np array form
158+
encoder_cache: encoder output
159159
mm_hash: hash of the mm input
160160
"""
161161
pass
@@ -371,7 +371,7 @@ def schedule_send_encoder_cache(
371371
Args:
372372
request_id: id of the encoder cache's request.
373373
input_id: index of the mm input amoung request's mm inputs
374-
encoder_cache: cache produced by vision model, in np array form
374+
encoder_cache: encoder output
375375
"""
376376
self.send_tasks_queue.put_nowait(
377377
(self._finish_wrapper, (self._send_encoder_cache, request_id,
@@ -381,16 +381,18 @@ def _finish_wrapper(
381381
self, callback: Callable, request_id: str, input_id: int,
382382
encoder_cache: torch.Tensor, mm_hash: str
383383
):
384-
384+
"""
385+
Wrapper to fill the transfered_ids list
386+
"""
385387
callback(request_id, input_id, encoder_cache, mm_hash)
386388
with self.transfered_ids_lock:
387389
self.transfered_ids.append((request_id, input_id))
388390

389391
def get_transfered_ids(self, ):
392+
"""
393+
Method to get transfered ids
394+
"""
390395
with self.transfered_ids_lock:
391396
transfered_ids = self.transfered_ids
392397
self.transfered_ids = []
393-
return transfered_ids
394-
395-
def finish_request(self, req_id):
396-
pass
398+
return transfered_ids

vllm/separated_encode/sched/encoder_scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def update_from_output(
226226
scheduler_output: SchedulerOutput,
227227
model_runner_output: ModelRunnerOutput,
228228
) -> dict[int, EngineCoreOutputs]:
229+
230+
# clean up the logic space of mm_data that was transfered
229231
transfered_mm_data = model_runner_output.transfered_mm_data
230232

231233
for (req_id, input_id) in transfered_mm_data:
@@ -241,6 +243,7 @@ def update_from_output(
241243

242244
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
243245

246+
# stop all requests from the current batch
244247
model_finished = []
245248
for request in self.running:
246249
req_id = request.request_id

vllm/separated_encode/worker/gpu_epd_lm_wrapper.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,19 @@ def execute_model(
9191
model_runner_output.injected_mm_data = injected_encoder_cache_ids
9292
return model_runner_output
9393

94-
def receive_encoder_cache(self, request_id, input_id, encoder_cache, mm_hash):
94+
def receive_encoder_cache(
95+
self,
96+
request_id: str,
97+
input_id: int,
98+
encoder_cache: torch.Tensor,
99+
mm_hash: str
100+
):
95101
"""
96102
Callback function for receiving encoder cache from remote instances.
97103
98104
This method is invoked by the encoder cache connector when encoder
99-
cache data is received from remote encoder instances. It processes
100-
the received numpy array by converting it to a PyTorch tensor with
101-
the correct device placement and data type, then stores it in the
102-
local encoder cache dictionary.
105+
cache data is received from remote encoder instances, then It stores
106+
received tensor in the local encoder_cache dictionary.
103107
104108
The method updates the injected encoder cache IDs list to inform the
105109
scheduler about successful cache injections.

vllm/separated_encode/worker/gpu_epd_vm_wrapper.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,6 @@ def execute_model(
110110
inputs, and transferring computed encoder caches to remote instances
111111
via a connector, while providing transfer status information to the
112112
scheduler.
113-
114-
It also converts encoder outputs into CPU tensors and then to numpy
115-
arrays to prepare data for the transfer.
116113
"""
117114
self._update_states(scheduler_output)
118115
self._execute_mm_encoder(scheduler_output)

0 commit comments

Comments
 (0)