Skip to content

Commit 8e02e5c

Browse files
committed
fix(request-audio): loop through model_names
Signed-off-by: Max Wittig <max.wittig@siemens.com>
1 parent ca9fdd2 commit 8e02e5c

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

src/vllm_router/service_discovery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class EndpointInfo:
9494
# Model label
9595
model_label: str
9696

97+
model_type: str
98+
9799
# Endpoint's sleep status
98100
sleep: bool
99101

@@ -306,13 +308,15 @@ def get_endpoint_info(self) -> List[EndpointInfo]:
306308
):
307309
continue
308310
model_label = self.model_labels[i] if self.model_labels else "default"
311+
model_type = self.model_types[i] if self.model_types else "default"
309312
endpoint_info = EndpointInfo(
310313
url=url,
311314
model_names=[model], # Convert single model to list
312315
Id=self.engines_id[i],
313316
sleep=False,
314317
added_timestamp=self.added_timestamp,
315318
model_label=model_label,
319+
model_type=model_type,
316320
model_info=self._get_model_info(model),
317321
)
318322
endpoint_infos.append(endpoint_info)

src/vllm_router/services/request_service/request.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,6 @@ async def route_general_transcriptions(
539539
content={"error": f"Invalid request: missing '{e.args[0]}' in form data."},
540540
)
541541

542-
logger.debug("==== Enter audio_transcriptions ====")
543542
logger.debug("Received upload: %s (%s)", file.filename, file.content_type)
544543
logger.debug(
545544
"Params: model=%s prompt=%r response_format=%r temperature=%r language=%s",
@@ -565,18 +564,16 @@ async def route_general_transcriptions(
565564

566565
endpoints = service_discovery.get_endpoint_info()
567566

568-
logger.debug("==== Total endpoints ====")
569-
logger.debug(endpoints)
570-
logger.debug("==== Total endpoints ====")
571-
572-
# filter the endpoints url by model name and label for transcriptions
573-
transcription_endpoints = [
574-
ep
575-
for ep in endpoints
576-
if model == ep.model_name
577-
and ep.model_label == "transcription"
578-
and not ep.sleep # Added ep.sleep == False
579-
]
567+
# filter the endpoints url by model name and model_type for transcriptions
568+
transcription_endpoints = []
569+
for ep in endpoints:
570+
for model_name in ep.model_names:
571+
if (
572+
model == model_name
573+
and ep.model_type == "transcription"
574+
and not ep.sleep
575+
):
576+
transcription_endpoints.append(ep)
580577

581578
logger.debug("====List of transcription endpoints====")
582579
logger.debug(transcription_endpoints)
@@ -620,10 +617,6 @@ async def route_general_transcriptions(
620617

621618
logger.info("Proxying transcription request for model %s to %s", model, chosen_url)
622619

623-
logger.debug("==== data payload keys ====")
624-
logger.debug(list(data.keys()))
625-
logger.debug("==== data payload keys ====")
626-
627620
try:
628621
client = request.app.state.aiohttp_client_wrapper()
629622

@@ -687,3 +680,9 @@ async def route_general_transcriptions(
687680
status_code=503,
688681
content={"error": f"Failed to connect to backend: {str(client_error)}"},
689682
)
683+
except Exception as e:
684+
logger.error(e)
685+
return JSONResponse(
686+
status_code=500,
687+
content={"error": f"Internal server error"},
688+
)

0 commit comments

Comments
 (0)