Skip to content

Commit 8c71d64

Browse files
committed
[Refactor]
Signed-off-by: LastZhabka <sakhmoldin.mukhammadarif@gmail.com>
1 parent 35f8655 commit 8c71d64

File tree

12 files changed

+325
-157
lines changed

12 files changed

+325
-157
lines changed

examples/online_serving/separated_encode/proxy/proxy1e1pd_aiohttp.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -45,31 +45,45 @@ async def shutdown_event():
4545
if decode_session:
4646
await decode_session.close()
4747

48+
49+
def has_mm_input(request_data: dict):
50+
if "messages" not in request_data:
51+
return False
52+
for message in request_data["messages"]:
53+
if not isinstance(message.get("content"), list):
54+
continue
55+
for content_item in message["content"]:
56+
if content_item.get("type") in ["image_url", "audio_url", "input_audio"]:
57+
return True
58+
return False
59+
4860
async def forward_streaming_request(
4961
request_data: dict,
5062
request_id: str
5163
) -> AsyncIterator[str]:
5264
headers = {"x-request-id": request_id}
53-
task1 = asyncio.create_task(
54-
encode_session.post(
55-
f"{ENCODE_SERVER_URL}/v1/chat/completions",
56-
json=request_data,
57-
headers=headers
65+
# Skip request to encoder instance if we don't have mm input
66+
if has_mm_input(request_data):
67+
task1 = asyncio.create_task(
68+
encode_session.post(
69+
f"{ENCODE_SERVER_URL}/v1/chat/completions",
70+
json=request_data,
71+
headers=headers
72+
)
5873
)
59-
)
60-
try:
61-
response = await task1
62-
if response.status != 200:
63-
error_text = await response.text()
74+
try:
75+
response = await task1
76+
if response.status != 200:
77+
error_text = await response.text()
78+
raise HTTPException(
79+
status_code=response.status,
80+
detail={"error": "Request failed", "message": error_text}
81+
)
82+
except Exception as e:
6483
raise HTTPException(
65-
status_code=response.status,
66-
detail={"error": "Request failed", "message": error_text}
84+
status_code=500,
85+
detail={"error": "Internal server error", "message": str(e)}
6786
)
68-
except Exception as e:
69-
raise HTTPException(
70-
status_code=500,
71-
detail={"error": "Internal server error", "message": str(e)}
72-
)
7387

7488
try:
7589
async with decode_session.post(
@@ -83,37 +97,37 @@ async def forward_streaming_request(
8397
yield chunk.decode('utf-8', errors='ignore')
8498
except Exception as e:
8599
logger.error(f"Error in streaming: {e}")
86-
task1.cancel()
87100
raise
88101

89102
async def forward_non_streaming_request(
90103
request_data: dict,
91104
request_id: str
92105
) -> dict:
93106
headers = {"x-request-id": request_id}
94-
95-
# Start request to encode server
96-
task1 = asyncio.create_task(
97-
encode_session.post(
98-
f"{ENCODE_SERVER_URL}/v1/chat/completions",
99-
json=request_data,
100-
headers=headers
107+
# Skip request to encoder instance if we don't have mm input
108+
if has_mm_input(request_data):
109+
# Start request to encode server
110+
task1 = asyncio.create_task(
111+
encode_session.post(
112+
f"{ENCODE_SERVER_URL}/v1/chat/completions",
113+
json=request_data,
114+
headers=headers
115+
)
101116
)
102-
)
103117

104-
try:
105-
response = await task1
106-
if response.status != 200:
107-
error_text = await response.text()
118+
try:
119+
response = await task1
120+
if response.status != 200:
121+
error_text = await response.text()
122+
raise HTTPException(
123+
status_code=response.status,
124+
detail={"error": "Request failed", "message": error_text}
125+
)
126+
except Exception as e:
108127
raise HTTPException(
109-
status_code=response.status,
110-
detail={"error": "Request failed", "message": error_text}
128+
status_code=500,
129+
detail={"error": "Internal server error", "message": str(e)}
111130
)
112-
except Exception as e:
113-
raise HTTPException(
114-
status_code=500,
115-
detail={"error": "Internal server error", "message": str(e)}
116-
)
117131

118132
try:
119133
# Make request to decode server
@@ -127,7 +141,6 @@ async def forward_non_streaming_request(
127141
return result
128142
except Exception as e:
129143
logger.error(f"Error in non-streaming: {e}")
130-
task1.cancel()
131144
raise
132145

133146
@app.post("/v1/chat/completions")

vllm/separated_encode/README.md

Lines changed: 33 additions & 43 deletions
Large diffs are not rendered by default.

vllm/separated_encode/ec_transfer/connector/redis.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,33 @@
33
from typing import Callable, Literal, Optional
44

55
import msgpack_numpy
6-
import numpy as np
76
import redis
8-
from numpy.typing import NDArray
97

108
from vllm.config import VllmConfig
119
from vllm.separated_encode.ec_transfer.connector.template import (
1210
ECConnectorTemplate)
1311
from vllm.logger import init_logger
12+
import torch
1413

1514
logger = init_logger(__name__)
1615

1716
class RedisECConnector(ECConnectorTemplate):
1817

1918
def __init__(self,
2019
vllm_config: "VllmConfig",
20+
device: Optional[torch.device],
2121
intra_instance_type: Literal["scheduler", "model-runner"],
2222
preallocate_callback: Optional[Callable[[str, int, int, str],
2323
None]],
2424
injection_callback: Optional[Callable[
25-
[str, int, NDArray[np.float32], str], None]],
25+
[str, int, torch.Tensor, str], None]],
2626
redis_host: str = "localhost",
2727
redis_port: int = 6379):
2828
self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port)
2929
self.rank = vllm_config.epd_disagg_config.epd_rank
3030
super().__init__(
3131
vllm_config,
32+
device,
3233
intra_instance_type,
3334
preallocate_callback,
3435
injection_callback,
@@ -71,12 +72,13 @@ def _send_encoder_cache_metas(
7172

7273
def _send_encoder_cache(
7374
self, request_id: str, input_id: int,
74-
encoder_cache: NDArray[np.float32], mm_hash: str) -> None:
75+
encoder_cache: torch.Tensor, mm_hash: str) -> None:
7576
# E -> PD
77+
encoder_cache_numpy = encoder_cache.to("cpu", dtype=torch.float16).numpy()
7678
transfer_data = msgpack_numpy.packb({
7779
"request_id": request_id,
7880
"input_id": input_id,
79-
"encoder_cache": encoder_cache,
81+
"encoder_cache": encoder_cache_numpy,
8082
"mm_hash": mm_hash
8183
})
8284
rank = self._get_request_ranks(request_id)[1]
@@ -113,7 +115,7 @@ def _recv_encoder_cache_metas(
113115

114116
def _recv_encoder_cache(
115117
self,
116-
injection_callback: Callable[[str, int, NDArray[np.float32], str],None]
118+
injection_callback: Callable[[str, int, torch.Tensor, str],None]
117119
) -> None:
118120
transfered_data = self.redis_client.blpop(f"cache{self.rank}")[1]
119121
transfered_data = msgpack_numpy.unpackb(transfered_data, raw=False)
@@ -123,5 +125,7 @@ def _recv_encoder_cache(
123125
transfered_data["encoder_cache"],
124126
transfered_data["mm_hash"]
125127
)
128+
encoder_cache = torch.from_numpy(encoder_cache).to(
129+
device=self.device, dtype=self.dtype)
126130
logger.debug(f"Received encoder cache -> {self.rank}, {request_id}")
127131
injection_callback(request_id, input_id, encoder_cache, mm_hash)

vllm/separated_encode/ec_transfer/connector/template.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from concurrent.futures import ThreadPoolExecutor
77
from typing import Callable, Literal, Optional
88

9-
import numpy as np
10-
from numpy.typing import NDArray
9+
import torch
1110

1211
from vllm.config import EPDDisaggConfig, VllmConfig
1312
from vllm.logger import init_logger
@@ -37,9 +36,10 @@ class ECConnectorTemplate(ABC):
3736
def __init__(
3837
self,
3938
vllm_config: "VllmConfig",
39+
device: Optional[torch.device],
4040
intra_instance_type: Literal["scheduler", "model-runner"],
4141
preallocate_callback: Optional[Callable[[str, int, int, str], None]],
42-
injection_callback: Optional[Callable[[str, int, NDArray[np.float32], str],
42+
injection_callback: Optional[Callable[[str, int, torch.Tensor, str],
4343
None]],
4444
):
4545
callback_mapping = {
@@ -55,12 +55,14 @@ def __init__(
5555
("prefill+decode", "model-runner"): (self._recv_encoder_cache,
5656
injection_callback)
5757
}
58+
self.device = device
59+
self.dtype = vllm_config.model_config.dtype
5860

5961
self.epd_disagg_config: EPDDisaggConfig
6062
self.intra_instance_type: Literal["scheduler", "model-runner"]
6163
self.inter_instance_type: Literal["encode", "prefill",
6264
"prefill+decode"]
63-
self.encoder_cache: dict[str, dict[int, NDArray[np.float32]]]
65+
self.encoder_cache: dict[str, dict[int, torch.Tensor]]
6466
self.send_executors: ThreadPoolExecutor
6567
self.recv_executors: ThreadPoolExecutor
6668

@@ -143,7 +145,7 @@ def _send_encoder_cache_metas(self, request_id: str, input_id: int,
143145
@abstractmethod
144146
def _send_encoder_cache(
145147
self, request_id: str, input_id: int,
146-
encoder_cache: NDArray[np.float32], mm_hash: str
148+
encoder_cache: torch.Tensor, mm_hash: str
147149
) -> None:
148150
"""Send the encoder cache.
149151
@@ -204,7 +206,7 @@ def _recv_encoder_cache_metas(
204206
@abstractmethod
205207
def _recv_encoder_cache(
206208
self,
207-
injection_callback: Callable[[str, int, NDArray[np.float32], str],None]
209+
injection_callback: Callable[[str, int, torch.Tensor, str],None]
208210
) -> None:
209211
"""Receives the encoder cache and calls injection callback
210212
@@ -224,7 +226,7 @@ def _recv_encoder_cache(
224226
pass
225227

226228
def add_encoder_cache(self, request_id: str, input_id: int,
227-
encoder_cache: NDArray[np.float32], mm_hash: str):
229+
encoder_cache: torch.Tensor, mm_hash: str):
228230
"""Add an encoder cache to the EC connector.
229231
230232
This method adds the encoder cache to the self.encoder_cache dictionary
@@ -360,7 +362,7 @@ def schedule_send_encoder_cache_metadata(self, request_id: str,
360362

361363
def schedule_send_encoder_cache(
362364
self, request_id: str, input_id: int,
363-
encoder_cache: NDArray[np.float32], mm_hash: str
365+
encoder_cache: torch.Tensor, mm_hash: str
364366
) -> None:
365367
"""Schedule encoder cache sending
366368
@@ -377,7 +379,7 @@ def schedule_send_encoder_cache(
377379

378380
def _finish_wrapper(
379381
self, callback: Callable, request_id: str, input_id: int,
380-
encoder_cache: NDArray[np.float32], mm_hash: str
382+
encoder_cache: torch.Tensor, mm_hash: str
381383
):
382384

383385
callback(request_id, input_id, encoder_cache, mm_hash)

0 commit comments

Comments
 (0)