diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 31189916..d07935a5 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -137,7 +137,7 @@ jobs: sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env - name: Compose docker-compose.yml - run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -o development-compose.yml + run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-cpu.ci.yml -o development-compose.yml - name: GPU stack versions (non-fatal) shell: bash diff --git a/.gitignore b/.gitignore index f3d8ab42..8e8c0336 100644 --- a/.gitignore +++ b/.gitignore @@ -179,3 +179,4 @@ private_key.key.lock development-compose.yml production-compose.yml +docker/compose/docker-compose.gemma-4b-gpu.ci.yml diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 41ef9771..25066491 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -2,10 +2,6 @@ services: caddy: env_file: - .env - ports: - - "80:80" - - "443:443" - - "443:443/udp" volumes: - ./caddy/Caddyfile:/etc/caddy/Caddyfile api: diff --git a/docker/compose/docker-compose.gemma-27b-gpu.yml b/docker/compose/docker-compose.gemma-27b-gpu.yml new file mode 100644 index 00000000..754b44c3 --- /dev/null +++ b/docker/compose/docker-compose.gemma-27b-gpu.yml @@ -0,0 +1,45 @@ +services: + gemma_27b_gpu: + image: nillion/nilai-vllm:latest + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + ipc: host + ulimits: + memlock: -1 + stack: 67108864 + env_file: + - .env + restart: unless-stopped + depends_on: + etcd: + condition: service_healthy + command: > + --model google/gemma-3-27b-it + --gpu-memory-utilization 0.79 + --max-model-len 60000 + --max-num-batched-tokens 8192 + --dtype bfloat16 + --kv-cache-dtype fp8 + --uvicorn-log-level warning + environment: + - SVC_HOST=gemma_27b_gpu + - SVC_PORT=8000 + - ETCD_HOST=etcd + - ETCD_PORT=2379 + - TOOL_SUPPORT=false + - MULTIMODAL_SUPPORT=true + volumes: + - hugging_face_models:/root/.cache/huggingface + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + retries: 3 + start_period: 60s + timeout: 10s +volumes: + hugging_face_models: diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml new file mode 100644 index 00000000..29423275 --- /dev/null +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -0,0 +1,47 @@ +services: + gemma_4b_gpu: + image: nillion/nilai-vllm:latest + container_name: nilai-gemma_4b_gpu + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + ulimits: + memlock: -1 + stack: 67108864 + env_file: + - .env + restart: unless-stopped + depends_on: + etcd: + condition: service_healthy + command: > + --model google/gemma-3-4b-it + --max-model-len 30000 + --max-num-batched-tokens 8192 + + --uvicorn-log-level warning + environment: + - SVC_HOST=gemma_4b_gpu + - SVC_PORT=8000 + - ETCD_HOST=etcd + - ETCD_PORT=2379 + - TOOL_SUPPORT=false + - MULTIMODAL_SUPPORT=true + - CUDA_LAUNCH_BLOCKING=1 + - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 + - PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + volumes: + - hugging_face_models:/root/.cache/huggingface + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + retries: 3 + start_period: 60s + timeout: 10s +volumes: + hugging_face_models: diff --git a/docker/compose/docker-compose.llama-8b-gpu.yml b/docker/compose/docker-compose.llama-8b-gpu.yml index 18392928..7ecdba10 100644 --- a/docker/compose/docker-compose.llama-8b-gpu.yml +++ b/docker/compose/docker-compose.llama-8b-gpu.yml @@ -20,7 +20,7 @@ services: condition: service_healthy command: > --model meta-llama/Llama-3.1-8B-Instruct - --gpu-memory-utilization 0.21 + --gpu-memory-utilization 0.20 --max-model-len 10000 --max-num-batched-tokens 10000 --tensor-parallel-size 1 diff --git a/docker/compose/docker-compose.qwen-2b-gpu.ci.yml b/docker/compose/docker-compose.qwen-2b-gpu.ci.yml new file mode 100644 index 00000000..7d040caf --- /dev/null +++ b/docker/compose/docker-compose.qwen-2b-gpu.ci.yml @@ -0,0 +1,64 @@ +version: "3.8" + +services: + c: + image: nillion/nilai-vllm:latest + container_name: qwen2vl_2b_gpu + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + ulimits: + memlock: -1 + stack: 67108864 + env_file: + - .env + restart: unless-stopped + depends_on: + etcd: + condition: service_healthy + command: + [ + "--model", "Qwen/Qwen2-VL-2B-Instruct-AWQ", + "--model-impl", "vllm", + "--tensor-parallel-size", "1", + "--trust-remote-code", + "--quantization", "awq", + + "--max-model-len", "1280", + "--max-num-batched-tokens", "1280", + "--max-num-seqs", "1", + + "--gpu-memory-utilization", "0.75", + "--swap-space", "8", + "--uvicorn-log-level", "warning", + + "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}", + "--skip-mm-profiling", + "--enforce-eager" + ] + + environment: + SVC_HOST: qwen2vl_2b_gpu + SVC_PORT: "8000" + ETCD_HOST: etcd + ETCD_PORT: "2379" + TOOL_SUPPORT: "true" + MULTIMODAL_SUPPORT: "true" + CUDA_LAUNCH_BLOCKING: "1" + VLLM_ALLOW_LONG_MAX_MODEL_LEN: "1" + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + volumes: + - hugging_face_models:/root/.cache/huggingface + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + retries: 3 + start_period: 60s + timeout: 10s + +volumes: + hugging_face_models: diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index b55ac18b..630d9088 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -1,11 +1,11 @@ import logging +from typing import Union import nilrag -from nilai_common import ChatRequest, Message +from nilai_common import ChatRequest, MessageAdapter from fastapi import HTTPException, status from sentence_transformers import SentenceTransformer -from typing import Union logger = logging.getLogger(__name__) @@ -63,13 +63,9 @@ async def handle_nilrag(req: ChatRequest): # Get user query logger.debug("Extracting user query") - query = None - for message in req.messages: - if message.role == "user": - query = message.content - break + query = req.get_last_user_query() - if query is None: + if not query: raise HTTPException(status_code=400, detail="No user query found") # Get number of chunks to include @@ -85,20 +81,25 @@ async def handle_nilrag(req: ChatRequest): relevant_context = f"\n\nRelevant Context:\n{formatted_results}" # Step 4: Update system message - for message in req.messages: + for message in req.adapted_messages: if message.role == "system": - if message.content is None: + content = message.content + if content is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="system message is empty", ) - message.content += ( - relevant_context # Append the context to the system message - ) + + if isinstance(content, str): + message.content = content + relevant_context + elif isinstance(content, list): + content.append({"type": "text", "text": relevant_context}) break else: # If no system message exists, add one - req.messages.insert(0, Message(role="system", content=relevant_context)) + req.messages.insert( + 0, MessageAdapter.new_message(role="system", content=relevant_context) + ) logger.debug(f"System message updated with relevant context:\n {req.messages}") diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index bd92caeb..3dfff7af 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -7,12 +7,14 @@ from nilai_api.config import WEB_SEARCH_SETTINGS from nilai_common.api_model import ( + ChatRequest, + Message, + MessageAdapter, SearchResult, Source, WebSearchEnhancedMessages, WebSearchContext, ) -from nilai_common import Message logger = logging.getLogger(__name__) @@ -61,26 +63,50 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Missing BRAVE_SEARCH_API key in environment", ) + q = " ".join(query.split()) q = " ".join(q.split()[:50])[:400] + params = {**_BRAVE_API_PARAMS_BASE, "q": q} headers = { **_BRAVE_API_HEADERS, "X-Subscription-Token": WEB_SEARCH_SETTINGS.api_key, } + client = _get_http_client() + logger.info("Brave API request start") + logger.debug( + "Brave API params assembled q_len=%d country=%s lang=%s count=%s", + len(q), + params.get("country"), + params.get("lang"), + params.get("count"), + ) resp = await client.get( WEB_SEARCH_SETTINGS.api_path, headers=headers, params=params ) + if resp.status_code >= 400: logger.error("Brave API error: %s - %s", resp.status_code, resp.text) - error = HTTPException( + raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Web search failed, service currently unavailable", ) - error.status_code = 503 - raise error - return resp.json() + + try: + data = resp.json() + logger.info("Brave API request success") + logger.debug( + "Brave API response keys=%s", + list(data.keys()) if isinstance(data, dict) else type(data).__name__, + ) + return data + except Exception: + logger.exception("Failed to parse Brave API JSON") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Web search failed: invalid response from provider", + ) def _parse_brave_results(data: Dict[str, Any]) -> List[SearchResult]: @@ -94,6 +120,7 @@ def _parse_brave_results(data: Dict[str, Any]) -> List[SearchResult]: """ web_block = data.get("web", {}) if isinstance(data, dict) else {} raw_results = web_block.get("results", []) if isinstance(web_block, dict) else [] + results: List[SearchResult] = [] for item in raw_results: if not isinstance(item, dict): @@ -101,14 +128,10 @@ def _parse_brave_results(data: Dict[str, Any]) -> List[SearchResult]: title = item.get("title", "")[:200] body = item.get("description") or item.get("snippet") or item.get("body", "") url = item.get("url") or item.get("link") or item.get("href", "") + if title and body and url: - results.append( - SearchResult( - title=title, - body=str(body)[:500], - url=str(url)[:500], - ) - ) + results.append(SearchResult(title=title, body=body, url=url)) + logger.debug("Parsed brave results count=%d", len(results)) return results @@ -130,6 +153,8 @@ async def perform_web_search_async(query: str) -> WebSearchContext: detail="Web search requested with an empty query", ) + logger.info("Web search start") + logger.debug("Web search raw query len=%d", len(query)) data = await _make_brave_api_request(query) results = _parse_brave_results(data) @@ -139,20 +164,28 @@ async def perform_web_search_async(query: str) -> WebSearchContext: detail="No web results found", ) + logger.info("Web search results ready count=%d", len(results)) lines = [ f"[{idx}] {r.title}\nURL: {r.url}\nSnippet: {r.body}" for idx, r in enumerate(results, start=1) ] prompt = "\n".join(lines) - sources = [Source(source=r.url, content=r.body) for r in results] + sources = [Source(source=r.url, content=r.body) for r in results] return WebSearchContext(prompt=prompt, sources=sources) +def _get_role_and_content(msg): + if isinstance(msg, dict): + return msg.get("role"), msg.get("content") + # If some SDK returns an object + return getattr(msg, "role", None), getattr(msg, "content", None) + + async def enhance_messages_with_web_search( messages: List[Message], query: str ) -> WebSearchEnhancedMessages: - """Enhance a list of messages with web search context. + """Enhance a list of messages with web search context.Collapse commentComment on line L155jcabrero commented on Aug 26, 2025 jcabreroon Aug 26, 2025MemberDeleted docstring?Write a replyResolve commentCode has comments. Press enter to view. Args: messages: List of conversation messages to enhance @@ -163,8 +196,49 @@ async def enhance_messages_with_web_search( context prepended as a system message, along with source information """ ctx = await perform_web_search_async(query) - enhanced = [Message(role="system", content=ctx.prompt)] + messages - query_source = Source(source="search_query", content=query) + query_source = Source(source="web_search_query", content=query) + + web_search_content = ( + f'You have access to the following web search results for the query: "{query}"\n\n' + "Use this information to provide accurate and up-to-date answers. " + "Cite the sources when appropriate.\n\n" + "Web Search Results:\n" + f"{ctx.prompt}\n\n" + "Please provide a comprehensive answer based on the search results above." + ) + + enhanced: List[Message] = [] + system_message_added = False + + for msg in messages: + adapted_message = MessageAdapter(raw=msg) + + if adapted_message.role == "system" and not system_message_added: + if isinstance(adapted_message.content, str): + combined_content_str = ( + adapted_message.content + "\n\n" + web_search_content + ) + elif isinstance(adapted_message.content, list): + # content is likely a list of parts (for multimodal); append a text part + parts = list(adapted_message.content) + parts.append({"type": "text", "text": "\n\n" + web_search_content}) + combined_content_str = parts + else: + combined_content_str = web_search_content + enhanced.append( + MessageAdapter.new_message(role="system", content=combined_content_str) + ) + system_message_added = True + else: + # Re-append in dict form + + enhanced.append(adapted_message.to_openai_param()) + + if not system_message_added: + enhanced.insert( + 0, MessageAdapter.new_message(role="system", content=web_search_content) + ) + return WebSearchEnhancedMessages( messages=enhanced, sources=[query_source] + ctx.sources, @@ -174,51 +248,58 @@ async def enhance_messages_with_web_search( async def generate_search_query_from_llm( user_message: str, model_name: str, client ) -> str: - system_prompt = """ - You are given a user question. Your task is to generate a concise web search query that will best retrieve information to answer the question. If the user’s question is already optimal, simply repeat it as the query. This is essentially summarization, paraphrasing, and key term extraction. - - - Do not add guiding elements or assumptions that the user did not explicitly request. - - Do not answer the query. - - The query must contain at least 10 words. - - Output only the search query. - - ### Example - - **User:** Who won the Roland Garros Open in 2024? Just reply with the winner's name. - **Search query:** Roland Garros 2024 tennis tournament winner men women champion """ + Use the LLM to produce a concise, high-recall search query. + """ + system_prompt = ( + "You are given a user question. Generate a concise web search query that will best retrieve information " + "to answer the question. If the user’s question is already optimal, repeat it exactly.\n" + "- Do not add assumptions not present in the question.\n" + "- Do not answer the question.\n" + "- The query must contain at least 10 words.\n" + "Output only the search query." + ) + messages = [ - Message(role="system", content=system_prompt), - Message(role="user", content=user_message), + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, ] + req = { "model": model_name, - "messages": [m.model_dump() for m in messages], + "messages": messages, "max_tokens": 150, } + + logger.info("Generate search query start model=%s", model_name) + logger.debug( + "User message len=%d", len(user_message) if isinstance(user_message, str) else 0 + ) try: response = await client.chat.completions.create(**req) except Exception as exc: - raise RuntimeError(f"Failed to generate search query: {str(exc)}") from exc - - if not response.choices: - raise RuntimeError("LLM returned an empty search query") + logger.exception("LLM call failed") + raise RuntimeError(f"Failed to generate search query: {exc}") from exc try: - content = response.choices[0].message.content.strip() - except (AttributeError, IndexError, TypeError) as exc: - raise RuntimeError(f"Invalid response structure from LLM: {str(exc)}") from exc + choices = getattr(response, "choices", None) or [] + msg = choices[0].message + content = (getattr(msg, "content", None) or "").strip() + except Exception as exc: + logger.exception("Invalid LLM response structure") + raise RuntimeError(f"Invalid response structure from LLM: {exc}") from exc if not content: + logger.error("LLM returned empty search query") raise RuntimeError("LLM returned an empty search query") - logger.debug("Generated search query: %s", content) - + logger.info("Generate search query success") + logger.debug("Generated query len=%d", len(content)) return content async def handle_web_search( - req_messages: List[Message], model_name: str, client + req_messages: ChatRequest, model_name: str, client ) -> WebSearchEnhancedMessages: """Handle web search enhancement for a conversation. @@ -234,18 +315,30 @@ async def handle_web_search( WebSearchEnhancedMessages with web search context added, or original messages if no user query is found or search fails """ - user_query = "" - for message in reversed(req_messages): - if message.role == "user": - user_query = message.content - break + logger.info("Handle web search start") + logger.debug( + "Handle web search messages_in=%d model=%s", + len(req_messages.messages), + model_name, + ) + user_query = req_messages.get_last_user_query() if not user_query: - return WebSearchEnhancedMessages(messages=req_messages, sources=[]) + logger.info("No user query found") + return WebSearchEnhancedMessages(messages=req_messages.messages, sources=[]) + try: concise_query = await generate_search_query_from_llm( user_query, model_name, client ) - return await enhance_messages_with_web_search(req_messages, concise_query) + logger.info("Enhancing messages with web search context") + return await enhance_messages_with_web_search( + req_messages.messages, concise_query + ) + + except HTTPException: + logger.exception("Web search provider error") + return WebSearchEnhancedMessages(messages=req_messages.messages, sources=[]) + except Exception: - logger.warning("Web search enhancement failed") - return WebSearchEnhancedMessages(messages=req_messages, sources=[]) + logger.exception("Unexpected error during web search handling") + return WebSearchEnhancedMessages(messages=req_messages.messages, sources=[]) diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index a56fe019..8ca56a16 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -1,4 +1,3 @@ -import asyncio from asyncio import iscoroutine from typing import Callable, Tuple, Awaitable, Annotated @@ -9,6 +8,7 @@ from fastapi import status, HTTPException, Request from redis.asyncio import from_url, Redis + from nilai_api.auth import get_auth_info, AuthenticationInfo, TokenRateLimits LUA_RATE_LIMIT_SCRIPT = """ @@ -139,7 +139,6 @@ async def __call__( # The value is the usage limit # The expiration is the time remaining in validity of the token # We use the time remaining to check if the token rate limit is exceeded - for limit in user_limits.token_rate_limit.limits: await self.check_bucket( redis, @@ -163,10 +162,10 @@ async def __call__( // WEB_SEARCH_SETTINGS.count, ), ) - await self.wait_for_bucket( + await self.check_bucket( redis, redis_rate_limit_command, - "global:web_search:rps", + f"web_search_rps:{user_limits.subscription_holder}", allowed_rps, 1000, ) @@ -212,24 +211,6 @@ async def check_bucket( headers={"Retry-After": str(expire)}, ) - @staticmethod - async def wait_for_bucket( - redis: Redis, - redis_rate_limit_command: str, - key: str, - times: int | None, - milliseconds: int, - ): - if times is None: - return - while True: - expire = await redis.evalsha( - redis_rate_limit_command, 1, key, str(times), str(milliseconds) - ) # type: ignore - if int(expire) == 0: - return - await asyncio.sleep((int(expire) + 50) / 1000) - async def check_concurrent_and_increment( self, redis: Redis, request: Request ) -> str | None: diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 501f384c..376e6570 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -1,6 +1,8 @@ # Fast API and serving import asyncio import logging +import time +import uuid from base64 import b64encode from typing import AsyncGenerator, Optional, Union, List, Tuple from nilai_api.attestation import get_attestation_report @@ -21,15 +23,18 @@ from nilai_common import ( AttestationReport, ChatRequest, - Message, ModelMetadata, + MessageAdapter, SignedChatCompletion, Nonce, Source, Usage, ) + + from openai import AsyncOpenAI + logger = logging.getLogger(__name__) router = APIRouter() @@ -131,8 +136,10 @@ async def chat_completion( ChatRequest( model="meta-llama/Llama-3.2-1B-Instruct", messages=[ - Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="What is your name?"), + MessageAdapter.new_message( + role="system", content="You are a helpful assistant." + ), + MessageAdapter.new_message(role="user", content="What is your name?"), ], ) ), @@ -191,13 +198,21 @@ async def chat_completion( model="meta-llama/Llama-3.2-1B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "What's the latest news about AI?"} + {"role": "user", "content": "What is your name?"} ], ) response = await chat_completion(request, user) """ + if len(req.messages) == 0: + raise HTTPException( + status_code=400, + detail="Request contained 0 messages", + ) model_name = req.model + request_id = str(uuid.uuid4()) + t_start = time.monotonic() + logger.info(f"[chat] call start request_id={req.messages}") endpoint = await state.get_model(model_name) if endpoint is None: raise HTTPException( @@ -210,43 +225,70 @@ async def chat_completion( status_code=400, detail="Model does not support tool usage, remove tools from request", ) + + has_multimodal = req.has_multimodal_content() + logger.info(f"[chat] has_multimodal: {has_multimodal}") + if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): + raise HTTPException( + status_code=400, + detail="Model does not support multimodal content, remove image inputs from request", + ) + model_url = endpoint.url + "/v1/" logger.info( - f"Chat completion request for model {model_name} from user {auth_info.user.userid} on url: {model_url}" + f"[chat] start request_id={request_id} user={auth_info.user.userid} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" ) client = AsyncOpenAI(base_url=model_url, api_key="") if req.nilrag: + logger.info(f"[chat] nilrag start request_id={request_id}") + t_nilrag = time.monotonic() await handle_nilrag(req) + logger.info( + f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" + ) messages = req.messages sources: Optional[List[Source]] = None + if req.web_search: - web_search_result = await handle_web_search(messages, model_name, client) + logger.info(f"[chat] web_search start request_id={request_id}") + t_ws = time.monotonic() + web_search_result = await handle_web_search(req, model_name, client) messages = web_search_result.messages sources = web_search_result.sources + logger.info( + f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" + ) + logger.info(f"[chat] web_search messages: {messages}") if req.stream: # Forwarding Streamed Responses async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: try: - response = await client.chat.completions.create( - model=req.model, - messages=messages, # type: ignore - stream=req.stream, # type: ignore - top_p=req.top_p, - temperature=req.temperature, - max_tokens=req.max_tokens, - tools=req.tools, # type: ignore - extra_body={ + logger.info(f"[chat] stream start request_id={request_id}") + t_call = time.monotonic() + current_messages = messages + request_kwargs = { + "model": req.model, + "messages": current_messages, # type: ignore + "stream": True, # type: ignore + "top_p": req.top_p, + "temperature": req.temperature, + "max_tokens": req.max_tokens, + "extra_body": { "stream_options": { "include_usage": True, "continuous_usage_stats": True, } }, - ) # type: ignore + } + if req.tools: + request_kwargs["tools"] = req.tools # type: ignore + + response = await client.chat.completions.create(**request_kwargs) # type: ignore prompt_token_usage: int = 0 completion_token_usage: int = 0 async for chunk in response: @@ -274,9 +316,12 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: prompt_tokens=prompt_token_usage, completion_tokens=completion_token_usage, ) + logger.info( + f"[chat] stream done request_id={request_id} prompt_tokens={prompt_token_usage} completion_tokens={completion_token_usage} duration_ms={(time.monotonic() - t_call) * 1000:.0f} total_ms={(time.monotonic() - t_start) * 1000:.0f}" + ) except Exception as e: - logger.error(f"Error streaming response: {e}") + logger.error(f"[chat] stream error request_id={request_id} error={e}") return # Return the streaming response @@ -284,21 +329,32 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: chat_completion_stream_generator(), media_type="text/event-stream", # Ensure client interprets as Server-Sent Events ) - response = await client.chat.completions.create( - model=req.model, - messages=messages, # type: ignore - stream=req.stream, - top_p=req.top_p, - temperature=req.temperature, - max_tokens=req.max_tokens, - tools=req.tools, # type: ignore - ) # type: ignore - + current_messages = messages + request_kwargs = { + "model": req.model, + "messages": current_messages, # type: ignore + "top_p": req.top_p, + "temperature": req.temperature, + "max_tokens": req.max_tokens, + } + if req.tools: + request_kwargs["tools"] = req.tools # type: ignore + logger.info(f"[chat] call start request_id={request_id}") + logger.info(f"[chat] call message: {current_messages}") + t_call = time.monotonic() + response = await client.chat.completions.create(**request_kwargs) # type: ignore + logger.info( + f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) + logger.info(f"[chat] call response: {response}") model_response = SignedChatCompletion( **response.model_dump(), signature="", sources=sources, ) + logger.info( + f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) if model_response.usage is None: raise HTTPException( @@ -324,4 +380,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: signature = sign_message(state.private_key, response_json) model_response.signature = b64encode(signature).decode() + logger.info( + f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" + ) return model_response diff --git a/nilai-models/src/nilai_models/daemon.py b/nilai-models/src/nilai_models/daemon.py index 7402f3f3..b37a5497 100644 --- a/nilai-models/src/nilai_models/daemon.py +++ b/nilai-models/src/nilai_models/daemon.py @@ -28,7 +28,12 @@ async def get_metadata(num_retries=30): response.raise_for_status() response_data = response.json() model_name = response_data["data"][0]["id"] - return ModelMetadata( + + supported_features = ["chat_completion"] + if SETTINGS.multimodal_support: + supported_features.append("multimodal") + + metadata = ModelMetadata( id=model_name, # Unique identifier name=model_name, # Human-readable name version="1.0", # Model version @@ -36,10 +41,13 @@ async def get_metadata(num_retries=30): author="", # Model creators license="Apache 2.0", # Usage license source=f"https://huggingface.co/{model_name}", # Model source - supported_features=["chat_completion"], # Capabilities + supported_features=supported_features, # Capabilities tool_support=SETTINGS.tool_support, # Tool support + multimodal_support=SETTINGS.multimodal_support, # Multimodal support ) + return metadata + except Exception as e: if not url: logger.warning(f"Failed to build url: {e}") diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index 56edbf56..e29eef27 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -4,7 +4,6 @@ SignedChatCompletion, Choice, HealthCheckResponse, - Message, ModelEndpoint, ModelMetadata, Nonce, @@ -14,6 +13,8 @@ Source, WebSearchEnhancedMessages, WebSearchContext, + Message, + MessageAdapter, ) from nilai_common.config import SETTINGS from nilai_common.discovery import ModelServiceDiscovery @@ -21,6 +22,7 @@ __all__ = [ "Message", + "MessageAdapter", "ChatRequest", "SignedChatCompletion", "Choice", diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index a3ae1e81..3bb9485f 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -1,17 +1,41 @@ -import uuid +from __future__ import annotations -from typing import Annotated, Iterable, List, Literal, Optional +import uuid +from typing import ( + Annotated, + Iterable, + List, + Optional, + Any, + cast, + TypeAlias, + Literal, + Union, +) -from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessageParam, + ChatCompletionToolParam, + ChatCompletionMessage, +) +from openai.types.chat.chat_completion_content_part_text_param import ( + ChatCompletionContentPartTextParam, +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ChatCompletionContentPartImageParam, +) from openai.types.chat.chat_completion import Choice as OpenaAIChoice -from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel, Field -class Message(ChatCompletionMessage): - role: Literal["system", "user", "assistant", "tool"] # type: ignore +# ---------- Aliases from the OpenAI SDK ---------- +ImageContent: TypeAlias = ChatCompletionContentPartImageParam +TextContent: TypeAlias = ChatCompletionContentPartTextParam +Message: TypeAlias = ChatCompletionMessageParam # SDK union of message shapes +# ---------- Models you already had ---------- class Choice(OpenaAIChoice): pass @@ -27,6 +51,103 @@ class SearchResult(BaseModel): url: str +# ---------- Helpers ---------- +def _extract_text_from_content(content: Any) -> Optional[str]: + """ + - If content is a str -> return it (stripped) if non-empty. + - If content is a list of content parts -> concatenate 'text' parts. + - Else -> None. + """ + if isinstance(content, str): + s = content.strip() + return s or None + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + t = part.get("text") + if isinstance(t, str) and t.strip(): + parts.append(t.strip()) + if parts: + return "\n".join(parts) + return None + + +# ---------- Adapter over the raw SDK message ---------- +class MessageAdapter(BaseModel): + """Thin wrapper around an OpenAI ChatCompletionMessageParam with convenience methods.""" + + raw: Message + + @property + def role(self) -> str: + return cast(str, self.raw.get("role")) + + @role.setter + def role( + self, + value: Literal["developer", "user", "system", "assistant", "tool", "function"], + ) -> None: + if not isinstance(value, str): + raise TypeError("role must be a string") + # Update the underlying SDK message dict + # Cast to Any to bypass TypedDict restrictions + cast(Any, self.raw)["role"] = value + + @property + def content(self) -> Any: + return self.raw.get("content") + + @content.setter + def content(self, value: Any) -> None: + # Update the underlying SDK message dict + # Cast to Any to bypass TypedDict restrictions + cast(Any, self.raw)["content"] = value + + @staticmethod + def new_message( + role: Literal["developer", "user", "system", "assistant", "tool", "function"], + content: Union[str, List[Any]], + ) -> Message: + message: Message = cast(Message, {"role": role, "content": content}) + return message + + @staticmethod + def new_completion_message(content: str) -> ChatCompletionMessage: + message: ChatCompletionMessage = cast( + ChatCompletionMessage, {"role": "assistant", "content": content} + ) + return message + + def is_text_part(self) -> bool: + return _extract_text_from_content(self.content) is not None + + def is_multimodal_part(self) -> bool: + c = self.content + if isinstance(c, str): + return False + + for part in c: + if isinstance(part, dict) and part.get("type") in ( + "image_url", + "input_image", + ): + return True + return False + + def extract_text(self) -> Optional[str]: + return _extract_text_from_content(self.content) + + def to_openai_param(self) -> Message: + # Return the original dict for API calls. + return self.raw + + +def adapt_messages(msgs: List[Message]) -> List[MessageAdapter]: + return [MessageAdapter(raw=m) for m in msgs] + + +# ---------- Your additional containers ---------- class WebSearchEnhancedMessages(BaseModel): messages: List[Message] sources: List[Source] @@ -39,6 +160,7 @@ class WebSearchContext(BaseModel): sources: List[Source] +# ---------- Request/response models ---------- class ChatRequest(BaseModel): model: str messages: List[Message] = Field(..., min_length=1) @@ -53,6 +175,36 @@ class ChatRequest(BaseModel): description="Enable web search to enhance context with current information", ) + def model_post_init(self, __context) -> None: + # Process messages after model initialization + for i, msg in enumerate(self.messages): + content = msg.get("content") + if ( + content is not None + and hasattr(content, "__iter__") + and hasattr(content, "__next__") + ): + # Convert iterator to list in place + cast(Any, msg)["content"] = list(content) + + @property + def adapted_messages(self) -> List[MessageAdapter]: + return adapt_messages(self.messages) + + def get_last_user_query(self) -> Optional[str]: + """ + Returns the latest non-empty user text (plain or from content parts), + or None if not found. + """ + for m in reversed(self.adapted_messages): + if m.role == "user" and m.is_text_part(): + return m.extract_text() + return None + + def has_multimodal_content(self) -> bool: + """True if any message contains an image content part.""" + return any([m.is_multimodal_part() for m in self.adapted_messages]) + class SignedChatCompletion(ChatCompletion): signature: str @@ -71,6 +223,7 @@ class ModelMetadata(BaseModel): source: str supported_features: List[str] tool_support: bool + multimodal_support: bool = False class ModelEndpoint(BaseModel): @@ -83,6 +236,7 @@ class HealthCheckResponse(BaseModel): uptime: str +# ---------- Attestation ---------- Nonce = Annotated[ str, Field( diff --git a/packages/nilai-common/src/nilai_common/config.py b/packages/nilai-common/src/nilai_common/config.py index aad86b01..cc0e6bdc 100644 --- a/packages/nilai-common/src/nilai_common/config.py +++ b/packages/nilai-common/src/nilai_common/config.py @@ -8,6 +8,7 @@ class HostSettings(BaseModel): etcd_host: str = "localhost" etcd_port: int = 2379 tool_support: bool = False + multimodal_support: bool = False gunicorn_workers: int = 10 attestation_host: str = "localhost" attestation_port: int = 8081 @@ -19,6 +20,7 @@ class HostSettings(BaseModel): etcd_host=str(os.getenv("ETCD_HOST", "localhost")), etcd_port=int(os.getenv("ETCD_PORT", 2379)), tool_support=bool(os.getenv("TOOL_SUPPORT", False)), + multimodal_support=bool(os.getenv("MULTIMODAL_SUPPORT", False)), gunicorn_workers=int(os.getenv("NILAI_GUNICORN_WORKERS", 10)), attestation_host=str(os.getenv("ATTESTATION_HOST", "localhost")), attestation_port=int(os.getenv("ATTESTATION_PORT", 8081)), diff --git a/scripts/wait_for_ci_services.sh b/scripts/wait_for_ci_services.sh index 163fc50c..36b2a75e 100755 --- a/scripts/wait_for_ci_services.sh +++ b/scripts/wait_for_ci_services.sh @@ -2,16 +2,33 @@ # Wait for the services to be ready API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) -MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-llama_1b_gpu 2>/dev/null) +MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-qwen2vl_2b_gpu 2>/dev/null) NUC_API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-nuc-api 2>/dev/null) MAX_ATTEMPTS=30 ATTEMPT=1 while [ $ATTEMPT -le $MAX_ATTEMPTS ]; do echo "Waiting for nilai to become healthy... API:[$API_HEALTH_STATUS] MODEL:[$MODEL_HEALTH_STATUS] NUC_API:[$NUC_API_HEALTH_STATUS] (Attempt $ATTEMPT/$MAX_ATTEMPTS)" + + # Check if any service is unhealthy and print logs + if [ "$API_HEALTH_STATUS" = "unhealthy" ]; then + echo "=== nilai-api is unhealthy, printing logs ===" + docker logs nilai-api --tail 50 + fi + + if [ "$MODEL_HEALTH_STATUS" = "unhealthy" ]; then + echo "=== nilai-qwen2vl_2b_gpu is unhealthy, printing logs ===" + docker logs nilai-qwen2vl_2b_gpu --tail 50 + fi + + if [ "$NUC_API_HEALTH_STATUS" = "unhealthy" ]; then + echo "=== nilai-nuc-api is unhealthy, printing logs ===" + docker logs nilai-nuc-api --tail 50 + fi + sleep 30 API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) - MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-llama_1b_gpu 2>/dev/null) + MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-qwen2vl_2b_gpu 2>/dev/null) NUC_API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-nuc-api 2>/dev/null) if [ "$API_HEALTH_STATUS" = "healthy" ] && [ "$MODEL_HEALTH_STATUS" = "healthy" ] && [ "$NUC_API_HEALTH_STATUS" = "healthy" ]; then break @@ -23,17 +40,23 @@ done echo "API_HEALTH_STATUS: $API_HEALTH_STATUS" if [ "$API_HEALTH_STATUS" != "healthy" ]; then echo "Error: nilai-api failed to become healthy after $MAX_ATTEMPTS attempts" + echo "=== Final logs for nilai-api ===" + docker logs nilai-api --tail 100 exit 1 fi echo "MODEL_HEALTH_STATUS: $MODEL_HEALTH_STATUS" if [ "$MODEL_HEALTH_STATUS" != "healthy" ]; then - echo "Error: nilai-llama_1b_gpu failed to become healthy after $MAX_ATTEMPTS attempts" + echo "Error: nilai-qwen2vl_2b_gpu failed to become healthy after $MAX_ATTEMPTS attempts" + echo "=== Final logs for nilai-qwen2vl_2b_gpu ===" + docker logs nilai-qwen2vl_2b_gpu --tail 100 exit 1 fi echo "NUC_API_HEALTH_STATUS: $NUC_API_HEALTH_STATUS" if [ "$NUC_API_HEALTH_STATUS" != "healthy" ]; then echo "Error: nilai-nuc-api failed to become healthy after $MAX_ATTEMPTS attempts" + echo "=== Final logs for nilai-nuc-api ===" + docker logs nilai-nuc-api --tail 100 exit 1 fi diff --git a/tests/e2e/config.py b/tests/e2e/config.py index c49eff97..7caae4a4 100644 --- a/tests/e2e/config.py +++ b/tests/e2e/config.py @@ -32,9 +32,7 @@ def api_key_getter(): "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", ], - "ci": [ - "meta-llama/Llama-3.2-1B-Instruct", - ], + "ci": ["google/gemma-3-4b-it"], } if ENVIRONMENT not in models: diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 5c72d690..1314c721 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -867,88 +867,226 @@ def make_request(): ) -def test_web_search_queueing_next_second_e2e(client): - """Test that web search requests are properly queued and processed in batches.""" - import threading - import time - import openai - from concurrent.futures import ThreadPoolExecutor, as_completed +def test_multimodal_single_request(client): + """Test multimodal chat completion with a single request using Qwen/Qwen2-VL-2B-Instruct-AWQ model""" + if "Qwen/Qwen2-VL-2B-Instruct-AWQ" not in test_models: + pytest.skip("Multimodal test only runs for Qwen/Qwen2-VL-2B-Instruct-AWQ model") - request_barrier = threading.Barrier(25) - responses = [] - start_time = None + try: + # Create a simple base64 encoded image (1x1 pixel red PNG) + response = client.chat.completions.create( + model="Qwen/Qwen2-VL-2B-Instruct-AWQ", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + temperature=0.2, + max_tokens=100, + ) - def make_request(): - request_barrier.wait() + # Verify response structure + assert isinstance(response, ChatCompletion), ( + "Response should be a ChatCompletion object" + ) + assert response.model == "Qwen/Qwen2-VL-2B-Instruct-AWQ", ( + "Response model should be Qwen/Qwen2-VL-2B-Instruct-AWQ" + ) + assert len(response.choices) > 0, "Response should contain at least one choice" - nonlocal start_time - if start_time is None: - start_time = time.time() + # Check content + content = response.choices[0].message.content + assert content is not None, "Content should not be null" + assert content.strip() != "", "Content should not be empty" - try: - response = client.chat.completions.create( - model=test_models[0], - messages=[{"role": "user", "content": "What is the weather like?"}], - extra_body={"web_search": True}, - max_tokens=10, - temperature=0.0, - ) - completion_time = time.time() - start_time - responses.append((completion_time, response, "success")) - except openai.RateLimitError as e: - completion_time = time.time() - start_time - responses.append((completion_time, e, "rate_limited")) - except Exception as e: - completion_time = time.time() - start_time - responses.append((completion_time, e, "error")) + print( + f"\nMultimodal single request response: {content[:100]}..." + if len(content) > 100 + else content + ) - with ThreadPoolExecutor(max_workers=25) as executor: - futures = [executor.submit(make_request) for _ in range(25)] + assert response.usage, "No usage data returned for multimodal request" + print(f"Multimodal usage: {response.usage}") - for future in as_completed(futures): - try: - future.result() - except Exception as e: - print(f"Thread execution error: {e}") + assert response.usage.prompt_tokens > 0, ( + "No prompt tokens returned for multimodal request" + ) + assert response.usage.completion_tokens > 0, ( + "No completion tokens returned for multimodal request" + ) + assert response.usage.total_tokens > 0, ( + "No total tokens returned for multimodal request" + ) - assert len(responses) == 25, "All requests should complete" + except Exception as e: + pytest.fail(f"Error testing multimodal single request: {str(e)}") - # Categorize responses - successful_responses = [(t, r) for t, r, status in responses if status == "success"] - rate_limited_responses = [ - (t, r) for t, r, status in responses if status == "rate_limited" - ] - error_responses = [(t, r) for t, r, status in responses if status == "error"] - print( - f"Successful: {len(successful_responses)}, Rate limited: {len(rate_limited_responses)}, Errors: {len(error_responses)}" - ) +def test_multimodal_consecutive_requests(client): + """Test two consecutive multimodal chat completions using Qwen/Qwen2-VL-2B-Instruct-AWQ model""" + if "Qwen/Qwen2-VL-2B-Instruct-AWQ" not in test_models: + pytest.skip("Multimodal test only runs for Qwen/Qwen2-VL-2B-Instruct-AWQ model") - # Verify queuing behavior - # With 25 requests and 20 RPS limit, some should be queued or rate limited - assert len(rate_limited_responses) > 0 or len(successful_responses) < 25, ( - "Queuing should be enforced - either some requests should be rate limited or delayed" - ) + try: + # Create a simple base64 encoded image (1x1 pixel red PNG) - for t, response in successful_responses: - assert isinstance(response, ChatCompletion), ( - "Response should be a ChatCompletion object" + # First multimodal request + response1 = client.chat.completions.create( + model="google/gemma-3-4b-it", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + temperature=0.2, + max_tokens=50, ) - assert len(response.choices) > 0, "Response should contain at least one choice" - assert response.choices[0].message.content, "Response should contain content" - sources = getattr(response, "sources", None) - assert sources is not None, "Web search responses should have sources" - assert isinstance(sources, list), "Sources should be a list" - assert len(sources) > 0, "Sources should not be empty" + # Verify first response + assert isinstance(response1, ChatCompletion), ( + "First response should be a ChatCompletion object" + ) + assert response1.model == "google/gemma-3-4b-it", ( + "First response model should be google/gemma-3-4b-it" + ) + assert len(response1.choices) > 0, ( + "First response should contain at least one choice" + ) - first_source = sources[0] - assert isinstance(first_source, dict), "First source should be a dictionary" - assert "title" in first_source, "First source should have title" - assert "url" in first_source, "First source should have url" - assert "snippet" in first_source, "First source should have snippet" + content1 = response1.choices[0].message.content + assert content1 is not None, "First response content should not be null" + assert content1.strip() != "", "First response content should not be empty" - for t, error in rate_limited_responses: - assert isinstance(error, openai.RateLimitError), ( - "Rate limited responses should be RateLimitError" + print( + f"\nFirst multimodal response: {content1[:100]}..." + if len(content1) > 100 + else content1 + ) + + # Second multimodal request + response2 = client.chat.completions.create( + model="google/gemma-3-4b-it", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image in detail."}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + temperature=0.2, + max_tokens=100, + ) + + # Verify second response + assert isinstance(response2, ChatCompletion), ( + "Second response should be a ChatCompletion object" + ) + assert response2.model == "google/gemma-3-4b-it", ( + "Second response model should be google/gemma-3-4b-it" + ) + assert len(response2.choices) > 0, ( + "Second response should contain at least one choice" + ) + + content2 = response2.choices[0].message.content + assert content2 is not None, "Second response content should not be null" + assert content2.strip() != "", "Second response content should not be empty" + + print( + f"\nSecond multimodal response: {content2[:100]}..." + if len(content2) > 100 + else content2 + ) + + # Verify both responses have usage data + assert response1.usage, "No usage data returned for first multimodal request" + assert response2.usage, "No usage data returned for second multimodal request" + + print(f"First multimodal usage: {response1.usage}") + print(f"Second multimodal usage: {response2.usage}") + + # Verify both responses have token counts + assert response1.usage.prompt_tokens > 0, ( + "No prompt tokens returned for first multimodal request" + ) + assert response1.usage.completion_tokens > 0, ( + "No completion tokens returned for first multimodal request" + ) + assert response1.usage.total_tokens > 0, ( + "No total tokens returned for first multimodal request" + ) + + assert response2.usage.prompt_tokens > 0, ( + "No prompt tokens returned for second multimodal request" + ) + assert response2.usage.completion_tokens > 0, ( + "No completion tokens returned for second multimodal request" + ) + assert response2.usage.total_tokens > 0, ( + "No total tokens returned for second multimodal request" + ) + + except Exception as e: + pytest.fail(f"Error testing consecutive multimodal requests: {str(e)}") + + +def test_multimodal_with_web_search_error(client): + """Test that multimodal + web search raises an error""" + if "Qwen/Qwen2-VL-2B-Instruct-AWQ" not in test_models: + pytest.skip("Multimodal test only runs for Qwen/Qwen2-VL-2B-Instruct-AWQ model") + + # Create a simple base64 encoded image (1x1 pixel red PNG) + + try: + client.chat.completions.create( + model="google/gemma-3-4b-it", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + extra_body={"web_search": True}, + temperature=0.2, + max_tokens=100, + ) + pytest.fail("Expected error for multimodal + web search combination") + except Exception as e: + # The error should be raised, which means the test passes + print(f"Expected error received: {str(e)}") + assert "multimodal" in str(e).lower() or "400" in str(e), ( + "Should raise multimodal or 400 error" ) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 3fdd70a0..4b43a545 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -2,11 +2,11 @@ from nilai_common import ( SignedChatCompletion, - Message, ModelEndpoint, ModelMetadata, Usage, Choice, + MessageAdapter, ) model_metadata: ModelMetadata = ModelMetadata( @@ -33,7 +33,7 @@ choices=[ Choice( index=0, - message=Message(role="assistant", content="test-content"), + message=MessageAdapter.new_completion_message(content="test-content"), finish_reason="stop", logprobs=ChoiceLogprobs(), ) diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index 183d5cdb..e6e20cdc 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -28,6 +28,12 @@ def mock_user(): mock.completion_tokens_details = None mock.prompt_tokens_details = None mock.queries = 10 + mock.ratelimit_minute = 100 + mock.ratelimit_hour = 1000 + mock.ratelimit_day = 10000 + mock.web_search_ratelimit_minute = 100 + mock.web_search_ratelimit_hour = 1000 + mock.web_search_ratelimit_day = 10000 return mock @@ -206,3 +212,106 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien "completion_tokens_details": None, "prompt_tokens_details": None, } + + +def test_chat_completion_image_web_search_error( + mock_user, mock_state, mock_user_manager, mocker, client +): + mocker.patch("openai.api_key", new="test-api-key") + from openai.types.chat import ChatCompletion + + mocker.patch.object(model_metadata, "multimodal_support", True) + + data = RESPONSE.model_dump() + data.pop("signature") + data.pop("sources", None) + response_data = ChatCompletion(**data) + mock_chat_completions = MagicMock() + mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) + mock_chat = MagicMock() + mock_chat.completions = mock_chat_completions + mock_async_openai_instance = MagicMock() + mock_async_openai_instance.chat = mock_chat + mocker.patch( + "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance + ) + response = client.post( + "/v1/chat/completions", + json={ + "model": "Qwen/Qwen2-VL-2B-Instruct-AWQ", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + }, + ], + "web_search": True, + }, + headers={"Authorization": "Bearer test-api-key"}, + ) + print(response) + assert response.status_code == 400 + + +def test_chat_completion_with_image( + mock_user, mock_state, mock_user_manager, mocker, client +): + mocker.patch("openai.api_key", new="test-api-key") + from openai.types.chat import ChatCompletion + + # Mock the model to support multimodal content + mocker.patch.object(model_metadata, "multimodal_support", True) + + data = RESPONSE.model_dump() + data.pop("signature") + data.pop("sources", None) + response_data = ChatCompletion(**data) + mock_chat_completions = MagicMock() + mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) + mock_chat = MagicMock() + mock_chat.completions = mock_chat_completions + mock_async_openai_instance = MagicMock() + mock_async_openai_instance.chat = mock_chat + mocker.patch( + "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance + ) + response = client.post( + "/v1/chat/completions", + json={ + "model": "Qwen/Qwen2-VL-2B-Instruct-AWQ", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + }, + ], + }, + headers={"Authorization": "Bearer test-api-key"}, + ) + assert response.status_code == 200 + assert "usage" in response.json() + assert response.json()["usage"] == { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "completion_tokens_details": None, + "prompt_tokens_details": None, + } diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index 940c6199..ab914a5e 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -198,55 +198,20 @@ async def test_global_web_search_rps_limit(req, redis_client, monkeypatch): web_search_minute_limit=None, ) - async def run_guarded(i, times, t0): - async for _ in rate_limit(req, user_limits): - times[i] = asyncio.get_event_loop().time() - t0 - await asyncio.sleep(0.01) + async def run_guarded(i, results): + try: + async for _ in rate_limit(req, user_limits): + results[i] = "ok" + await asyncio.sleep(0.01) + except HTTPException as e: + results[i] = e.status_code n = 40 - times = [0.0] * n - t0 = asyncio.get_event_loop().time() - tasks = [asyncio.create_task(run_guarded(i, times, t0)) for i in range(n)] - await asyncio.gather(*tasks) - - within_first_second = [t for t in times if t < 1.0] - assert len(within_first_second) <= 20 - assert max(times) >= 1.0 - - -@pytest.mark.asyncio -async def test_queueing_across_seconds(req, redis_client, monkeypatch): - from nilai_api import rate_limiting as rl - - await redis_client[0].delete("global:web_search:rps") - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "rps", 20) - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "max_concurrent_requests", 20) - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "count", 1) - - rate_limit = RateLimit(web_search_extractor=lambda _: True) - user_limits = UserRateLimits( - subscription_holder=random_id(), - day_limit=None, - hour_limit=None, - minute_limit=None, - token_rate_limit=None, - web_search_day_limit=None, - web_search_hour_limit=None, - web_search_minute_limit=None, - ) - - async def run_guarded(i, times, t0): - async for _ in rate_limit(req, user_limits): - times[i] = asyncio.get_event_loop().time() - t0 - await asyncio.sleep(0.01) - - n = 25 - times = [0.0] * n - t0 = asyncio.get_event_loop().time() - tasks = [asyncio.create_task(run_guarded(i, times, t0)) for i in range(n)] - await asyncio.gather(*tasks) - - first_window = [t for t in times if t < 1.0] - second_window = [t for t in times if 1.0 <= t < 2.0] - assert len(first_window) <= 20 - assert len(second_window) >= 1 + results = [None] * n + tasks = [asyncio.create_task(run_guarded(i, results)) for i in range(n)] + await asyncio.gather(*tasks, return_exceptions=True) + + successes = [r for r in results if r == "ok"] + rejections = [r for r in results if r == 429] + assert len(successes) <= 20 + assert len(rejections) >= 20 diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py index 7b83088d..124e2eb9 100644 --- a/tests/unit/nilai_api/test_web_search.py +++ b/tests/unit/nilai_api/test_web_search.py @@ -3,12 +3,6 @@ from fastapi import HTTPException from nilai_api.handlers.web_search import ( perform_web_search_async, - enhance_messages_with_web_search, -) -from nilai_common import Message -from nilai_common.api_model import ( - WebSearchContext, - Source, ) @@ -134,32 +128,3 @@ async def test_perform_web_search_async_concurrent_queries(): assert ( results[1].sources[0].content == "Advances in machine learning algorithms." ) - - -@pytest.mark.asyncio -async def test_enhance_messages_with_web_search(): - """Test message enhancement with web search results and source validation""" - original_messages = [ - Message(role="system", content="You are a helpful assistant"), - Message(role="user", content="What is the latest AI news?"), - ] - - with patch("nilai_api.handlers.web_search.perform_web_search_async") as mock_search: - mock_search.return_value = WebSearchContext( - prompt="[1] Latest AI Developments\nURL: https://example.com\nSnippet: OpenAI announces GPT-5", - sources=[ - Source(source="https://example.com", content="OpenAI announces GPT-5") - ], - ) - - enhanced = await enhance_messages_with_web_search(original_messages, "AI news") - - assert len(enhanced.messages) == 3 - assert enhanced.messages[0].role == "system" - assert "Latest AI Developments" in str(enhanced.messages[0].content) - assert enhanced.sources is not None - assert len(enhanced.sources) == 2 - assert enhanced.sources[0].source == "search_query" - assert enhanced.sources[0].content == "AI news" - assert enhanced.sources[1].source == "https://example.com" - assert enhanced.sources[1].content == "OpenAI announces GPT-5"