diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 0d44a7611aed..a970981b7562 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -156,3 +156,19 @@ async def test_request_cancellation(server: RemoteOpenAIServer): max_tokens=10) assert len(response.choices) == 1 + + +@pytest.mark.asyncio +async def test_request_wrong_content_type(server: RemoteOpenAIServer): + + chat_input = [{"role": "user", "content": "Write a long story"}] + client = server.get_async_client() + + with pytest.raises(openai.APIStatusError): + await client.chat.completions.create( + messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_headers={ + "Content-Type": "application/x-www-form-urlencoded" + }) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8f54d6c7804..06d7bb6c32d7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -20,7 +20,7 @@ from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -253,6 +253,15 @@ def _cleanup_ipc_path(): multiprocess.mark_process_dead(engine_process.pid) +async def validate_json_request(raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + if content_type != "application/json": + raise HTTPException( + status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + detail="Unsupported Media Type: Only 'application/json' is allowed" + ) + + router = APIRouter() @@ -336,7 +345,7 @@ async def ping(raw_request: Request) -> Response: return await health(raw_request) -@router.post("/tokenize") +@router.post("/tokenize", dependencies=[Depends(validate_json_request)]) @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -351,7 +360,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): assert_never(generator) -@router.post("/detokenize") +@router.post("/detokenize", dependencies=[Depends(validate_json_request)]) @with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -380,7 +389,8 @@ async def show_version(): return JSONResponse(content=ver) -@router.post("/v1/chat/completions") +@router.post("/v1/chat/completions", + dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): @@ -401,7 +411,7 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/completions") +@router.post("/v1/completions", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) @@ -419,7 +429,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/embeddings") +@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) @@ -465,7 +475,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) -@router.post("/pooling") +@router.post("/pooling", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_pooling(request: PoolingRequest, raw_request: Request): handler = pooling(raw_request) @@ -483,7 +493,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): assert_never(generator) -@router.post("/score") +@router.post("/score", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) @@ -501,7 +511,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/score") +@router.post("/v1/score", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( @@ -511,7 +521,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) -@router.post("/rerank") +@router.post("/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank(request: RerankRequest, raw_request: Request): handler = rerank(raw_request) @@ -528,7 +538,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/rerank") +@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank_v1(request: RerankRequest, raw_request: Request): logger.warning_once( @@ -539,7 +549,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) -@router.post("/v2/rerank") +@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) @@ -583,7 +593,7 @@ async def reset_prefix_cache(raw_request: Request): return Response(status_code=200) -@router.post("/invocations") +@router.post("/invocations", dependencies=[Depends(validate_json_request)]) async def invocations(raw_request: Request): """ For SageMaker, routes requests to other handlers based on model `task`. @@ -633,7 +643,8 @@ async def stop_profile(raw_request: Request): "Lora dynamic loading & unloading is enabled in the API server. " "This should ONLY be used for local development!") - @router.post("/v1/load_lora_adapter") + @router.post("/v1/load_lora_adapter", + dependencies=[Depends(validate_json_request)]) async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): handler = models(raw_request) @@ -644,7 +655,8 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest, return Response(status_code=200, content=response) - @router.post("/v1/unload_lora_adapter") + @router.post("/v1/unload_lora_adapter", + dependencies=[Depends(validate_json_request)]) async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): handler = models(raw_request)