@@ -36,13 +36,33 @@ def __init__(self,
3636 )
3737
3838 def _get_request_ranks (self , request_id : str ):
39- # request_id format: $ACTUAL_REQUEST_ID|$E_RANK|$PD_RANK
39+ """Extract E_RANK and PD_RANK from a proxy-formatted request ID.
40+
41+ Extracts the request_id with format $ACTUAL_REQUEST_ID|$E_RANK|$PD_RANK
42+
43+ Args:
44+ request_id: The formatted request ID string from the proxy.
45+
46+ Returns:
47+ Tuple containing (E_RANK, PD_RANK).
48+ """
4049 result = request_id .split ("|" )
41- return int (result [1 ]), int (result [2 ])
50+ return int (result [- 2 ]), int (result [- 1 ])
4251
4352 def _send_prealloc_notification (self , request_id : str , input_id : int ,
4453 successful : bool , mm_hash : str ) -> None :
45- # PD -> E
54+ """
55+ Send pre-allocation notification from PD to E instance via Redis.
56+
57+ Notifies the encoder instance whether pre-allocation was successful
58+ and whether the encoder cache should be sent.
59+
60+ Args:
61+ request_id: The formatted request ID containing rank information.
62+ input_id: Index of the multimodal input within the request.
63+ successful: Whether pre-allocation succeeded and cache should be sent.
64+ mm_hash: Hash of the multimodal input.
65+ """
4666 transfer_data = {
4767 "request_id" : request_id ,
4868 "input_id" : input_id ,
@@ -58,7 +78,18 @@ def _send_encoder_cache_metas(
5878 self , request_id : str , input_id : int ,
5979 num_encoder_tokens : int , mm_hash : str
6080 ) -> None :
61- # E -> PD
81+ """
82+ Send encoder cache metadata from E to PD instance via Redis.
83+
84+ Transfers metadata needed for pre-allocating space for the encoder cache
85+ on the prefill/decode instance.
86+
87+ Args:
88+ request_id: The formatted request ID containing rank information.
89+ input_id: Index of the multimodal input within the request.
90+ num_encoder_tokens: Number of tokens in the encoder cache.
91+ mm_hash: Hash of the multimodal input.
92+ """
6293 transfer_data = {
6394 "request_id" : request_id ,
6495 "input_id" : input_id ,
@@ -73,7 +104,18 @@ def _send_encoder_cache_metas(
73104 def _send_encoder_cache (
74105 self , request_id : str , input_id : int ,
75106 encoder_cache : torch .Tensor , mm_hash : str ) -> None :
76- # E -> PD
107+ """
108+ Send encoder cache tensor from E to PD instance via Redis.
109+
110+ Converts the encoder cache to CPU float16 numpy array before sending
111+ to optimize transfer size.
112+
113+ Args:
114+ request_id: The formatted request ID containing rank information.
115+ input_id: Index of the multimodal input within the request.
116+ encoder_cache: The encoder output tensor to transfer.
117+ mm_hash: Hash of the multimodal input.
118+ """
77119 encoder_cache_numpy = encoder_cache .to ("cpu" , dtype = torch .float16 ).numpy ()
78120 transfer_data = msgpack_numpy .packb ({
79121 "request_id" : request_id ,
@@ -88,6 +130,16 @@ def _send_encoder_cache(
88130 def _recv_prealloc_notification (
89131 self , maybe_send_cache_callback : Callable [[str , int , bool , str ],
90132 None ]) -> None :
133+ """
134+ Receive pre-allocation notification on E instance from Redis.
135+
136+ Blocks until a notification is received, then unpacks the data and
137+ invokes the callback to handle cache sending logic.
138+
139+ Args:
140+ maybe_send_cache_callback: Callback to determine whether to send
141+ the encoder cache based on the pre-allocation result.
142+ """
91143 transfered_data = self .redis_client .blpop (f"prealloc{ self .rank } " )[1 ]
92144 transfered_data = msgpack_numpy .unpackb (transfered_data , raw = False )
93145 request_id , input_id , successful , mm_hash = (
@@ -102,6 +154,16 @@ def _recv_prealloc_notification(
102154 def _recv_encoder_cache_metas (
103155 self , preallocate_callback : Callable [[str , int , int , str ],
104156 None ]) -> None :
157+ """
158+ Receive encoder cache metadata on PD instance from Redis.
159+
160+ Blocks until metadata is received, then unpacks the data and invokes
161+ the callback to pre-allocate space in the scheduler.
162+
163+ Args:
164+ preallocate_callback: Scheduler callback to pre-allocate space
165+ for the incoming encoder cache.
166+ """
105167 transfered_data = self .redis_client .blpop (f"cache_metas{ self .rank } " )[1 ]
106168 transfered_data = msgpack_numpy .unpackb (transfered_data , raw = False )
107169 request_id , input_id , num_encoder_tokens , mm_hash = (
@@ -117,6 +179,16 @@ def _recv_encoder_cache(
117179 self ,
118180 injection_callback : Callable [[str , int , torch .Tensor , str ],None ]
119181 ) -> None :
182+ """
183+ Receive encoder cache tensor on PD instance from Redis.
184+
185+ Blocks until cache data is received, converts it from numpy back to
186+ the appropriate torch tensor format, then invokes the injection callback.
187+
188+ Args:
189+ injection_callback: Model runner callback to inject the encoder
190+ cache into the cache dictionary.
191+ """
120192 transfered_data = self .redis_client .blpop (f"cache{ self .rank } " )[1 ]
121193 transfered_data = msgpack_numpy .unpackb (transfered_data , raw = False )
122194 request_id , input_id , encoder_cache , mm_hash = (
0 commit comments