From 0469e3d7382ac6a7921c38262160ff9e36705992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Tue, 4 Nov 2025 15:22:10 +0100 Subject: [PATCH 1/7] feat: added improved database structure and logging --- .gitignore | 2 + .vscode/settings.json | 22 - changes.patch | 3200 +++++++++++++++++ docker-compose.dev.yml | 2 - .../dashboards/nuc-query-data.json | 12 +- .../runtime-data/dashboards/query-data.json | 12 +- .../dashboards/testnet-nuc-query-data.json | 12 +- .../runtime-data/dashboards/totals-data.json | 10 +- .../runtime-data/dashboards/usage-data.json | 8 +- ...73468afc_chore_improved_database_schema.py | 206 ++ ...b23c73035b_fix_userid_change_to_user_id.py | 37 + nilai-api/examples/users.py | 43 - nilai-api/pyproject.toml | 2 +- nilai-api/src/nilai_api/auth/__init__.py | 2 - nilai-api/src/nilai_api/auth/nuc.py | 6 +- nilai-api/src/nilai_api/auth/strategies.py | 81 +- nilai-api/src/nilai_api/commands/add_user.py | 16 +- nilai-api/src/nilai_api/config/__init__.py | 1 + nilai-api/src/nilai_api/credit.py | 3 + nilai-api/src/nilai_api/db/logs.py | 271 +- nilai-api/src/nilai_api/db/users.py | 210 +- nilai-api/src/nilai_api/rate_limiting.py | 46 +- .../src/nilai_api/routers/endpoints/chat.py | 478 +-- .../nilai_api/routers/endpoints/responses.py | 448 ++- nilai-api/src/nilai_api/routers/private.py | 25 +- .../nilai-common/src/nilai_common/__init__.py | 2 + .../src/nilai_common/api_models/__init__.py | 2 + .../api_models/chat_completion_model.py | 3 +- .../src/nilai_common/discovery.py | 2 +- .../nilai_api/test_users_db_integration.py | 122 +- tests/unit/nilai_api/__init__.py | 28 +- tests/unit/nilai_api/auth/test_auth.py | 50 +- tests/unit/nilai_api/auth/test_strategies.py | 107 +- .../routers/test_chat_completions_private.py | 77 +- .../nilai_api/routers/test_nildb_endpoints.py | 103 +- .../routers/test_responses_private.py | 51 +- tests/unit/nilai_api/test_db.py | 10 +- tests/unit/nilai_api/test_rate_limiting.py | 12 +- uv.lock | 8 +- 39 files changed, 4614 insertions(+), 1118 deletions(-) delete mode 100644 .vscode/settings.json create mode 100644 changes.patch create mode 100644 nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py create mode 100644 nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py delete mode 100644 nilai-api/examples/users.py diff --git a/.gitignore b/.gitignore index f3d8ab42..7cbf1bfd 100644 --- a/.gitignore +++ b/.gitignore @@ -179,3 +179,5 @@ private_key.key.lock development-compose.yml production-compose.yml + +.vscode/ diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 24c9e519..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "workbench.colorCustomizations": { - "activityBar.activeBackground": "#65c89b", - "activityBar.background": "#65c89b", - "activityBar.foreground": "#15202b", - "activityBar.inactiveForeground": "#15202b99", - "activityBarBadge.background": "#945bc4", - "activityBarBadge.foreground": "#e7e7e7", - "commandCenter.border": "#15202b99", - "sash.hoverBorder": "#65c89b", - "statusBar.background": "#42b883", - "statusBar.foreground": "#15202b", - "statusBarItem.hoverBackground": "#359268", - "statusBarItem.remoteBackground": "#42b883", - "statusBarItem.remoteForeground": "#15202b", - "titleBar.activeBackground": "#42b883", - "titleBar.activeForeground": "#15202b", - "titleBar.inactiveBackground": "#42b88399", - "titleBar.inactiveForeground": "#15202b99" - }, - "peacock.color": "#42b883" -} diff --git a/changes.patch b/changes.patch new file mode 100644 index 00000000..e2061aab --- /dev/null +++ b/changes.patch @@ -0,0 +1,3200 @@ +diff --git a/.gitignore b/.gitignore +index f3d8ab4..7cbf1bf 100644 +--- a/.gitignore ++++ b/.gitignore +@@ -179,3 +179,5 @@ private_key.key.lock + + development-compose.yml + production-compose.yml ++ ++.vscode/ +diff --git a/QUERY_LOG_DEPENDENCY_MIGRATION.md b/QUERY_LOG_DEPENDENCY_MIGRATION.md +new file mode 100644 +index 0000000..0eaca24 +--- /dev/null ++++ b/QUERY_LOG_DEPENDENCY_MIGRATION.md +@@ -0,0 +1,253 @@ ++# QueryLog Dependency Migration Guide ++ ++## Overview ++ ++The `QueryLogManager` has been converted to a FastAPI dependency pattern using `QueryLogContext`. This provides better integration with the request lifecycle and more accurate timing metrics. ++ ++## What Changed ++ ++### Before (Static Manager) ++```python ++from nilai_api.db.logs import QueryLogManager ++ ++# Manual logging with all parameters ++await QueryLogManager.log_query( ++ userid=auth_info.user.userid, ++ model=req.model, ++ prompt_tokens=prompt_tokens, ++ completion_tokens=completion_tokens, ++ response_time_ms=response_time_ms, ++ web_search_calls=len(sources) if sources else 0, ++ was_streamed=req.stream, ++ was_multimodal=has_multimodal, ++ was_nilrag=bool(req.nilrag), ++ was_nildb=bool(auth_info.prompt_document), ++) ++``` ++ ++### After (Dependency Pattern) ++```python ++from fastapi import Depends ++from nilai_api.db.logs import QueryLogContext, get_query_log_context ++ ++@router.post("/endpoint") ++async def endpoint( ++ log_ctx: QueryLogContext = Depends(get_query_log_context), # Inject dependency ++): ++ # Set context as you go ++ log_ctx.set_user(auth_info.user.userid) ++ log_ctx.set_model(req.model) ++ ++ # ... do work ... ++ ++ # Commit at the end (calculates timing automatically) ++ await log_ctx.commit() ++``` ++ ++## Key Features ++ ++### 1. Automatic Timing Tracking ++```python ++# Context automatically tracks: ++# - Total request time (from dependency creation) ++# - Model inference time (with start_model_timing/end_model_timing) ++# - Tool execution time (with start_tool_timing/end_tool_timing) ++ ++log_ctx.start_model_timing() ++response = await model.generate() ++log_ctx.end_model_timing() ++``` ++ ++### 2. Incremental Context Building ++```python ++# Set request parameters ++log_ctx.set_request_params( ++ temperature=req.temperature, ++ max_tokens=req.max_tokens, ++ was_streamed=req.stream, ++ was_multimodal=has_multimodal, ++ was_nildb=bool(auth_info.prompt_document), ++ was_nilrag=bool(req.nilrag), ++) ++ ++# Set usage metrics (can be called multiple times, last wins) ++log_ctx.set_usage( ++ prompt_tokens=100, ++ completion_tokens=50, ++ tool_calls=2, ++ web_search_calls=1, ++) ++``` ++ ++### 3. Error Tracking ++```python ++try: ++ # ... process request ... ++except HTTPException as e: ++ log_ctx.set_error(error_code=e.status_code, error_message=str(e.detail)) ++ await log_ctx.commit() ++ raise ++``` ++ ++### 4. Safe Commit (No Breaking) ++```python ++# Commit never raises exceptions - logging failures are logged but don't break requests ++await log_ctx.commit() ++``` ++ ++## Migration Steps for `/v1/chat/completions` ++ ++### Step 1: Add Dependency to Function Signature ++ ++```python ++@router.post("/v1/chat/completions", tags=["Chat"], response_model=None) ++async def chat_completion( ++ req: ChatRequest = Body(...), ++ _rate_limit=Depends(RateLimit(...)), ++ auth_info: AuthenticationInfo = Depends(get_auth_info), ++ meter: MeteringContext = Depends(LLMMeter), ++ log_ctx: QueryLogContext = Depends(get_query_log_context), # ADD THIS ++): ++``` ++ ++### Step 2: Initialize Context Early ++ ++```python ++ # Right after validation ++ log_ctx.set_user(auth_info.user.userid) ++ log_ctx.set_model(req.model) ++ log_ctx.set_request_params( ++ temperature=req.temperature, ++ max_tokens=req.max_tokens, ++ was_streamed=req.stream, ++ was_multimodal=has_multimodal, ++ was_nildb=bool(auth_info.prompt_document), ++ was_nilrag=bool(req.nilrag), ++ ) ++``` ++ ++### Step 3: Track Model Timing ++ ++```python ++ # Before model call ++ log_ctx.start_model_timing() ++ ++ response = await client.chat.completions.create(...) ++ ++ # After model call ++ log_ctx.end_model_timing() ++``` ++ ++### Step 4: Track Tool Timing (if applicable) ++ ++```python ++ if req.tools: ++ log_ctx.start_tool_timing() ++ ++ (final_completion, agg_prompt, agg_completion) = await handle_tool_workflow(...) ++ ++ log_ctx.end_tool_timing() ++ log_ctx.set_usage(tool_calls=len(response.choices[0].message.tool_calls or [])) ++``` ++ ++### Step 5: Replace QueryLogManager.log_query() ++ ++```python ++ # OLD - Remove this: ++ await QueryLogManager.log_query( ++ auth_info.user.userid, ++ model=req.model, ++ prompt_tokens=..., ++ completion_tokens=..., ++ response_time_ms=..., ++ web_search_calls=..., ++ ) ++ ++ # NEW - Replace with: ++ log_ctx.set_usage( ++ prompt_tokens=model_response.usage.prompt_tokens, ++ completion_tokens=model_response.usage.completion_tokens, ++ web_search_calls=len(sources) if sources else 0, ++ ) ++ await log_ctx.commit() ++``` ++ ++### Step 6: Handle Streaming Case ++ ++For streaming responses, commit inside the generator: ++ ++```python ++async def chat_completion_stream_generator(): ++ try: ++ # ... streaming logic ... ++ ++ async for chunk in response: ++ if chunk.usage is not None: ++ prompt_token_usage = chunk.usage.prompt_tokens ++ completion_token_usage = chunk.usage.completion_tokens ++ # ... yield chunks ... ++ ++ # At the end of stream ++ log_ctx.set_usage( ++ prompt_tokens=prompt_token_usage, ++ completion_tokens=completion_token_usage, ++ web_search_calls=len(sources) if sources else 0, ++ ) ++ await log_ctx.commit() ++ except Exception as e: ++ log_ctx.set_error(error_code=500, error_message=str(e)) ++ await log_ctx.commit() ++ raise ++``` ++ ++## Complete Example ++ ++Here's a minimal complete example: ++ ++```python ++@router.post("/v1/chat/completions") ++async def chat_completion( ++ req: ChatRequest, ++ auth_info: AuthenticationInfo = Depends(get_auth_info), ++ log_ctx: QueryLogContext = Depends(get_query_log_context), ++): ++ # Setup ++ log_ctx.set_user(auth_info.user.userid) ++ log_ctx.set_model(req.model) ++ ++ try: ++ # Process request ++ log_ctx.start_model_timing() ++ response = await process_request(req) ++ log_ctx.end_model_timing() ++ ++ # Set usage ++ log_ctx.set_usage( ++ prompt_tokens=response.usage.prompt_tokens, ++ completion_tokens=response.usage.completion_tokens, ++ ) ++ ++ # Commit ++ await log_ctx.commit() ++ ++ return response ++ except HTTPException as e: ++ log_ctx.set_error(e.status_code, str(e.detail)) ++ await log_ctx.commit() ++ raise ++``` ++ ++## Benefits ++ ++1. ✅ **Automatic timing** - No manual time.monotonic() tracking needed ++2. ✅ **Granular metrics** - Separate model vs tool timing ++3. ✅ **Error tracking** - Built-in error code and message support ++4. ✅ **Type safety** - Full type hints throughout ++5. ✅ **Non-breaking** - Legacy `QueryLogManager.log_query()` still works ++6. ✅ **Clean separation** - Logging logic separate from business logic ++7. ✅ **Request isolation** - Each request gets its own context instance ++8. ✅ **Flexible updates** - Update metrics as you discover them during request processing ++ ++## Backward Compatibility ++ ++The old `QueryLogManager.log_query()` static method still works and is marked as "legacy support". You can migrate endpoints gradually without breaking existing functionality. +diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml +index d40200e..5784fca 100644 +--- a/docker-compose.dev.yml ++++ b/docker-compose.dev.yml +@@ -33,8 +33,6 @@ services: + condition: service_healthy + nilauth-credit-server: + condition: service_healthy +- environment: +- - POSTGRES_DB=${POSTGRES_DB_NUC} + volumes: + - ./nilai-api/:/app/nilai-api/ + - ./packages/:/app/packages/ +@@ -97,7 +95,7 @@ services: + + nilauth-credit-server: + image: ghcr.io/nillionnetwork/nilauth-credit:sha-cb9e36a +- platform: linux/amd64 # for macOS to force running on Rosetta 2 ++ # platform: linux/amd64 # for macOS to force running on Rosetta 2 + environment: + DATABASE_URL: postgresql://nilauth:nilauth_dev_password@nilauth-postgres:5432/nilauth_credit + HOST: 0.0.0.0 +diff --git a/grafana/runtime-data/dashboards/nuc-query-data.json b/grafana/runtime-data/dashboards/nuc-query-data.json +index d66fd42..c7bbb6b 100644 +--- a/grafana/runtime-data/dashboards/nuc-query-data.json ++++ b/grafana/runtime-data/dashboards/nuc-query-data.json +@@ -126,7 +126,7 @@ + "editorMode": "code", + "format": "time_series", + "rawQuery": true, +- "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", ++ "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", + "refId": "A", + "sql": { + "columns": [ +@@ -218,7 +218,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", ++ "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", + "refId": "A", + "sql": { + "columns": [ +@@ -352,7 +352,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", ++ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", + "refId": "A", + "sql": { + "columns": [ +@@ -360,7 +360,7 @@ + "alias": "\"User\"", + "parameters": [ + { +- "name": "userid", ++ "name": "user_id", + "type": "functionParameter" + } + ], +@@ -381,7 +381,7 @@ + "groupBy": [ + { + "property": { +- "name": "userid", ++ "name": "user_id", + "type": "string" + }, + "type": "groupBy" +@@ -481,7 +481,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", ++ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", + "refId": "A", + "sql": { + "columns": [ +diff --git a/grafana/runtime-data/dashboards/query-data.json b/grafana/runtime-data/dashboards/query-data.json +index 8e0b774..f33f87a 100644 +--- a/grafana/runtime-data/dashboards/query-data.json ++++ b/grafana/runtime-data/dashboards/query-data.json +@@ -126,7 +126,7 @@ + "editorMode": "code", + "format": "time_series", + "rawQuery": true, +- "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", ++ "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", + "refId": "A", + "sql": { + "columns": [ +@@ -218,7 +218,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", ++ "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", + "refId": "A", + "sql": { + "columns": [ +@@ -352,7 +352,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", ++ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", + "refId": "A", + "sql": { + "columns": [ +@@ -360,7 +360,7 @@ + "alias": "\"User\"", + "parameters": [ + { +- "name": "userid", ++ "name": "user_id", + "type": "functionParameter" + } + ], +@@ -381,7 +381,7 @@ + "groupBy": [ + { + "property": { +- "name": "userid", ++ "name": "user_id", + "type": "string" + }, + "type": "groupBy" +@@ -481,7 +481,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", ++ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", + "refId": "A", + "sql": { + "columns": [ +diff --git a/grafana/runtime-data/dashboards/testnet-nuc-query-data.json b/grafana/runtime-data/dashboards/testnet-nuc-query-data.json +index f98d70e..358ba4e 100644 +--- a/grafana/runtime-data/dashboards/testnet-nuc-query-data.json ++++ b/grafana/runtime-data/dashboards/testnet-nuc-query-data.json +@@ -126,7 +126,7 @@ + "editorMode": "code", + "format": "time_series", + "rawQuery": true, +- "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", ++ "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", + "refId": "A", + "sql": { + "columns": [ +@@ -218,7 +218,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", ++ "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", + "refId": "A", + "sql": { + "columns": [ +@@ -352,7 +352,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", ++ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", + "refId": "A", + "sql": { + "columns": [ +@@ -360,7 +360,7 @@ + "alias": "\"User\"", + "parameters": [ + { +- "name": "userid", ++ "name": "user_id", + "type": "functionParameter" + } + ], +@@ -381,7 +381,7 @@ + "groupBy": [ + { + "property": { +- "name": "userid", ++ "name": "user_id", + "type": "string" + }, + "type": "groupBy" +@@ -481,7 +481,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", ++ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", + "refId": "A", + "sql": { + "columns": [ +diff --git a/grafana/runtime-data/dashboards/totals-data.json b/grafana/runtime-data/dashboards/totals-data.json +index 2db20c7..ff66ce0 100644 +--- a/grafana/runtime-data/dashboards/totals-data.json ++++ b/grafana/runtime-data/dashboards/totals-data.json +@@ -83,7 +83,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", ++ "rawSql": "SELECT \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "refId": "A", + "sql": { + "columns": [ +@@ -165,7 +165,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT SUM(total_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", ++ "rawSql": "SELECT SUM(total_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "refId": "A", + "sql": { + "columns": [ +@@ -248,7 +248,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT SUM(q.prompt_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", ++ "rawSql": "SELECT SUM(q.prompt_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "refId": "A", + "sql": { + "columns": [ +@@ -331,7 +331,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT SUM(q.completion_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", ++ "rawSql": "SELECT SUM(q.completion_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "refId": "A", + "sql": { + "columns": [ +@@ -397,4 +397,4 @@ + "uid": "aex54yzf0nmyoc", + "version": 1, + "weekStart": "" +-} +\ No newline at end of file ++} +diff --git a/grafana/runtime-data/dashboards/usage-data.json b/grafana/runtime-data/dashboards/usage-data.json +index 88857f9..a22bf91 100644 +--- a/grafana/runtime-data/dashboards/usage-data.json ++++ b/grafana/runtime-data/dashboards/usage-data.json +@@ -299,7 +299,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE q.query_timestamp >= NOW() - INTERVAL '1 hours'\nGROUP BY u.email;", ++ "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE q.query_timestamp >= NOW() - INTERVAL '1 hours'\nGROUP BY u.email;", + "refId": "A", + "sql": { + "columns": [ +@@ -307,7 +307,7 @@ + "alias": "\"User ID\"", + "parameters": [ + { +- "name": "userid", ++ "name": "user_id", + "type": "functionParameter" + } + ], +@@ -328,7 +328,7 @@ + "groupBy": [ + { + "property": { +- "name": "userid", ++ "name": "user_id", + "type": "string" + }, + "type": "groupBy" +@@ -430,7 +430,7 @@ + "editorMode": "code", + "format": "table", + "rawQuery": true, +- "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE q.query_timestamp >= NOW() - INTERVAL '7 days'\nGROUP BY u.email;", ++ "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE q.query_timestamp >= NOW() - INTERVAL '7 days'\nGROUP BY u.email;", + "refId": "A", + "sql": { + "columns": [ +diff --git a/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py b/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py +new file mode 100644 +index 0000000..ebaca5a +--- /dev/null ++++ b/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py +@@ -0,0 +1,206 @@ ++"""chore: merged database schema updates ++ ++Revision ID: 0ba073468afc ++Revises: ea942d6c7a00 ++Create Date: 2025-10-31 09:43:12.022675 ++ ++""" ++ ++from typing import Sequence, Union ++ ++from alembic import op ++import sqlalchemy as sa ++from sqlalchemy.dialects import postgresql ++ ++# revision identifiers, used by Alembic. ++revision: str = "0ba073468afc" ++down_revision: Union[str, None] = "9ddf28cf6b6f" ++branch_labels: Union[str, Sequence[str], None] = None ++depends_on: Union[str, Sequence[str], None] = None ++ ++ ++def upgrade() -> None: ++ # ### merged commands from ea942d6c7a00 and 0ba073468afc ### ++ # query_logs: new telemetry columns (with defaults to backfill existing rows) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "tool_calls", sa.Integer(), server_default=sa.text("0"), nullable=False ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "temperature", sa.Float(), server_default=sa.text("0.9"), nullable=True ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "max_tokens", sa.Integer(), server_default=sa.text("4096"), nullable=True ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "response_time_ms", ++ sa.Integer(), ++ server_default=sa.text("-1"), ++ nullable=False, ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "model_response_time_ms", ++ sa.Integer(), ++ server_default=sa.text("-1"), ++ nullable=False, ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "tool_response_time_ms", ++ sa.Integer(), ++ server_default=sa.text("-1"), ++ nullable=False, ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "was_streamed", ++ sa.Boolean(), ++ server_default=sa.text("False"), ++ nullable=False, ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "was_multimodal", ++ sa.Boolean(), ++ server_default=sa.text("False"), ++ nullable=False, ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "was_nildb", sa.Boolean(), server_default=sa.text("False"), nullable=False ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "was_nilrag", sa.Boolean(), server_default=sa.text("False"), nullable=False ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "error_code", sa.Integer(), server_default=sa.text("200"), nullable=False ++ ), ++ ) ++ op.add_column( ++ "query_logs", ++ sa.Column( ++ "error_message", sa.Text(), server_default=sa.text("'OK'"), nullable=False ++ ), ++ ) ++ ++ # query_logs: remove FK to users.userid before dropping the column later ++ op.drop_constraint("query_logs_userid_fkey", "query_logs", type_="foreignkey") ++ ++ # query_logs: add lockid and index, drop legacy userid and its index ++ op.add_column( ++ "query_logs", sa.Column("lockid", sa.String(length=75), nullable=False) ++ ) ++ op.drop_index("ix_query_logs_userid", table_name="query_logs") ++ op.create_index( ++ op.f("ix_query_logs_lockid"), "query_logs", ["lockid"], unique=False ++ ) ++ op.drop_column("query_logs", "userid") ++ ++ # users: drop legacy token counters ++ op.drop_column("users", "prompt_tokens") ++ op.drop_column("users", "completion_tokens") ++ ++ # users: reshape identity columns and indexes ++ op.add_column("users", sa.Column("user_id", sa.String(length=75), nullable=False)) ++ op.drop_index("ix_users_apikey", table_name="users") ++ op.drop_index("ix_users_userid", table_name="users") ++ op.create_index(op.f("ix_users_user_id"), "users", ["user_id"], unique=False) ++ op.drop_column("users", "last_activity") ++ op.drop_column("users", "userid") ++ op.drop_column("users", "apikey") ++ op.drop_column("users", "signup_date") ++ op.drop_column("users", "queries") ++ op.drop_column("users", "name") ++ # ### end merged commands ### ++ ++ ++def downgrade() -> None: ++ # ### revert merged commands back to 9ddf28cf6b6f ### ++ # users: restore legacy columns and indexes ++ op.add_column("users", sa.Column("name", sa.VARCHAR(length=100), nullable=False)) ++ op.add_column("users", sa.Column("queries", sa.INTEGER(), nullable=False)) ++ op.add_column( ++ "users", ++ sa.Column( ++ "signup_date", ++ postgresql.TIMESTAMP(timezone=True), ++ server_default=sa.text("now()"), ++ nullable=False, ++ ), ++ ) ++ op.add_column("users", sa.Column("apikey", sa.VARCHAR(length=75), nullable=False)) ++ op.add_column("users", sa.Column("userid", sa.VARCHAR(length=75), nullable=False)) ++ op.add_column( ++ "users", ++ sa.Column("last_activity", postgresql.TIMESTAMP(timezone=True), nullable=True), ++ ) ++ op.drop_index(op.f("ix_users_user_id"), table_name="users") ++ op.create_index("ix_users_userid", "users", ["userid"], unique=False) ++ op.create_index("ix_users_apikey", "users", ["apikey"], unique=False) ++ op.drop_column("users", "user_id") ++ op.add_column( ++ "users", ++ sa.Column( ++ "completion_tokens", ++ sa.INTEGER(), ++ server_default=sa.text("0"), ++ nullable=False, ++ ), ++ ) ++ op.add_column( ++ "users", ++ sa.Column( ++ "prompt_tokens", sa.INTEGER(), server_default=sa.text("0"), nullable=False ++ ), ++ ) ++ ++ # query_logs: restore userid, index and FK; drop new columns ++ op.add_column( ++ "query_logs", sa.Column("userid", sa.VARCHAR(length=75), nullable=False) ++ ) ++ op.drop_index(op.f("ix_query_logs_lockid"), table_name="query_logs") ++ op.create_index("ix_query_logs_userid", "query_logs", ["userid"], unique=False) ++ op.create_foreign_key( ++ "query_logs_userid_fkey", "query_logs", "users", ["userid"], ["userid"] ++ ) ++ op.drop_column("query_logs", "lockid") ++ op.drop_column("query_logs", "error_message") ++ op.drop_column("query_logs", "error_code") ++ op.drop_column("query_logs", "was_nilrag") ++ op.drop_column("query_logs", "was_nildb") ++ op.drop_column("query_logs", "was_multimodal") ++ op.drop_column("query_logs", "was_streamed") ++ op.drop_column("query_logs", "tool_response_time_ms") ++ op.drop_column("query_logs", "model_response_time_ms") ++ op.drop_column("query_logs", "response_time_ms") ++ op.drop_column("query_logs", "max_tokens") ++ op.drop_column("query_logs", "temperature") ++ op.drop_column("query_logs", "tool_calls") ++ # ### end revert ### +diff --git a/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py b/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py +new file mode 100644 +index 0000000..4c20bb6 +--- /dev/null ++++ b/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py +@@ -0,0 +1,37 @@ ++"""fix: userid change to user_id ++ ++Revision ID: 43b23c73035b ++Revises: 0ba073468afc ++Create Date: 2025-11-03 11:33:03.006101 ++ ++""" ++ ++from typing import Sequence, Union ++ ++from alembic import op ++import sqlalchemy as sa ++ ++ ++# revision identifiers, used by Alembic. ++revision: str = "43b23c73035b" ++down_revision: Union[str, None] = "0ba073468afc" ++branch_labels: Union[str, Sequence[str], None] = None ++depends_on: Union[str, Sequence[str], None] = None ++ ++ ++def upgrade() -> None: ++ # ### commands auto generated by Alembic - please adjust! ### ++ op.add_column( ++ "query_logs", sa.Column("user_id", sa.String(length=75), nullable=False) ++ ) ++ op.create_index( ++ op.f("ix_query_logs_user_id"), "query_logs", ["user_id"], unique=False ++ ) ++ # ### end Alembic commands ### ++ ++ ++def downgrade() -> None: ++ # ### commands auto generated by Alembic - please adjust! ### ++ op.drop_index(op.f("ix_query_logs_user_id"), table_name="query_logs") ++ op.drop_column("query_logs", "user_id") ++ # ### end Alembic commands ### +diff --git a/nilai-api/examples/users.py b/nilai-api/examples/users.py +deleted file mode 100644 +index b6b206d..0000000 +--- a/nilai-api/examples/users.py ++++ /dev/null +@@ -1,43 +0,0 @@ +-#!/usr/bin/python +- +-from nilai_api.db.logs import QueryLogManager +-from nilai_api.db.users import UserManager +- +- +-# Example Usage +-async def main(): +- # Add some users +- bob = await UserManager.insert_user("Bob", "bob@example.com") +- alice = await UserManager.insert_user("Alice", "alice@example.com") +- +- print(f"Bob's details: {bob}") +- print(f"Alice's details: {alice}") +- +- # Check API key +- user_name = await UserManager.check_api_key(bob.apikey) +- print(f"API key validation: {user_name}") +- +- # Update and retrieve token usage +- await UserManager.update_token_usage( +- bob.userid, prompt_tokens=50, completion_tokens=20 +- ) +- usage = await UserManager.get_user_token_usage(bob.userid) +- print(f"Bob's token usage: {usage}") +- +- # Log a query +- await QueryLogManager.log_query( +- userid=bob.userid, +- model="gpt-3.5-turbo", +- prompt_tokens=8, +- completion_tokens=7, +- web_search_calls=1, +- ) +- +- +-if __name__ == "__main__": +- import asyncio +- from dotenv import load_dotenv +- +- load_dotenv() +- +- asyncio.run(main()) +diff --git a/nilai-api/pyproject.toml b/nilai-api/pyproject.toml +index 0bbfba3..9caae2a 100644 +--- a/nilai-api/pyproject.toml ++++ b/nilai-api/pyproject.toml +@@ -35,7 +35,7 @@ dependencies = [ + "trafilatura>=1.7.0", + "secretvaults", + "e2b-code-interpreter>=1.0.3", +- "nilauth-credit-middleware>=0.1.1", ++ "nilauth-credit-middleware>=0.1.2", + ] + + +diff --git a/nilai-api/src/nilai_api/auth/__init__.py b/nilai-api/src/nilai_api/auth/__init__.py +index 2e7cd6f..2123685 100644 +--- a/nilai-api/src/nilai_api/auth/__init__.py ++++ b/nilai-api/src/nilai_api/auth/__init__.py +@@ -4,7 +4,6 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + from logging import getLogger + + from nilai_api.config import CONFIG +-from nilai_api.db.users import UserManager + from nilai_api.auth.strategies import AuthenticationStrategy + + from nuc.validate import ValidationException +@@ -36,7 +35,6 @@ async def get_auth_info( + ) + + auth_info = await strategy(credentials.credentials) +- await UserManager.update_last_activity(userid=auth_info.user.userid) + return auth_info + except AuthenticationError as e: + raise e +diff --git a/nilai-api/src/nilai_api/auth/nuc.py b/nilai-api/src/nilai_api/auth/nuc.py +index 4645935..614d9ef 100644 +--- a/nilai-api/src/nilai_api/auth/nuc.py ++++ b/nilai-api/src/nilai_api/auth/nuc.py +@@ -86,11 +86,11 @@ def validate_nuc(nuc_token: str) -> Tuple[str, str]: + + # Validate the + # Return the subject of the token, the subscription holder +- subscription_holder = token.subject.public_key.hex() +- user = token.issuer.public_key.hex() ++ subscription_holder = token.subject ++ user = token.issuer + logger.info(f"Subscription holder: {subscription_holder}") + logger.info(f"User: {user}") +- return subscription_holder, user ++ return str(subscription_holder), str(user) + + + def get_token_rate_limit(nuc_token: str) -> Optional[TokenRateLimits]: +diff --git a/nilai-api/src/nilai_api/auth/strategies.py b/nilai-api/src/nilai_api/auth/strategies.py +index 9917ee3..089e7e9 100644 +--- a/nilai-api/src/nilai_api/auth/strategies.py ++++ b/nilai-api/src/nilai_api/auth/strategies.py +@@ -1,6 +1,6 @@ + from typing import Callable, Awaitable, Optional +-from datetime import datetime, timezone + ++from fastapi import HTTPException + from nilai_api.db.users import UserManager, UserModel, UserData + from nilai_api.auth.nuc import ( + validate_nuc, +@@ -11,11 +11,18 @@ from nilai_api.config import CONFIG + from nilai_api.auth.common import ( + PromptDocument, + TokenRateLimits, +- AuthenticationInfo, + AuthenticationError, ++ AuthenticationInfo, ++) ++ ++from nilauth_credit_middleware import ( ++ CreditClientSingleton, + ) ++from nilauth_credit_middleware.api_model import ValidateCredentialResponse ++ + + from enum import Enum ++ + # All strategies must return a UserModel + # The strategies can raise any exception, which will be caught and converted to an AuthenticationError + # The exception detail will be passed to the client +@@ -44,18 +51,10 @@ def allow_token( + return await function(token) + + if token == allowed_token: +- user_model: UserModel | None = await UserManager.check_user( +- allowed_token ++ user_model = UserModel( ++ user_id=allowed_token, ++ rate_limits=None, + ) +- if user_model is None: +- user_model = UserModel( +- userid=allowed_token, +- name=allowed_token, +- apikey=allowed_token, +- signup_date=datetime.now(timezone.utc), +- ) +- await UserManager.insert_user_model(user_model) +- + return AuthenticationInfo( + user=UserData.from_sqlalchemy(user_model), + token_rate_limit=None, +@@ -68,16 +67,41 @@ def allow_token( + return decorator + + ++async def validate_credential(credential: str, is_public: bool) -> UserModel: ++ """ ++ Validate a credential with nilauth credit middleware and return the user model ++ """ ++ credit_client = CreditClientSingleton.get_client() ++ try: ++ validate_response: ValidateCredentialResponse = ( ++ await credit_client.validate_credential(credential, is_public=is_public) ++ ) ++ except HTTPException as e: ++ if e.status_code == 404: ++ raise AuthenticationError(f"Credential not found: {e.detail}") ++ elif e.status_code == 401: ++ raise AuthenticationError(f"Credential is inactive: {e.detail}") ++ else: ++ raise AuthenticationError(f"Failed to validate credential: {e.detail}") ++ ++ user_model = await UserManager.check_user(validate_response.user_id) ++ if user_model is None: ++ user_model = UserModel( ++ user_id=validate_response.user_id, ++ rate_limits=None, ++ ) ++ return user_model ++ ++ + @allow_token(CONFIG.docs.token) + async def api_key_strategy(api_key: str) -> AuthenticationInfo: +- user_model: Optional[UserModel] = await UserManager.check_api_key(api_key) +- if user_model: +- return AuthenticationInfo( +- user=UserData.from_sqlalchemy(user_model), +- token_rate_limit=None, +- prompt_document=None, +- ) +- raise AuthenticationError("Missing or invalid API key") ++ user_model = await validate_credential(api_key, is_public=False) ++ ++ return AuthenticationInfo( ++ user=UserData.from_sqlalchemy(user_model), ++ token_rate_limit=None, ++ prompt_document=None, ++ ) + + + @allow_token(CONFIG.docs.token) +@@ -89,20 +113,7 @@ async def nuc_strategy(nuc_token) -> AuthenticationInfo: + token_rate_limits: Optional[TokenRateLimits] = get_token_rate_limit(nuc_token) + prompt_document: Optional[PromptDocument] = get_token_prompt_document(nuc_token) + +- user_model: Optional[UserModel] = await UserManager.check_user(user) +- if user_model: +- return AuthenticationInfo( +- user=UserData.from_sqlalchemy(user_model), +- token_rate_limit=token_rate_limits, +- prompt_document=prompt_document, +- ) +- +- user_model = UserModel( +- userid=user, +- name=user, +- apikey=subscription_holder, +- ) +- await UserManager.insert_user_model(user_model) ++ user_model = await validate_credential(subscription_holder, is_public=True) + return AuthenticationInfo( + user=UserData.from_sqlalchemy(user_model), + token_rate_limit=token_rate_limits, +diff --git a/nilai-api/src/nilai_api/commands/add_user.py b/nilai-api/src/nilai_api/commands/add_user.py +index e9f49e5..5bd488b 100644 +--- a/nilai-api/src/nilai_api/commands/add_user.py ++++ b/nilai-api/src/nilai_api/commands/add_user.py +@@ -6,9 +6,7 @@ import click + + + @click.command() +-@click.option("--name", type=str, required=True, help="User Name") +-@click.option("--apikey", type=str, help="API Key") +-@click.option("--userid", type=str, help="User Id") ++@click.option("--user_id", type=str, help="User Id") + @click.option("--ratelimit-day", type=int, help="number of request per day") + @click.option("--ratelimit-hour", type=int, help="number of request per hour") + @click.option("--ratelimit-minute", type=int, help="number of request per minute") +@@ -26,9 +24,7 @@ import click + help="number of web search request per minute", + ) + def main( +- name, +- apikey: str | None, +- userid: str | None, ++ user_id: str | None, + ratelimit_day: int | None, + ratelimit_hour: int | None, + ratelimit_minute: int | None, +@@ -38,9 +34,7 @@ def main( + ): + async def add_user(): + user: UserModel = await UserManager.insert_user( +- name, +- apikey, +- userid, ++ user_id, + RateLimits( + user_rate_limit_day=ratelimit_day, + user_rate_limit_hour=ratelimit_hour, +@@ -52,7 +46,7 @@ def main( + ) + json_user = json.dumps( + { +- "userid": user.userid, ++ "user_id": user.user_id, + "name": user.name, + "apikey": user.apikey, + "ratelimit_day": user.rate_limits_obj.user_rate_limit_day, +diff --git a/nilai-api/src/nilai_api/config/__init__.py b/nilai-api/src/nilai_api/config/__init__.py +index 1939c74..b61a9fe 100644 +--- a/nilai-api/src/nilai_api/config/__init__.py ++++ b/nilai-api/src/nilai_api/config/__init__.py +@@ -64,3 +64,4 @@ __all__ = [ + ] + + logging.info(CONFIG.prettify()) ++print(CONFIG.prettify()) +diff --git a/nilai-api/src/nilai_api/credit.py b/nilai-api/src/nilai_api/credit.py +index 3a06135..b9d7ea6 100644 +--- a/nilai-api/src/nilai_api/credit.py ++++ b/nilai-api/src/nilai_api/credit.py +@@ -20,6 +20,9 @@ logger = logging.getLogger(__name__) + class NoOpMeteringContext: + """A no-op metering context for requests that should skip metering (e.g., Docs Token).""" + ++ def __init__(self): ++ self.lock_id: str = "noop-lock-id" ++ + def set_response(self, response_data: dict) -> None: + """No-op method that does nothing.""" + pass +diff --git a/nilai-api/src/nilai_api/db/logs.py b/nilai-api/src/nilai_api/db/logs.py +index 030c869..4a78c8a 100644 +--- a/nilai-api/src/nilai_api/db/logs.py ++++ b/nilai-api/src/nilai_api/db/logs.py +@@ -1,12 +1,14 @@ + import logging ++import time + from datetime import datetime, timezone ++from typing import Optional + ++from nilai_common import Usage + import sqlalchemy + +-from sqlalchemy import ForeignKey, Integer, String, DateTime, Text ++from sqlalchemy import Integer, String, DateTime, Text, Boolean, Float + from sqlalchemy.exc import SQLAlchemyError + from nilai_api.db import Base, Column, get_db_session +-from nilai_api.db.users import UserModel + + logger = logging.getLogger(__name__) + +@@ -16,9 +18,8 @@ class QueryLog(Base): + __tablename__ = "query_logs" + + id: int = Column(Integer, primary_key=True, autoincrement=True) # type: ignore +- userid: str = Column( +- String(75), ForeignKey(UserModel.userid), nullable=False, index=True +- ) # type: ignore ++ user_id: str = Column(String(75), nullable=False, index=True) # type: ignore ++ lockid: str = Column(String(75), nullable=False, index=True) # type: ignore + query_timestamp: datetime = Column( + DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False + ) # type: ignore +@@ -26,51 +27,285 @@ class QueryLog(Base): + prompt_tokens: int = Column(Integer, nullable=False) # type: ignore + completion_tokens: int = Column(Integer, nullable=False) # type: ignore + total_tokens: int = Column(Integer, nullable=False) # type: ignore ++ tool_calls: int = Column(Integer, nullable=False) # type: ignore + web_search_calls: int = Column(Integer, nullable=False) # type: ignore ++ temperature: Optional[float] = Column(Float, nullable=True) # type: ignore ++ max_tokens: Optional[int] = Column(Integer, nullable=True) # type: ignore ++ ++ response_time_ms: int = Column(Integer, nullable=False) # type: ignore ++ model_response_time_ms: int = Column(Integer, nullable=False) # type: ignore ++ tool_response_time_ms: int = Column(Integer, nullable=False) # type: ignore ++ ++ was_streamed: bool = Column(Boolean, nullable=False) # type: ignore ++ was_multimodal: bool = Column(Boolean, nullable=False) # type: ignore ++ was_nildb: bool = Column(Boolean, nullable=False) # type: ignore ++ was_nilrag: bool = Column(Boolean, nullable=False) # type: ignore ++ ++ error_code: int = Column(Integer, nullable=False) # type: ignore ++ error_message: str = Column(Text, nullable=False) # type: ignore + + def __repr__(self): +- return f"" ++ return f"" ++ ++ ++class QueryLogContext: ++ """ ++ Context manager for logging query metrics during a request. ++ Used as a FastAPI dependency to track request metrics. ++ """ ++ ++ def __init__(self): ++ self.user_id: Optional[str] = None ++ self.lockid: Optional[str] = None ++ self.model: Optional[str] = None ++ self.prompt_tokens: int = 0 ++ self.completion_tokens: int = 0 ++ self.tool_calls: int = 0 ++ self.web_search_calls: int = 0 ++ self.temperature: Optional[float] = None ++ self.max_tokens: Optional[int] = None ++ self.was_streamed: bool = False ++ self.was_multimodal: bool = False ++ self.was_nildb: bool = False ++ self.was_nilrag: bool = False ++ self.error_code: int = 0 ++ self.error_message: str = "" ++ ++ # Timing tracking ++ self.start_time: float = time.monotonic() ++ self.model_start_time: Optional[float] = None ++ self.model_end_time: Optional[float] = None ++ self.tool_start_time: Optional[float] = None ++ self.tool_end_time: Optional[float] = None ++ ++ def set_user(self, user_id: str) -> None: ++ """Set the user ID for this query.""" ++ self.user_id = user_id ++ ++ def set_lockid(self, lockid: str) -> None: ++ """Set the lock ID for this query.""" ++ self.lockid = lockid ++ ++ def set_model(self, model: str) -> None: ++ """Set the model name for this query.""" ++ self.model = model ++ ++ def set_request_params( ++ self, ++ temperature: Optional[float] = None, ++ max_tokens: Optional[int] = None, ++ was_streamed: bool = False, ++ was_multimodal: bool = False, ++ was_nildb: bool = False, ++ was_nilrag: bool = False, ++ ) -> None: ++ """Set request parameters.""" ++ self.temperature = temperature ++ self.max_tokens = max_tokens ++ self.was_streamed = was_streamed ++ self.was_multimodal = was_multimodal ++ self.was_nildb = was_nildb ++ self.was_nilrag = was_nilrag ++ ++ def set_usage( ++ self, ++ prompt_tokens: int = 0, ++ completion_tokens: int = 0, ++ tool_calls: int = 0, ++ web_search_calls: int = 0, ++ ) -> None: ++ """Set token usage and feature usage.""" ++ self.prompt_tokens = prompt_tokens ++ self.completion_tokens = completion_tokens ++ self.tool_calls = tool_calls ++ self.web_search_calls = web_search_calls ++ ++ def set_error(self, error_code: int, error_message: str) -> None: ++ """Set error information.""" ++ self.error_code = error_code ++ self.error_message = error_message ++ ++ def start_model_timing(self) -> None: ++ """Mark the start of model inference.""" ++ self.model_start_time = time.monotonic() ++ ++ def end_model_timing(self) -> None: ++ """Mark the end of model inference.""" ++ self.model_end_time = time.monotonic() ++ ++ def start_tool_timing(self) -> None: ++ """Mark the start of tool execution.""" ++ self.tool_start_time = time.monotonic() ++ ++ def end_tool_timing(self) -> None: ++ """Mark the end of tool execution.""" ++ self.tool_end_time = time.monotonic() ++ ++ def _calculate_timings(self) -> tuple[int, int, int]: ++ """Calculate response times in milliseconds.""" ++ total_ms = int((time.monotonic() - self.start_time) * 1000) ++ ++ model_ms = 0 ++ if self.model_start_time and self.model_end_time: ++ model_ms = int((self.model_end_time - self.model_start_time) * 1000) ++ ++ tool_ms = 0 ++ if self.tool_start_time and self.tool_end_time: ++ tool_ms = int((self.tool_end_time - self.tool_start_time) * 1000) ++ ++ return total_ms, model_ms, tool_ms ++ ++ async def commit(self) -> None: ++ """ ++ Commit the query log to the database. ++ Should be called at the end of the request lifecycle. ++ """ ++ if not self.user_id or not self.model: ++ logger.warning( ++ "Skipping query log: user_id or model not set " ++ f"(user_id={self.user_id}, model={self.model})" ++ ) ++ return ++ ++ total_ms, model_ms, tool_ms = self._calculate_timings() ++ total_tokens = self.prompt_tokens + self.completion_tokens ++ ++ try: ++ async with get_db_session() as session: ++ query_log = QueryLog( ++ user_id=self.user_id, ++ lockid=self.lockid, ++ model=self.model, ++ prompt_tokens=self.prompt_tokens, ++ completion_tokens=self.completion_tokens, ++ total_tokens=total_tokens, ++ tool_calls=self.tool_calls, ++ web_search_calls=self.web_search_calls, ++ temperature=self.temperature, ++ max_tokens=self.max_tokens, ++ query_timestamp=datetime.now(timezone.utc), ++ response_time_ms=total_ms, ++ model_response_time_ms=model_ms, ++ tool_response_time_ms=tool_ms, ++ was_streamed=self.was_streamed, ++ was_multimodal=self.was_multimodal, ++ was_nilrag=self.was_nilrag, ++ was_nildb=self.was_nildb, ++ error_code=self.error_code, ++ error_message=self.error_message, ++ ) ++ session.add(query_log) ++ await session.commit() ++ logger.info( ++ f"Query logged for user {self.user_id}: model={self.model}, " ++ f"tokens={total_tokens}, total_ms={total_ms}" ++ ) ++ except SQLAlchemyError as e: ++ logger.error(f"Error logging query: {e}") ++ # Don't raise - logging failure shouldn't break the request + + + class QueryLogManager: ++ """Static methods for direct query logging (legacy support).""" ++ + @staticmethod + async def log_query( +- userid: str, ++ user_id: str, ++ lockid: str, + model: str, + prompt_tokens: int, + completion_tokens: int, ++ response_time_ms: int, + web_search_calls: int, ++ was_streamed: bool, ++ was_multimodal: bool, ++ was_nilrag: bool, ++ was_nildb: bool, ++ tool_calls: int = 0, ++ temperature: float = 1.0, ++ max_tokens: int = 0, ++ model_response_time_ms: int = 0, ++ tool_response_time_ms: int = 0, ++ error_code: int = 0, ++ error_message: str = "", + ): + """ +- Log a user's query. +- +- Args: +- userid (str): User's unique ID +- model (str): The model that generated the response +- prompt_tokens (int): Number of input tokens used +- completion_tokens (int): Number of tokens in the generated response ++ Log a user's query (legacy method). ++ Consider using QueryLogContext as a dependency instead. + """ + total_tokens = prompt_tokens + completion_tokens + + try: + async with get_db_session() as session: + query_log = QueryLog( +- userid=userid, ++ user_id=user_id, ++ lockid=lockid, + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, +- query_timestamp=datetime.now(timezone.utc), ++ tool_calls=tool_calls, + web_search_calls=web_search_calls, ++ temperature=temperature, ++ max_tokens=max_tokens, ++ query_timestamp=datetime.now(timezone.utc), ++ response_time_ms=response_time_ms, ++ model_response_time_ms=model_response_time_ms, ++ tool_response_time_ms=tool_response_time_ms, ++ was_streamed=was_streamed, ++ was_multimodal=was_multimodal, ++ was_nilrag=was_nilrag, ++ was_nildb=was_nildb, ++ error_code=error_code, ++ error_message=error_message, + ) + session.add(query_log) + await session.commit() + logger.info( +- f"Query logged for user {userid} with total tokens {total_tokens}." ++ f"Query logged for user {user_id} with total tokens {total_tokens}." + ) + except SQLAlchemyError as e: + logger.error(f"Error logging query: {e}") + raise + ++ @staticmethod ++ async def get_user_token_usage(user_id: str) -> Optional[Usage]: ++ """ ++ Get aggregated token usage for a specific user using server-side SQL aggregation. ++ This is more efficient than fetching all records and calculating in Python. ++ """ ++ try: ++ async with get_db_session() as session: ++ # Use SQL aggregation functions to calculate on the database server ++ query = ( ++ sqlalchemy.select( ++ sqlalchemy.func.coalesce( ++ sqlalchemy.func.sum(QueryLog.prompt_tokens), 0 ++ ).label("prompt_tokens"), ++ sqlalchemy.func.coalesce( ++ sqlalchemy.func.sum(QueryLog.completion_tokens), 0 ++ ).label("completion_tokens"), ++ sqlalchemy.func.coalesce( ++ sqlalchemy.func.sum(QueryLog.total_tokens), 0 ++ ).label("total_tokens"), ++ sqlalchemy.func.count().label("queries"), ++ ).where(QueryLog.user_id == user_id) # type: ignore[arg-type] ++ ) ++ ++ result = await session.execute(query) ++ row = result.one_or_none() ++ ++ if row is None: ++ return None ++ ++ return Usage( ++ prompt_tokens=int(row.prompt_tokens), ++ completion_tokens=int(row.completion_tokens), ++ total_tokens=int(row.total_tokens), ++ ) ++ except SQLAlchemyError as e: ++ logger.error(f"Error getting token usage: {e}") ++ return None ++ + +-__all__ = ["QueryLogManager", "QueryLog"] ++__all__ = ["QueryLogManager", "QueryLog", "QueryLogContext"] +diff --git a/nilai-api/src/nilai_api/db/users.py b/nilai-api/src/nilai_api/db/users.py +index 515ba38..e475c42 100644 +--- a/nilai-api/src/nilai_api/db/users.py ++++ b/nilai-api/src/nilai_api/db/users.py +@@ -2,11 +2,10 @@ import logging + import uuid + from pydantic import BaseModel, ConfigDict, Field + +-from datetime import datetime, timezone +-from typing import Any, Dict, List, Optional ++from typing import Optional + + import sqlalchemy +-from sqlalchemy import Integer, String, DateTime, JSON ++from sqlalchemy import String, JSON + from sqlalchemy.exc import SQLAlchemyError + + from nilai_api.db import Base, Column, get_db_session +@@ -57,21 +56,11 @@ class RateLimits(BaseModel): + # Enhanced User Model with additional constraints and validation + class UserModel(Base): + __tablename__ = "users" +- +- userid: str = Column(String(75), primary_key=True, index=True) # type: ignore +- name: str = Column(String(100), nullable=False) # type: ignore +- apikey: str = Column(String(75), unique=False, nullable=False, index=True) # type: ignore +- prompt_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore +- completion_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore +- queries: int = Column(Integer, default=0, nullable=False) # type: ignore +- signup_date: datetime = Column( +- DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False +- ) # type: ignore +- last_activity: datetime = Column(DateTime(timezone=True), nullable=True) # type: ignore ++ user_id: str = Column(String(75), primary_key=True, index=True) # type: ignore + rate_limits: dict = Column(JSON, nullable=True) # type: ignore + + def __repr__(self): +- return f"" ++ return f"" + + @property + def rate_limits_obj(self) -> RateLimits: +@@ -85,14 +74,7 @@ class UserModel(Base): + + + class UserData(BaseModel): +- userid: str +- name: str +- apikey: str +- prompt_tokens: int = 0 +- completion_tokens: int = 0 +- queries: int = 0 +- signup_date: datetime +- last_activity: Optional[datetime] = None ++ user_id: str # apikey or subscription holder public key + rate_limits: RateLimits = Field(default_factory=RateLimits().get_effective_limits) + + model_config = ConfigDict(from_attributes=True) +@@ -100,21 +82,10 @@ class UserData(BaseModel): + @classmethod + def from_sqlalchemy(cls, user: UserModel) -> "UserData": + return cls( +- userid=user.userid, +- name=user.name, +- apikey=user.apikey, +- prompt_tokens=user.prompt_tokens or 0, +- completion_tokens=user.completion_tokens or 0, +- queries=user.queries or 0, +- signup_date=user.signup_date or datetime.now(timezone.utc), +- last_activity=user.last_activity, ++ user_id=user.user_id, + rate_limits=user.rate_limits_obj, + ) + +- @property +- def is_subscription_owner(self): +- return self.userid == self.apikey +- + + class UserManager: + @staticmethod +@@ -127,31 +98,9 @@ class UserManager: + """Generate a unique API key.""" + return str(uuid.uuid4()) + +- @staticmethod +- async def update_last_activity(userid: str): +- """ +- Update the last activity timestamp for a user. +- +- Args: +- userid (str): User's unique ID +- """ +- try: +- async with get_db_session() as session: +- user = await session.get(UserModel, userid) +- if user: +- user.last_activity = datetime.now(timezone.utc) +- await session.commit() +- logger.info(f"Updated last activity for user {userid}") +- else: +- logger.warning(f"User {userid} not found") +- except SQLAlchemyError as e: +- logger.error(f"Error updating last activity: {e}") +- + @staticmethod + async def insert_user( +- name: str, +- apikey: str | None = None, +- userid: str | None = None, ++ user_id: str | None = None, + rate_limits: RateLimits | None = None, + ) -> UserModel: + """ +@@ -160,19 +109,16 @@ class UserManager: + Args: + name (str): Name of the user + apikey (str): API key for the user +- userid (str): Unique ID for the user ++ user_id (str): Unique ID for the user + rate_limits (RateLimits): Rate limit configuration + + Returns: + UserModel: The created user model + """ +- userid = userid if userid else UserManager.generate_user_id() +- apikey = apikey if apikey else UserManager.generate_api_key() ++ user_id = user_id if user_id else UserManager.generate_user_id() + + user = UserModel( +- userid=userid, +- name=name, +- apikey=apikey, ++ user_id=user_id, + rate_limits=rate_limits.model_dump() if rate_limits else None, + ) + return await UserManager.insert_user_model(user) +@@ -189,35 +135,14 @@ class UserManager: + async with get_db_session() as session: + session.add(user) + await session.commit() +- logger.info(f"User {user.name} added successfully.") ++ logger.info(f"User {user.user_id} added successfully.") + return user + except SQLAlchemyError as e: + logger.error(f"Error inserting user: {e}") + raise + + @staticmethod +- async def check_user(userid: str) -> Optional[UserModel]: +- """ +- Validate a user. +- +- Args: +- userid (str): User ID to validate +- +- Returns: +- User's name if user is valid, None otherwise +- """ +- try: +- async with get_db_session() as session: +- query = sqlalchemy.select(UserModel).filter(UserModel.userid == userid) # type: ignore +- user = await session.execute(query) +- user = user.scalar_one_or_none() +- return user +- except SQLAlchemyError as e: +- logger.error(f"Error checking API key: {e}") +- return None +- +- @staticmethod +- async def check_api_key(api_key: str) -> Optional[UserModel]: ++ async def check_user(user_id: str) -> Optional[UserModel]: + """ + Validate an API key. + +@@ -225,118 +150,27 @@ class UserManager: + api_key (str): API key to validate + + Returns: +- User's name if API key is valid, None otherwise ++ User's rate limits if user id is valid, None otherwise + """ + try: + async with get_db_session() as session: +- query = sqlalchemy.select(UserModel).filter(UserModel.apikey == api_key) # type: ignore ++ query = sqlalchemy.select(UserModel).filter( ++ UserModel.user_id == user_id # type: ignore ++ ) + user = await session.execute(query) + user = user.scalar_one_or_none() + return user + except SQLAlchemyError as e: +- logger.error(f"Error checking API key: {e}") +- return None +- +- @staticmethod +- async def update_token_usage( +- userid: str, prompt_tokens: int, completion_tokens: int +- ): +- """ +- Update token usage for a specific user. +- +- Args: +- userid (str): User's unique ID +- prompt_tokens (int): Number of input tokens +- completion_tokens (int): Number of generated tokens +- """ +- try: +- async with get_db_session() as session: +- user = await session.get(UserModel, userid) +- if user: +- user.prompt_tokens += prompt_tokens +- user.completion_tokens += completion_tokens +- user.queries += 1 +- await session.commit() +- logger.info(f"Updated token usage for user {userid}") +- else: +- logger.warning(f"User {userid} not found") +- except SQLAlchemyError as e: +- logger.error(f"Error updating token usage: {e}") +- +- @staticmethod +- async def get_token_usage(userid: str) -> Optional[Dict[str, Any]]: +- """ +- Get token usage for a specific user. +- +- Args: +- userid (str): User's unique ID +- """ +- try: +- async with get_db_session() as session: +- user = await session.get(UserModel, userid) +- if user: +- return { +- "prompt_tokens": user.prompt_tokens, +- "completion_tokens": user.completion_tokens, +- "total_tokens": user.prompt_tokens + user.completion_tokens, +- "queries": user.queries, +- } +- else: +- logger.warning(f"User {userid} not found") +- return None +- except SQLAlchemyError as e: +- logger.error(f"Error updating token usage: {e}") +- return None +- +- @staticmethod +- async def get_all_users() -> Optional[List[UserData]]: +- """ +- Retrieve all users from the database. +- +- Returns: +- List of UserData or None if no users found +- """ +- try: +- async with get_db_session() as session: +- users = await session.execute(sqlalchemy.select(UserModel)) +- users = users.scalars().all() +- return [UserData.from_sqlalchemy(user) for user in users] +- except SQLAlchemyError as e: +- logger.error(f"Error retrieving all users: {e}") +- return None +- +- @staticmethod +- async def get_user_token_usage(userid: str) -> Optional[Dict[str, int]]: +- """ +- Retrieve total token usage for a user. +- +- Args: +- userid (str): User's unique ID +- +- Returns: +- Dict of token usage or None if user not found +- """ +- try: +- async with get_db_session() as session: +- user = await session.get(UserModel, userid) +- if user: +- return { +- "prompt_tokens": user.prompt_tokens, +- "completion_tokens": user.completion_tokens, +- "queries": user.queries, +- } +- return None +- except SQLAlchemyError as e: +- logger.error(f"Error retrieving token usage: {e}") ++ logger.error(f"Rate limit checking user id: {e}") + return None + + @staticmethod +- async def update_rate_limits(userid: str, rate_limits: RateLimits) -> bool: ++ async def update_rate_limits(user_id: str, rate_limits: RateLimits) -> bool: + """ + Update rate limits for a specific user. + + Args: +- userid (str): User's unique ID ++ user_id (str): User's unique ID + rate_limits (RateLimits): New rate limit configuration + + Returns: +@@ -344,14 +178,14 @@ class UserManager: + """ + try: + async with get_db_session() as session: +- user = await session.get(UserModel, userid) ++ user = await session.get(UserModel, user_id) + if user: + user.rate_limits = rate_limits.model_dump() + await session.commit() +- logger.info(f"Updated rate limits for user {userid}") ++ logger.info(f"Updated rate limits for user {user_id}") + return True + else: +- logger.warning(f"User {userid} not found") ++ logger.warning(f"User {user_id} not found") + return False + except SQLAlchemyError as e: + logger.error(f"Error updating rate limits: {e}") +diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py +index 8205b55..347f162 100644 +--- a/nilai-api/src/nilai_api/rate_limiting.py ++++ b/nilai-api/src/nilai_api/rate_limiting.py +@@ -53,7 +53,7 @@ async def _extract_coroutine_result(maybe_future, request: Request): + + + class UserRateLimits(BaseModel): +- subscription_holder: str ++ user_id: str + token_rate_limit: TokenRateLimits | None + rate_limits: RateLimits + +@@ -61,14 +61,13 @@ class UserRateLimits(BaseModel): + def get_user_limits( + auth_info: Annotated[AuthenticationInfo, Depends(get_auth_info)], + ) -> UserRateLimits: +- # TODO: When the only allowed strategy is NUC, we can change the apikey name to subscription_holder +- # In apikey mode, the apikey is unique as the userid. +- # In nuc mode, the apikey is associated with a subscription holder and the userid is the user ++ # In apikey mode, the apikey is unique as the user_id. ++ # In nuc mode, the apikey is associated with a subscription holder and the user_id is the user + # For NUCs we want the rate limit to be per subscription holder, not per user +- # In JWT mode, the apikey is the userid too ++ # In JWT mode, the apikey is the user_id too + # So we use the apikey as the id + return UserRateLimits( +- subscription_holder=auth_info.user.apikey, ++ user_id=auth_info.user.user_id, + token_rate_limit=auth_info.token_rate_limit, + rate_limits=auth_info.user.rate_limits, + ) +@@ -106,21 +105,21 @@ class RateLimit: + await self.check_bucket( + redis, + redis_rate_limit_command, +- f"minute:{user_limits.subscription_holder}", ++ f"minute:{user_limits.user_id}", + user_limits.rate_limits.user_rate_limit_minute, + MINUTE_MS, + ) + await self.check_bucket( + redis, + redis_rate_limit_command, +- f"hour:{user_limits.subscription_holder}", ++ f"hour:{user_limits.user_id}", + user_limits.rate_limits.user_rate_limit_hour, + HOUR_MS, + ) + await self.check_bucket( + redis, + redis_rate_limit_command, +- f"day:{user_limits.subscription_holder}", ++ f"day:{user_limits.user_id}", + user_limits.rate_limits.user_rate_limit_day, + DAY_MS, + ) +@@ -128,7 +127,7 @@ class RateLimit: + await self.check_bucket( + redis, + redis_rate_limit_command, +- f"user:{user_limits.subscription_holder}", ++ f"user:{user_limits.user_id}", + user_limits.rate_limits.user_rate_limit, + 0, # No expiration for for-good rate limit + ) +@@ -176,7 +175,7 @@ class RateLimit: + await self.check_bucket( + redis, + redis_rate_limit_command, +- f"web_search:{user_limits.subscription_holder}", ++ f"web_search:{user_limits.user_id}", + user_limits.rate_limits.web_search_rate_limit, + 0, # No expiration for for-good rate limit + ) +@@ -199,7 +198,7 @@ class RateLimit: + await self.check_bucket( + redis, + redis_rate_limit_command, +- f"web_search_{time_unit}:{user_limits.subscription_holder}", ++ f"web_search_{time_unit}:{user_limits.user_id}", + limit, + milliseconds, + ) +diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py +index db75067..038f8db 100644 +--- a/nilai-api/src/nilai_api/routers/private.py ++++ b/nilai-api/src/nilai_api/routers/private.py +@@ -11,13 +11,20 @@ from nilai_api.handlers.nilrag import handle_nilrag + from nilai_api.handlers.web_search import handle_web_search + from nilai_api.handlers.tools.tool_router import handle_tool_workflow + +-from fastapi import APIRouter, Body, Depends, HTTPException, status, Request ++from fastapi import ( ++ APIRouter, ++ BackgroundTasks, ++ Body, ++ Depends, ++ HTTPException, ++ status, ++ Request, ++) + from fastapi.responses import StreamingResponse + from nilai_api.auth import get_auth_info, AuthenticationInfo + from nilai_api.config import CONFIG + from nilai_api.crypto import sign_message +-from nilai_api.db.logs import QueryLogManager +-from nilai_api.db.users import UserManager ++from nilai_api.db.logs import QueryLogContext, QueryLogManager + from nilai_api.rate_limiting import RateLimit + from nilai_api.state import state + +@@ -53,14 +60,10 @@ router = APIRouter() + @router.get("/v1/delegation") + async def get_prompt_store_delegation( + prompt_delegation_request: PromptDelegationRequest, +- auth_info: AuthenticationInfo = Depends(get_auth_info), ++ _: AuthenticationInfo = Depends( ++ get_auth_info ++ ), # This is to satisfy that the user is authenticated + ) -> PromptDelegationToken: +- if not auth_info.user.is_subscription_owner: +- raise HTTPException( +- status_code=status.HTTP_403_FORBIDDEN, +- detail=f"Prompt storage is reserved to subscription owners: {auth_info.user} is not a subscription owner, apikey: {auth_info.user}", +- ) +- + try: + return await get_nildb_delegation_token(prompt_delegation_request) + except Exception as e: +@@ -84,12 +87,15 @@ async def get_usage(auth_info: AuthenticationInfo = Depends(get_auth_info)) -> U + usage = await get_usage(user) + ``` + """ +- return Usage( +- prompt_tokens=auth_info.user.prompt_tokens, +- completion_tokens=auth_info.user.completion_tokens, +- total_tokens=auth_info.user.prompt_tokens + auth_info.user.completion_tokens, +- queries=auth_info.user.queries, # type: ignore # FIXME this field is not part of Usage ++ user_usage: Optional[Usage] = await QueryLogManager.get_user_token_usage( ++ auth_info.user.user_id + ) ++ if user_usage is None: ++ raise HTTPException( ++ status_code=status.HTTP_404_NOT_FOUND, ++ detail="User not found", ++ ) ++ return user_usage + + + @router.get("/v1/attestation/report", tags=["Attestation"]) +@@ -173,6 +179,7 @@ async def chat_completion( + ], + ) + ), ++ background_tasks: BackgroundTasks = BackgroundTasks(), + _rate_limit=Depends( + RateLimit( + concurrent_extractor=chat_completion_concurrent_rate_limit, +@@ -181,6 +188,7 @@ async def chat_completion( + ), + auth_info: AuthenticationInfo = Depends(get_auth_info), + meter: MeteringContext = Depends(LLMMeter), ++ log_ctx: QueryLogContext = Depends(QueryLogContext), + ) -> Union[SignedChatCompletion, StreamingResponse]: + """ + Generate a chat completion response from the AI model. +@@ -234,249 +242,312 @@ async def chat_completion( + ) + response = await chat_completion(request, user) + """ +- +- if len(req.messages) == 0: +- raise HTTPException( +- status_code=400, +- detail="Request contained 0 messages", +- ) ++ # Initialize log context early so we can log any errors ++ log_ctx.set_user(auth_info.user.user_id) ++ log_ctx.set_lockid(meter.lock_id) + model_name = req.model + request_id = str(uuid.uuid4()) + t_start = time.monotonic() +- logger.info(f"[chat] call start request_id={req.messages}") +- endpoint = await state.get_model(model_name) +- if endpoint is None: +- raise HTTPException( +- status_code=status.HTTP_400_BAD_REQUEST, +- detail=f"Invalid model name {model_name}, check /v1/models for options", +- ) +- +- if not endpoint.metadata.tool_support and req.tools: +- raise HTTPException( +- status_code=400, +- detail="Model does not support tool usage, remove tools from request", +- ) + +- has_multimodal = req.has_multimodal_content() +- logger.info(f"[chat] has_multimodal: {has_multimodal}") +- if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): +- raise HTTPException( +- status_code=400, +- detail="Model does not support multimodal content, remove image inputs from request", +- ) +- +- model_url = endpoint.url + "/v1/" ++ try: ++ if len(req.messages) == 0: ++ raise HTTPException( ++ status_code=400, ++ detail="Request contained 0 messages", ++ ) ++ logger.info(f"[chat] call start request_id={req.messages}") ++ endpoint = await state.get_model(model_name) ++ if endpoint is None: ++ raise HTTPException( ++ status_code=status.HTTP_400_BAD_REQUEST, ++ detail=f"Invalid model name {model_name}, check /v1/models for options", ++ ) + +- logger.info( +- f"[chat] start request_id={request_id} user={auth_info.user.userid} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" +- ) ++ # Now we have a valid model, set it in log context ++ log_ctx.set_model(model_name) + +- client = AsyncOpenAI(base_url=model_url, api_key="") +- if auth_info.prompt_document: +- try: +- nildb_prompt: str = await get_prompt_from_nildb(auth_info.prompt_document) +- req.messages.insert( +- 0, MessageAdapter.new_message(role="system", content=nildb_prompt) +- ) +- except Exception as e: ++ if not endpoint.metadata.tool_support and req.tools: + raise HTTPException( +- status_code=status.HTTP_403_FORBIDDEN, +- detail=f"Unable to extract prompt from nilDB: {str(e)}", ++ status_code=400, ++ detail="Model does not support tool usage, remove tools from request", + ) + +- if req.nilrag: +- logger.info(f"[chat] nilrag start request_id={request_id}") +- t_nilrag = time.monotonic() +- await handle_nilrag(req) +- logger.info( +- f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" +- ) ++ has_multimodal = req.has_multimodal_content() ++ logger.info(f"[chat] has_multimodal: {has_multimodal}") ++ if has_multimodal and ( ++ not endpoint.metadata.multimodal_support or req.web_search ++ ): ++ raise HTTPException( ++ status_code=400, ++ detail="Model does not support multimodal content, remove image inputs from request", ++ ) + +- messages = req.messages +- sources: Optional[List[Source]] = None ++ model_url = endpoint.url + "/v1/" + +- if req.web_search: +- logger.info(f"[chat] web_search start request_id={request_id}") +- t_ws = time.monotonic() +- web_search_result = await handle_web_search(req, model_name, client) +- messages = web_search_result.messages +- sources = web_search_result.sources + logger.info( +- f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" ++ f"[chat] start request_id={request_id} user={auth_info.user.user_id} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" ++ ) ++ log_ctx.set_request_params( ++ temperature=req.temperature, ++ max_tokens=req.max_tokens, ++ was_streamed=req.stream or False, ++ was_multimodal=has_multimodal, ++ was_nildb=bool(auth_info.prompt_document), ++ was_nilrag=bool(req.nilrag), + ) +- logger.info(f"[chat] web_search messages: {messages}") + +- if req.stream: ++ client = AsyncOpenAI(base_url=model_url, api_key="") ++ if auth_info.prompt_document: ++ try: ++ nildb_prompt: str = await get_prompt_from_nildb( ++ auth_info.prompt_document ++ ) ++ req.messages.insert( ++ 0, MessageAdapter.new_message(role="system", content=nildb_prompt) ++ ) ++ except Exception as e: ++ raise HTTPException( ++ status_code=status.HTTP_403_FORBIDDEN, ++ detail=f"Unable to extract prompt from nilDB: {str(e)}", ++ ) + +- async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: +- t_call = time.monotonic() +- prompt_token_usage = 0 +- completion_token_usage = 0 ++ if req.nilrag: ++ logger.info(f"[chat] nilrag start request_id={request_id}") ++ t_nilrag = time.monotonic() ++ await handle_nilrag(req) ++ logger.info( ++ f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" ++ ) + +- try: +- logger.info(f"[chat] stream start request_id={request_id}") +- +- request_kwargs = { +- "model": req.model, +- "messages": messages, +- "stream": True, +- "top_p": req.top_p, +- "temperature": req.temperature, +- "max_tokens": req.max_tokens, +- "extra_body": { +- "stream_options": { +- "include_usage": True, +- "continuous_usage_stats": False, +- } +- }, +- } +- if req.tools: +- request_kwargs["tools"] = req.tools ++ messages = req.messages ++ sources: Optional[List[Source]] = None ++ ++ if req.web_search: ++ logger.info(f"[chat] web_search start request_id={request_id}") ++ t_ws = time.monotonic() ++ web_search_result = await handle_web_search(req, model_name, client) ++ messages = web_search_result.messages ++ sources = web_search_result.sources ++ logger.info( ++ f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" ++ ) ++ logger.info(f"[chat] web_search messages: {messages}") ++ ++ if req.stream: ++ ++ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: ++ t_call = time.monotonic() ++ prompt_token_usage = 0 ++ completion_token_usage = 0 ++ ++ try: ++ logger.info(f"[chat] stream start request_id={request_id}") ++ ++ log_ctx.start_model_timing() ++ ++ request_kwargs = { ++ "model": req.model, ++ "messages": messages, ++ "stream": True, ++ "top_p": req.top_p, ++ "temperature": req.temperature, ++ "max_tokens": req.max_tokens, ++ "extra_body": { ++ "stream_options": { ++ "include_usage": True, ++ "continuous_usage_stats": False, ++ } ++ }, ++ } ++ if req.tools: ++ request_kwargs["tools"] = req.tools + +- response = await client.chat.completions.create(**request_kwargs) ++ response = await client.chat.completions.create(**request_kwargs) + +- async for chunk in response: +- if chunk.usage is not None: +- prompt_token_usage = chunk.usage.prompt_tokens +- completion_token_usage = chunk.usage.completion_tokens ++ async for chunk in response: ++ if chunk.usage is not None: ++ prompt_token_usage = chunk.usage.prompt_tokens ++ completion_token_usage = chunk.usage.completion_tokens + +- payload = chunk.model_dump(exclude_unset=True) ++ payload = chunk.model_dump(exclude_unset=True) + +- if chunk.usage is not None and sources: +- payload["sources"] = [ +- s.model_dump(mode="json") for s in sources +- ] ++ if chunk.usage is not None and sources: ++ payload["sources"] = [ ++ s.model_dump(mode="json") for s in sources ++ ] + +- yield f"data: {json.dumps(payload)}\n\n" ++ yield f"data: {json.dumps(payload)}\n\n" + +- await UserManager.update_token_usage( +- auth_info.user.userid, +- prompt_tokens=prompt_token_usage, +- completion_tokens=completion_token_usage, +- ) +- meter.set_response( +- { +- "usage": LLMUsage( +- prompt_tokens=prompt_token_usage, +- completion_tokens=completion_token_usage, +- web_searches=len(sources) if sources else 0, +- ) +- } +- ) +- await QueryLogManager.log_query( +- auth_info.user.userid, +- model=req.model, +- prompt_tokens=prompt_token_usage, +- completion_tokens=completion_token_usage, +- web_search_calls=len(sources) if sources else 0, +- ) +- logger.info( +- "[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d " +- "duration_ms=%.0f total_ms=%.0f", +- request_id, +- prompt_token_usage, +- completion_token_usage, +- (time.monotonic() - t_call) * 1000, +- (time.monotonic() - t_start) * 1000, +- ) ++ log_ctx.end_model_timing() ++ meter.set_response( ++ { ++ "usage": LLMUsage( ++ prompt_tokens=prompt_token_usage, ++ completion_tokens=completion_token_usage, ++ web_searches=len(sources) if sources else 0, ++ ) ++ } ++ ) ++ log_ctx.set_usage( ++ prompt_tokens=prompt_token_usage, ++ completion_tokens=completion_token_usage, ++ web_search_calls=len(sources) if sources else 0, ++ ) ++ await log_ctx.commit() ++ logger.info( ++ "[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d " ++ "duration_ms=%.0f total_ms=%.0f", ++ request_id, ++ prompt_token_usage, ++ completion_token_usage, ++ (time.monotonic() - t_call) * 1000, ++ (time.monotonic() - t_start) * 1000, ++ ) ++ ++ except Exception as e: ++ logger.error( ++ "[chat] stream error request_id=%s error=%s", request_id, e ++ ) ++ log_ctx.set_error(error_code=500, error_message=str(e)) ++ await log_ctx.commit() ++ yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" ++ ++ return StreamingResponse( ++ chat_completion_stream_generator(), ++ media_type="text/event-stream", ++ ) + +- except Exception as e: +- logger.error( +- "[chat] stream error request_id=%s error=%s", request_id, e +- ) +- yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" ++ current_messages = messages ++ request_kwargs = { ++ "model": req.model, ++ "messages": current_messages, # type: ignore ++ "top_p": req.top_p, ++ "temperature": req.temperature, ++ "max_tokens": req.max_tokens, ++ } ++ if req.tools: ++ request_kwargs["tools"] = req.tools # type: ignore ++ request_kwargs["tool_choice"] = req.tool_choice ++ ++ logger.info(f"[chat] call start request_id={request_id}") ++ logger.info(f"[chat] call message: {current_messages}") ++ t_call = time.monotonic() ++ log_ctx.start_model_timing() ++ response = await client.chat.completions.create(**request_kwargs) # type: ignore ++ log_ctx.end_model_timing() ++ logger.info( ++ f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" ++ ) ++ logger.info(f"[chat] call response: {response}") ++ ++ # Handle tool workflow fully inside tools.router ++ log_ctx.start_tool_timing() ++ ( ++ final_completion, ++ agg_prompt_tokens, ++ agg_completion_tokens, ++ ) = await handle_tool_workflow(client, req, current_messages, response) ++ log_ctx.end_tool_timing() ++ logger.info(f"[chat] call final_completion: {final_completion}") ++ model_response = SignedChatCompletion( ++ **final_completion.model_dump(), ++ signature="", ++ sources=sources, ++ ) + +- return StreamingResponse( +- chat_completion_stream_generator(), +- media_type="text/event-stream", ++ logger.info( ++ f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) + +- current_messages = messages +- request_kwargs = { +- "model": req.model, +- "messages": current_messages, # type: ignore +- "top_p": req.top_p, +- "temperature": req.temperature, +- "max_tokens": req.max_tokens, +- } +- if req.tools: +- request_kwargs["tools"] = req.tools # type: ignore +- request_kwargs["tool_choice"] = req.tool_choice +- +- logger.info(f"[chat] call start request_id={request_id}") +- logger.info(f"[chat] call message: {current_messages}") +- t_call = time.monotonic() +- response = await client.chat.completions.create(**request_kwargs) # type: ignore +- logger.info( +- f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" +- ) +- logger.info(f"[chat] call response: {response}") +- +- # Handle tool workflow fully inside tools.router +- ( +- final_completion, +- agg_prompt_tokens, +- agg_completion_tokens, +- ) = await handle_tool_workflow(client, req, current_messages, response) +- logger.info(f"[chat] call final_completion: {final_completion}") +- model_response = SignedChatCompletion( +- **final_completion.model_dump(), +- signature="", +- sources=sources, +- ) ++ if model_response.usage is None: ++ raise HTTPException( ++ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, ++ detail="Model response does not contain usage statistics", ++ ) + +- logger.info( +- f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" +- ) ++ if agg_prompt_tokens or agg_completion_tokens: ++ total_prompt_tokens = response.usage.prompt_tokens ++ total_completion_tokens = response.usage.completion_tokens + +- if model_response.usage is None: +- raise HTTPException( +- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, +- detail="Model response does not contain usage statistics", +- ) ++ total_prompt_tokens += agg_prompt_tokens ++ total_completion_tokens += agg_completion_tokens + +- if agg_prompt_tokens or agg_completion_tokens: +- total_prompt_tokens = response.usage.prompt_tokens +- total_completion_tokens = response.usage.completion_tokens ++ model_response.usage.prompt_tokens = total_prompt_tokens ++ model_response.usage.completion_tokens = total_completion_tokens ++ model_response.usage.total_tokens = ( ++ total_prompt_tokens + total_completion_tokens ++ ) ++ ++ # Update token usage in DB ++ meter.set_response( ++ { ++ "usage": LLMUsage( ++ prompt_tokens=model_response.usage.prompt_tokens, ++ completion_tokens=model_response.usage.completion_tokens, ++ web_searches=len(sources) if sources else 0, ++ ) ++ } ++ ) + +- total_prompt_tokens += agg_prompt_tokens +- total_completion_tokens += agg_completion_tokens ++ # Log query with context ++ tool_calls_count = 0 ++ if final_completion.choices and final_completion.choices[0].message.tool_calls: ++ tool_calls_count = len(final_completion.choices[0].message.tool_calls) + +- model_response.usage.prompt_tokens = total_prompt_tokens +- model_response.usage.completion_tokens = total_completion_tokens +- model_response.usage.total_tokens = ( +- total_prompt_tokens + total_completion_tokens ++ log_ctx.set_usage( ++ prompt_tokens=model_response.usage.prompt_tokens, ++ completion_tokens=model_response.usage.completion_tokens, ++ tool_calls=tool_calls_count, ++ web_search_calls=len(sources) if sources else 0, + ) ++ # Use background task for successful requests to avoid blocking response ++ background_tasks.add_task(log_ctx.commit) + +- # Update token usage in DB +- await UserManager.update_token_usage( +- auth_info.user.userid, +- prompt_tokens=model_response.usage.prompt_tokens, +- completion_tokens=model_response.usage.completion_tokens, +- ) +- meter.set_response( +- { +- "usage": LLMUsage( +- prompt_tokens=model_response.usage.prompt_tokens, +- completion_tokens=model_response.usage.completion_tokens, +- web_searches=len(sources) if sources else 0, +- ) +- } +- ) +- await QueryLogManager.log_query( +- auth_info.user.userid, +- model=req.model, +- prompt_tokens=model_response.usage.prompt_tokens, +- completion_tokens=model_response.usage.completion_tokens, +- web_search_calls=len(sources) if sources else 0, +- ) ++ # Sign the response ++ response_json = model_response.model_dump_json() ++ signature = sign_message(state.private_key, response_json) ++ model_response.signature = b64encode(signature).decode() + +- # Sign the response +- response_json = model_response.model_dump_json() +- signature = sign_message(state.private_key, response_json) +- model_response.signature = b64encode(signature).decode() ++ logger.info( ++ f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" ++ ) ++ return model_response ++ except HTTPException as e: ++ # Extract error code from HTTPException, default to status code ++ error_code = e.status_code ++ error_message = str(e.detail) if e.detail else str(e) ++ logger.error( ++ f"[chat] HTTPException request_id={request_id} user={auth_info.user.user_id} " ++ f"model={model_name} error_code={error_code} error={error_message}", ++ exc_info=True, ++ ) + +- logger.info( +- f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" +- ) +- return model_response ++ # Only log server errors (5xx) to database to prevent DoS attacks via client errors ++ # Client errors (4xx) are logged to application logs only ++ if error_code >= 500: ++ # Set model if not already set (e.g., for validation errors before model validation) ++ if log_ctx.model is None: ++ log_ctx.set_model(model_name) ++ log_ctx.set_error(error_code=error_code, error_message=error_message) ++ await log_ctx.commit() ++ # For 4xx errors, we skip DB logging - they're logged above via logger.error() ++ # This prevents DoS attacks where attackers send many invalid requests ++ ++ raise ++ except Exception as e: ++ # Catch any other unexpected exceptions ++ error_message = str(e) ++ logger.error( ++ f"[chat] unexpected error request_id={request_id} user={auth_info.user.user_id} " ++ f"model={model_name} error={error_message}", ++ exc_info=True, ++ ) ++ # Set model if not already set ++ if log_ctx.model is None: ++ log_ctx.set_model(model_name) ++ log_ctx.set_error(error_code=500, error_message=error_message) ++ await log_ctx.commit() ++ raise HTTPException( ++ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, ++ detail=f"Internal server error: {error_message}", ++ ) +diff --git a/tests/e2e/config.py b/tests/e2e/config.py +index e06f9d4..3111902 100644 +--- a/tests/e2e/config.py ++++ b/tests/e2e/config.py +@@ -38,7 +38,7 @@ models = { + "meta-llama/Llama-3.1-8B-Instruct", + ], + "ci": [ +- "meta-llama/Llama-3.2-1B-Instruct", ++ "llama-3.2-1b-instruct", + ], + } + +diff --git a/tests/integration/nilai_api/test_users_db_integration.py b/tests/integration/nilai_api/test_users_db_integration.py +index 82d8d02..892a3af 100644 +--- a/tests/integration/nilai_api/test_users_db_integration.py ++++ b/tests/integration/nilai_api/test_users_db_integration.py +@@ -17,37 +17,17 @@ class TestUserManagerIntegration: + async def test_simple_user_creation(self, clean_database): + """Test creating a simple user and retrieving it.""" + # Insert user with minimal data +- user = await UserManager.insert_user(name="Simple Test User") ++ user = await UserManager.insert_user(user_id="Simple Test User") + + # Verify user creation +- assert user.name == "Simple Test User" +- assert user.userid is not None +- assert user.apikey is not None +- assert user.userid != user.apikey # Should be different UUIDs ++ assert user.user_id == "Simple Test User" ++ assert user.rate_limits is not None + + # Retrieve user by ID +- found_user = await UserManager.check_user(user.userid) ++ found_user = await UserManager.check_user(user.user_id) + assert found_user is not None +- assert found_user.userid == user.userid +- assert found_user.name == "Simple Test User" +- assert found_user.apikey == user.apikey +- +- @pytest.mark.asyncio +- async def test_api_key_validation(self, clean_database): +- """Test API key validation functionality.""" +- # Create user +- user = await UserManager.insert_user("API Test User") +- +- # Validate correct API key +- api_user = await UserManager.check_api_key(user.apikey) +- assert api_user is not None +- assert api_user.apikey == user.apikey +- assert api_user.userid == user.userid +- assert api_user.name == "API Test User" +- +- # Test invalid API key +- invalid_user = await UserManager.check_api_key("invalid-api-key") +- assert invalid_user is None ++ assert found_user.user_id == user.user_id ++ assert found_user.rate_limits == user.rate_limits + + @pytest.mark.asyncio + async def test_rate_limits_json_crud_basic(self, clean_database): +@@ -66,14 +46,14 @@ class TestUserManagerIntegration: + + # CREATE: Insert user with rate limits + user = await UserManager.insert_user( +- name="Rate Limits Test User", rate_limits=rate_limits ++ user_id="Rate Limits Test User", rate_limits=rate_limits + ) + + # Verify rate limits are stored as JSON + assert user.rate_limits == rate_limits.model_dump() + + # READ: Retrieve user and verify rate limits JSON +- retrieved_user = await UserManager.check_user(user.userid) ++ retrieved_user = await UserManager.check_user(user.user_id) + assert retrieved_user is not None + assert retrieved_user.rate_limits == rate_limits.model_dump() + +@@ -98,11 +78,11 @@ class TestUserManagerIntegration: + ) + + user = await UserManager.insert_user( +- name="Update Rate Limits User", rate_limits=initial_rate_limits ++ user_id="Update Rate Limits User", rate_limits=initial_rate_limits + ) + + # Verify initial rate limits +- retrieved_user = await UserManager.check_user(user.userid) ++ retrieved_user = await UserManager.check_user(user.user_id) + assert retrieved_user is not None + assert retrieved_user.rate_limits == initial_rate_limits.model_dump() + +@@ -125,19 +105,19 @@ class TestUserManagerIntegration: + stmt = sa.text(""" + UPDATE users + SET rate_limits = :rate_limits_json +- WHERE userid = :userid ++ WHERE user_id = :user_id + """) + await session.execute( + stmt, + { + "rate_limits_json": updated_rate_limits.model_dump_json(), +- "userid": user.userid, ++ "user_id": user.user_id, + }, + ) + await session.commit() + + # READ: Verify the update worked +- updated_user = await UserManager.check_user(user.userid) ++ updated_user = await UserManager.check_user(user.user_id) + assert updated_user is not None + assert updated_user.rate_limits == updated_rate_limits.model_dump() + +@@ -162,11 +142,11 @@ class TestUserManagerIntegration: + ) + + user = await UserManager.insert_user( +- name="Partial Rate Limits User", rate_limits=partial_rate_limits ++ user_id="Partial Rate Limits User", rate_limits=partial_rate_limits + ) + + # Verify partial data is stored correctly +- retrieved_user = await UserManager.check_user(user.userid) ++ retrieved_user = await UserManager.check_user(user.user_id) + assert retrieved_user is not None + assert retrieved_user.rate_limits == partial_rate_limits.model_dump() + +@@ -183,13 +163,13 @@ class TestUserManagerIntegration: + '{user_rate_limit_hour}', + '75' + ) +- WHERE userid = :userid ++ WHERE user_id = :user_id + """) +- await session.execute(stmt, {"userid": user.userid}) ++ await session.execute(stmt, {"user_id": user.user_id}) + await session.commit() + + # Verify partial update worked +- updated_user = await UserManager.check_user(user.userid) ++ updated_user = await UserManager.check_user(user.user_id) + assert updated_user is not None + + expected_data = partial_rate_limits.model_dump() +@@ -211,7 +191,7 @@ class TestUserManagerIntegration: + ) + + user = await UserManager.insert_user( +- name="Delete Rate Limits User", rate_limits=rate_limits ++ user_id="Delete Rate Limits User", rate_limits=rate_limits + ) + + # DELETE: Set rate_limits to NULL +@@ -219,12 +199,14 @@ class TestUserManagerIntegration: + import sqlalchemy as sa + + async with get_db_session() as session: +- stmt = sa.text("UPDATE users SET rate_limits = NULL WHERE userid = :userid") +- await session.execute(stmt, {"userid": user.userid}) ++ stmt = sa.text( ++ "UPDATE users SET rate_limits = NULL WHERE user_id = :user_id" ++ ) ++ await session.execute(stmt, {"user_id": user.user_id}) + await session.commit() + + # Verify NULL handling +- null_user = await UserManager.check_user(user.userid) ++ null_user = await UserManager.check_user(user.user_id) + assert null_user is not None + assert null_user.rate_limits is None + +@@ -239,15 +221,15 @@ class TestUserManagerIntegration: + # First set some data + new_data = {"user_rate_limit_day": 500, "web_search_rate_limit_day": 25} + stmt = sa.text( +- "UPDATE users SET rate_limits = :data WHERE userid = :userid" ++ "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" + ) + await session.execute( +- stmt, {"data": json.dumps(new_data), "userid": user.userid} ++ stmt, {"data": json.dumps(new_data), "user_id": user.user_id} + ) + await session.commit() + + # Verify data was set +- updated_user = await UserManager.check_user(user.userid) ++ updated_user = await UserManager.check_user(user.user_id) + assert updated_user is not None + assert updated_user.rate_limits == new_data + +@@ -256,13 +238,13 @@ class TestUserManagerIntegration: + stmt = sa.text(""" + UPDATE users + SET rate_limits = rate_limits::jsonb - 'web_search_rate_limit_day' +- WHERE userid = :userid ++ WHERE user_id = :user_id + """) +- await session.execute(stmt, {"userid": user.userid}) ++ await session.execute(stmt, {"user_id": user.user_id}) + await session.commit() + + # Verify field was removed +- final_user = await UserManager.check_user(user.userid) ++ final_user = await UserManager.check_user(user.user_id) + expected_final_data = {"user_rate_limit_day": 500} + assert final_user is not None + assert final_user.rate_limits == expected_final_data +@@ -293,15 +275,15 @@ class TestUserManagerIntegration: + for i, test_data in enumerate(test_cases): + async with get_db_session() as session: + stmt = sa.text( +- "UPDATE users SET rate_limits = :data WHERE userid = :userid" ++ "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" + ) + await session.execute( +- stmt, {"data": json.dumps(test_data), "userid": user.userid} ++ stmt, {"data": json.dumps(test_data), "user_id": user.user_id} + ) + await session.commit() + + # Retrieve and verify +- updated_user = await UserManager.check_user(user.userid) ++ updated_user = await UserManager.check_user(user.user_id) + assert updated_user is not None + assert updated_user.rate_limits == test_data + +@@ -327,11 +309,13 @@ class TestUserManagerIntegration: + + # Test empty JSON object + async with get_db_session() as session: +- stmt = sa.text("UPDATE users SET rate_limits = '{}' WHERE userid = :userid") +- await session.execute(stmt, {"userid": user.userid}) ++ stmt = sa.text( ++ "UPDATE users SET rate_limits = '{}' WHERE user_id = :user_id" ++ ) ++ await session.execute(stmt, {"user_id": user.user_id}) + await session.commit() + +- empty_user = await UserManager.check_user(user.userid) ++ empty_user = await UserManager.check_user(user.user_id) + assert empty_user is not None + assert empty_user.rate_limits == {} + empty_rate_limits_obj = empty_user.rate_limits_obj +@@ -343,18 +327,18 @@ class TestUserManagerIntegration: + async with get_db_session() as session: + # This should work as PostgreSQL JSONB validates JSON + stmt = sa.text( +- "UPDATE users SET rate_limits = :data WHERE userid = :userid" ++ "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" + ) + await session.execute( + stmt, + { + "data": '{"user_rate_limit_day": 5000}', # Valid JSON string +- "userid": user.userid, ++ "user_id": user.user_id, + }, + ) + await session.commit() + +- json_string_user = await UserManager.check_user(user.userid) ++ json_string_user = await UserManager.check_user(user.user_id) + assert json_string_user is not None + assert json_string_user.rate_limits == {"user_rate_limit_day": 5000} + +@@ -366,16 +350,16 @@ class TestUserManagerIntegration: + async def test_rate_limits_update_workflow(self, clean_database): + """Test complete workflow: create user with no rate limits -> update rate limits -> verify update.""" + # Step 1: Create user with NO rate limits +- user = await UserManager.insert_user(name="Rate Limits Workflow User") ++ user = await UserManager.insert_user(user_id="Rate Limits Workflow User") + + # Verify user was created with no rate limits + assert user.name == "Rate Limits Workflow User" +- assert user.userid is not None ++ assert user.user_id is not None + assert user.apikey is not None + assert user.rate_limits is None # No rate limits initially + + # Step 2: Retrieve user and confirm no rate limits +- retrieved_user = await UserManager.check_user(user.userid) ++ retrieved_user = await UserManager.check_user(user.user_id) + assert retrieved_user is not None + print(retrieved_user.to_pydantic()) + assert retrieved_user is not None +@@ -401,12 +385,12 @@ class TestUserManagerIntegration: + + # Step 4: Update the user's rate limits using the new function + update_success = await UserManager.update_rate_limits( +- user.userid, new_rate_limits ++ user.user_id, new_rate_limits + ) + assert update_success is True + + # Step 5: Retrieve user again and verify rate limits were updated +- updated_user = await UserManager.check_user(user.userid) ++ updated_user = await UserManager.check_user(user.user_id) + assert updated_user is not None + assert updated_user.rate_limits is not None + assert updated_user.rate_limits == new_rate_limits.model_dump() +@@ -431,12 +415,12 @@ class TestUserManagerIntegration: + ) + + partial_update_success = await UserManager.update_rate_limits( +- user.userid, partial_rate_limits ++ user.user_id, partial_rate_limits + ) + assert partial_update_success is True + + # Step 8: Verify partial update worked +- final_user = await UserManager.check_user(user.userid) ++ final_user = await UserManager.check_user(user.user_id) + assert final_user is not None + assert final_user.rate_limits == partial_rate_limits.model_dump() + +@@ -447,8 +431,8 @@ class TestUserManagerIntegration: + # Other fields should have config defaults (not None due to get_effective_limits) + + # Step 9: Test error case - update non-existent user +- fake_userid = "non-existent-user-id" ++ fake_user_id = "non-existent-user-id" + error_update = await UserManager.update_rate_limits( +- fake_userid, new_rate_limits ++ fake_user_id, new_rate_limits + ) + assert error_update is False +diff --git a/tests/unit/nilai_api/__init__.py b/tests/unit/nilai_api/__init__.py +index 0be5261..7cbc123 100644 +--- a/tests/unit/nilai_api/__init__.py ++++ b/tests/unit/nilai_api/__init__.py +@@ -21,11 +21,11 @@ class MockUserDatabase: + + async def insert_user(self, name: str, email: str) -> Dict[str, str]: + """Insert a new user into the mock database.""" +- userid = self.generate_user_id() ++ user_id = self.generate_user_id() + apikey = self.generate_api_key() + + user_data = { +- "userid": userid, ++ "user_id": user_id, + "name": name, + "email": email, + "apikey": apikey, +@@ -36,34 +36,34 @@ class MockUserDatabase: + "last_activity": None, + } + +- self.users[userid] = user_data +- return {"userid": userid, "apikey": apikey} ++ self.users[user_id] = user_data ++ return {"user_id": user_id, "apikey": apikey} + + async def check_api_key(self, api_key: str) -> Optional[dict]: + """Validate an API key in the mock database.""" + for user in self.users.values(): + if user["apikey"] == api_key: +- return {"name": user["name"], "userid": user["userid"]} ++ return {"name": user["name"], "user_id": user["user_id"]} + return None + + async def update_token_usage( +- self, userid: str, prompt_tokens: int, completion_tokens: int ++ self, user_id: str, prompt_tokens: int, completion_tokens: int + ): + """Update token usage for a specific user.""" +- if userid in self.users: +- user = self.users[userid] ++ if user_id in self.users: ++ user = self.users[user_id] + user["prompt_tokens"] += prompt_tokens + user["completion_tokens"] += completion_tokens + user["queries"] += 1 + user["last_activity"] = datetime.now(timezone.utc) + + async def log_query( +- self, userid: str, model: str, prompt_tokens: int, completion_tokens: int ++ self, user_id: str, model: str, prompt_tokens: int, completion_tokens: int + ): + """Log a user's query in the mock database.""" + query_log = { + "id": self._next_query_log_id, +- "userid": userid, ++ "user_id": user_id, + "query_timestamp": datetime.now(timezone.utc), + "model": model, + "prompt_tokens": prompt_tokens, +@@ -74,9 +74,9 @@ class MockUserDatabase: + self.query_logs[self._next_query_log_id] = query_log + self._next_query_log_id += 1 + +- async def get_token_usage(self, userid: str) -> Optional[Dict[str, Any]]: ++ async def get_token_usage(self, user_id: str) -> Optional[Dict[str, Any]]: + """Get token usage for a specific user.""" +- user = self.users.get(userid) ++ user = self.users.get(user_id) + if user: + return { + "prompt_tokens": user["prompt_tokens"], +@@ -90,9 +90,9 @@ class MockUserDatabase: + """Retrieve all users from the mock database.""" + return list(self.users.values()) if self.users else None + +- async def get_user_token_usage(self, userid: str) -> Optional[Dict[str, int]]: ++ async def get_user_token_usage(self, user_id: str) -> Optional[Dict[str, int]]: + """Retrieve total token usage for a user.""" +- user = self.users.get(userid) ++ user = self.users.get(user_id) + if user: + return { + "prompt_tokens": user["prompt_tokens"], +diff --git a/tests/unit/nilai_api/auth/test_auth.py b/tests/unit/nilai_api/auth/test_auth.py +index 591c447..ec1aabc 100644 +--- a/tests/unit/nilai_api/auth/test_auth.py ++++ b/tests/unit/nilai_api/auth/test_auth.py +@@ -29,7 +29,7 @@ def mock_user_model(): + + mock = MagicMock(spec=UserModel) + mock.name = "Test User" +- mock.userid = "test-user-id" ++ mock.user_id = "test-user-id" + mock.apikey = "test-api-key" + mock.prompt_tokens = 0 + mock.completion_tokens = 0 +@@ -72,11 +72,9 @@ async def test_get_auth_info_valid_token( + + auth_info = await get_auth_info(credentials) + print(auth_info) +- assert auth_info.user.name == "Test User", ( +- f"Expected Test User but got {auth_info.user.name}" +- ) +- assert auth_info.user.userid == "test-user-id", ( +- f"Expected test-user-id but got {auth_info.user.userid}" ++ ++ assert auth_info.user.user_id == "test-user-id", ( ++ f"Expected test-user-id but got {auth_info.user.user_id}" + ) + + +diff --git a/tests/unit/nilai_api/auth/test_strategies.py b/tests/unit/nilai_api/auth/test_strategies.py +index 0c169f5..d362786 100644 +--- a/tests/unit/nilai_api/auth/test_strategies.py ++++ b/tests/unit/nilai_api/auth/test_strategies.py +@@ -16,7 +16,7 @@ class TestAuthStrategies: + """Mock UserModel fixture""" + mock = MagicMock(spec=UserModel) + mock.name = "Test User" +- mock.userid = "test-user-id" ++ mock.user_id = "test-user-id" + mock.apikey = "test-api-key" + mock.prompt_tokens = 0 + mock.completion_tokens = 0 +@@ -43,7 +43,6 @@ class TestAuthStrategies: + result = await api_key_strategy("test-api-key") + + assert isinstance(result, AuthenticationInfo) +- assert result.user.name == "Test User" + assert result.token_rate_limit is None + assert result.prompt_document is None + +@@ -84,7 +83,6 @@ class TestAuthStrategies: + result = await nuc_strategy("nuc-token") + + assert isinstance(result, AuthenticationInfo) +- assert result.user.name == "Test User" + assert result.token_rate_limit is None + assert result.prompt_document == mock_prompt_document + +@@ -154,7 +152,6 @@ class TestAuthStrategies: + result = await nuc_strategy("nuc-token") + + assert isinstance(result, AuthenticationInfo) +- assert result.user.name == "Test User" + assert result.token_rate_limit is None + assert result.prompt_document is None + +@@ -201,7 +198,7 @@ class TestAuthStrategies: + """Test that all strategies return AuthenticationInfo with prompt_document field""" + mock_user_model = MagicMock(spec=UserModel) + mock_user_model.name = "Test" +- mock_user_model.userid = "test" ++ mock_user_model.user_id = "test" + mock_user_model.apikey = "test" + mock_user_model.prompt_tokens = 0 + mock_user_model.completion_tokens = 0 +diff --git a/tests/unit/nilai_api/routers/test_nildb_endpoints.py b/tests/unit/nilai_api/routers/test_nildb_endpoints.py +index c0103ea..0648980 100644 +--- a/tests/unit/nilai_api/routers/test_nildb_endpoints.py ++++ b/tests/unit/nilai_api/routers/test_nildb_endpoints.py +@@ -18,8 +18,8 @@ class TestNilDBEndpoints: + """Mock user data for subscription owner""" + mock_user_model = MagicMock(spec=UserModel) + mock_user_model.name = "Subscription Owner" +- mock_user_model.userid = "owner-id" +- mock_user_model.apikey = "owner-id" # Same as userid for subscription owner ++ mock_user_model.user_id = "owner-id" ++ mock_user_model.apikey = "owner-id" # Same as user_id for subscription owner + mock_user_model.prompt_tokens = 0 + mock_user_model.completion_tokens = 0 + mock_user_model.queries = 0 +@@ -37,8 +37,8 @@ class TestNilDBEndpoints: + """Mock user data for regular user (not subscription owner)""" + mock_user_model = MagicMock(spec=UserModel) + mock_user_model.name = "Regular User" +- mock_user_model.userid = "user-id" +- mock_user_model.apikey = "different-api-key" # Different from userid ++ mock_user_model.user_id = "user-id" ++ mock_user_model.apikey = "different-api-key" # Different from user_id + mock_user_model.prompt_tokens = 0 + mock_user_model.completion_tokens = 0 + mock_user_model.queries = 0 +@@ -149,7 +149,7 @@ class TestNilDBEndpoints: + ) + + mock_user = MagicMock() +- mock_user.userid = "test-user-id" ++ mock_user.user_id = "test-user-id" + mock_user.name = "Test User" + mock_user.apikey = "test-api-key" + mock_user.rate_limits = RateLimits().get_effective_limits() +@@ -256,7 +256,7 @@ class TestNilDBEndpoints: + ) + + mock_user = MagicMock() +- mock_user.userid = "test-user-id" ++ mock_user.user_id = "test-user-id" + mock_user.name = "Test User" + mock_user.apikey = "test-api-key" + mock_user.rate_limits = RateLimits().get_effective_limits() +@@ -304,7 +304,7 @@ class TestNilDBEndpoints: + from nilai_common import ChatRequest + + mock_user = MagicMock() +- mock_user.userid = "test-user-id" ++ mock_user.user_id = "test-user-id" + mock_user.name = "Test User" + mock_user.apikey = "test-api-key" + mock_user.rate_limits = RateLimits().get_effective_limits() +@@ -419,8 +419,8 @@ class TestNilDBEndpoints: + self, mock_subscription_owner_user, mock_regular_user + ): + """Test the is_subscription_owner property""" +- # Subscription owner (userid == apikey) ++ # Subscription owner (user_id == apikey) + assert mock_subscription_owner_user.is_subscription_owner is True + +- # Regular user (userid != apikey) ++ # Regular user (user_id != apikey) + assert mock_regular_user.is_subscription_owner is False +diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py +index 1978e83..daafc86 100644 +--- a/tests/unit/nilai_api/routers/test_private.py ++++ b/tests/unit/nilai_api/routers/test_private.py +@@ -20,7 +20,7 @@ async def test_runs_in_a_loop(): + @pytest.fixture + def mock_user(): + mock = MagicMock(spec=UserModel) +- mock.userid = "test-user-id" ++ mock.user_id = "test-user-id" + mock.name = "Test User" + mock.apikey = "test-api-key" + mock.prompt_tokens = 100 +@@ -66,7 +66,7 @@ def mock_user_manager(mock_user, mocker): + UserManager, + "insert_user", + return_value={ +- "userid": "test-user-id", ++ "user_id": "test-user-id", + "apikey": "test-api-key", + "rate_limits": RateLimits().get_effective_limits().model_dump_json(), + }, +@@ -81,12 +81,12 @@ def mock_user_manager(mock_user, mocker): + "get_all_users", + return_value=[ + { +- "userid": "test-user-id", ++ "user_id": "test-user-id", + "apikey": "test-api-key", + "rate_limits": RateLimits().get_effective_limits().model_dump_json(), + }, + { +- "userid": "test-user-id-2", ++ "user_id": "test-user-id-2", + "apikey": "test-api-key", + "rate_limits": RateLimits().get_effective_limits().model_dump_json(), + }, +diff --git a/tests/unit/nilai_api/test_db.py b/tests/unit/nilai_api/test_db.py +index dff0fd8..3979321 100644 +--- a/tests/unit/nilai_api/test_db.py ++++ b/tests/unit/nilai_api/test_db.py +@@ -15,7 +15,7 @@ async def test_insert_user(mock_db): + """Test user insertion functionality.""" + user = await mock_db.insert_user("Test User", "test@example.com") + +- assert "userid" in user ++ assert "user_id" in user + assert "apikey" in user + assert len(mock_db.users) == 1 + +@@ -38,9 +38,9 @@ async def test_token_usage(mock_db): + """Test token usage tracking.""" + user = await mock_db.insert_user("Test User", "test@example.com") + +- await mock_db.update_token_usage(user["userid"], 50, 20) ++ await mock_db.update_token_usage(user["user_id"], 50, 20) + +- token_usage = await mock_db.get_token_usage(user["userid"]) ++ token_usage = await mock_db.get_token_usage(user["user_id"]) + assert token_usage["prompt_tokens"] == 50 + assert token_usage["completion_tokens"] == 20 + assert token_usage["queries"] == 1 +@@ -51,9 +51,9 @@ async def test_query_logging(mock_db): + """Test query logging functionality.""" + user = await mock_db.insert_user("Test User", "test@example.com") + +- await mock_db.log_query(user["userid"], "test-model", 10, 15) ++ await mock_db.log_query(user["user_id"], "test-model", 10, 15) + + assert len(mock_db.query_logs) == 1 + log_entry = list(mock_db.query_logs.values())[0] +- assert log_entry["userid"] == user["userid"] ++ assert log_entry["user_id"] == user["user_id"] + assert log_entry["model"] == "test-model" +diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py +index 4cf53b0..82d2119 100644 +--- a/tests/unit/nilai_api/test_rate_limiting.py ++++ b/tests/unit/nilai_api/test_rate_limiting.py +@@ -44,7 +44,7 @@ async def test_concurrent_rate_limit(req): + rate_limit = RateLimit(concurrent_extractor=lambda _: (5, "test")) + + user_limits = UserRateLimits( +- subscription_holder=random_id(), ++ user_id=random_id(), + token_rate_limit=None, + rate_limits=RateLimits( + user_rate_limit_day=None, +@@ -77,7 +77,7 @@ async def test_concurrent_rate_limit(req): + "user_limits", + [ + UserRateLimits( +- subscription_holder=random_id(), ++ user_id=random_id(), + token_rate_limit=None, + rate_limits=RateLimits( + user_rate_limit_day=10, +@@ -91,7 +91,7 @@ async def test_concurrent_rate_limit(req): + ), + ), + UserRateLimits( +- subscription_holder=random_id(), ++ user_id=random_id(), + token_rate_limit=None, + rate_limits=RateLimits( + user_rate_limit_day=None, +@@ -105,7 +105,7 @@ async def test_concurrent_rate_limit(req): + ), + ), + UserRateLimits( +- subscription_holder=random_id(), ++ user_id=random_id(), + token_rate_limit=None, + rate_limits=RateLimits( + user_rate_limit_day=None, +@@ -119,7 +119,7 @@ async def test_concurrent_rate_limit(req): + ), + ), + UserRateLimits( +- subscription_holder=random_id(), ++ user_id=random_id(), + token_rate_limit=TokenRateLimits( + limits=[ + TokenRateLimit( +@@ -180,7 +180,7 @@ async def test_web_search_rate_limits(redis_client): + + rate_limit = RateLimit(web_search_extractor=web_search_extractor) + user_limits = UserRateLimits( +- subscription_holder=apikey, ++ user_id=apikey, + token_rate_limit=None, + rate_limits=RateLimits( + user_rate_limit_day=None, +@@ -212,7 +212,7 @@ async def test_global_web_search_rps_limit(req, redis_client, monkeypatch): + + rate_limit = RateLimit(web_search_extractor=lambda _: True) + user_limits = UserRateLimits( +- subscription_holder=random_id(), ++ user_id=random_id(), + token_rate_limit=None, + rate_limits=RateLimits( + user_rate_limit_day=None, +@@ -253,7 +253,7 @@ async def test_queueing_across_seconds(req, redis_client, monkeypatch): + + rate_limit = RateLimit(web_search_extractor=lambda _: True) + user_limits = UserRateLimits( +- subscription_holder=random_id(), ++ user_id=random_id(), + token_rate_limit=None, + rate_limits=RateLimits( + user_rate_limit_day=None, +diff --git a/uv.lock b/uv.lock +index d54a2ef..1482449 100644 +--- a/uv.lock ++++ b/uv.lock +@@ -1649,7 +1649,7 @@ requires-dist = [ + { name = "gunicorn", specifier = ">=23.0.0" }, + { name = "httpx", specifier = ">=0.27.2" }, + { name = "nilai-common", editable = "packages/nilai-common" }, +- { name = "nilauth-credit-middleware", specifier = ">=0.1.1" }, ++ { name = "nilauth-credit-middleware", specifier = ">=0.1.2" }, + { name = "nilrag", specifier = ">=0.1.11" }, + { name = "nuc", specifier = ">=0.1.0" }, + { name = "openai", specifier = ">=1.59.9" }, +@@ -1658,7 +1658,7 @@ requires-dist = [ + { name = "python-dotenv", specifier = ">=1.0.1" }, + { name = "pyyaml", specifier = ">=6.0.1" }, + { name = "redis", specifier = ">=5.2.1" }, +- { name = "secretvaults", git = "https://github.com/NillionNetwork/secretvaults-py?rev=feat%2Fbackport-did-key-and-ethr-parsing" }, ++ { name = "secretvaults", git = "https://github.com/jcabrero/secretvaults-py?rev=main" }, + { name = "sqlalchemy", specifier = ">=2.0.36" }, + { name = "trafilatura", specifier = ">=1.7.0" }, + { name = "uvicorn", specifier = ">=0.32.1" }, +@@ -1739,7 +1739,7 @@ dev = [ + + [[package]] + name = "nilauth-credit-middleware" +-version = "0.1.1" ++version = "0.1.2" + source = { registry = "https://pypi.org/simple" } + dependencies = [ + { name = "fastapi", extra = ["standard"] }, +@@ -1747,9 +1747,9 @@ dependencies = [ + { name = "nuc" }, + { name = "pydantic" }, + ] +-sdist = { url = "https://files.pythonhosted.org/packages/9f/cf/7716fa5f4aca83ef39d6f9f8bebc1d80d194c52c9ce6e75ee6bd1f401217/nilauth_credit_middleware-0.1.1.tar.gz", hash = "sha256:ae32c4c1e6bc083c8a7581d72a6da271ce9c0f0f9271a1694acb81ccd0a4a8bd", size = 10259, upload-time = "2025-10-16T11:15:03.918Z" } ++sdist = { url = "https://files.pythonhosted.org/packages/46/bc/ae9b2c26919151fc7193b406a98831eeef197f6ec46b0c075138e66ec016/nilauth_credit_middleware-0.1.2.tar.gz", hash = "sha256:66423a4d18aba1eb5f5d47a04c8f7ae6a19ab4e34433475aa9dc1ba398483fdd", size = 11979, upload-time = "2025-10-30T16:21:20.538Z" } + wheels = [ +- { url = "https://files.pythonhosted.org/packages/a7/b5/6e4090ae2ae8848d12e43f82d8d995cd1dff9de8e947cf5fb2b8a72a828e/nilauth_credit_middleware-0.1.1-py3-none-any.whl", hash = "sha256:10a0fda4ac11f51b9a5dd7b3a8fbabc0b28ff92a170a7729ac11eb15c7b37887", size = 14919, upload-time = "2025-10-16T11:15:02.201Z" }, ++ { url = "https://files.pythonhosted.org/packages/05/c3/73d55667aad701a64f3d1330d66c90a8c292fd19f054093ca74960aca1fb/nilauth_credit_middleware-0.1.2-py3-none-any.whl", hash = "sha256:31f3233e6706c6167b6246a4edb9a405d587eccb1399231223f95c0cdf1ce57c", size = 18121, upload-time = "2025-10-30T16:21:19.547Z" }, + ] + + [[package]] +@@ -2854,8 +2854,8 @@ sdist = { url = "https://files.pythonhosted.org/packages/9b/41/bb668a6e419230354 + + [[package]] + name = "secretvaults" +-version = "0.3.0" +-source = { git = "https://github.com/NillionNetwork/secretvaults-py?rev=feat%2Fbackport-did-key-and-ethr-parsing#b40aebf572c6d4c94dc381e022b82724d727df23" } ++version = "0.2.1" ++source = { git = "https://github.com/jcabrero/secretvaults-py?rev=main#498ee5304fdcc730d1810fcf6172e56fa6dd7d16" } + dependencies = [ + { name = "aiohttp" }, + { name = "blindfold" }, diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index a3c77e01..33d9174c 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -33,8 +33,6 @@ services: condition: service_healthy nilauth-credit-server: condition: service_healthy - environment: - - POSTGRES_DB=${POSTGRES_DB_NUC} volumes: - ./nilai-api/:/app/nilai-api/ - ./packages/:/app/packages/ diff --git a/grafana/runtime-data/dashboards/nuc-query-data.json b/grafana/runtime-data/dashboards/nuc-query-data.json index d66fd428..c7bbb6b2 100644 --- a/grafana/runtime-data/dashboards/nuc-query-data.json +++ b/grafana/runtime-data/dashboards/nuc-query-data.json @@ -126,7 +126,7 @@ "editorMode": "code", "format": "time_series", "rawQuery": true, - "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", + "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", "refId": "A", "sql": { "columns": [ @@ -218,7 +218,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", + "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", "refId": "A", "sql": { "columns": [ @@ -352,7 +352,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", "refId": "A", "sql": { "columns": [ @@ -360,7 +360,7 @@ "alias": "\"User\"", "parameters": [ { - "name": "userid", + "name": "user_id", "type": "functionParameter" } ], @@ -381,7 +381,7 @@ "groupBy": [ { "property": { - "name": "userid", + "name": "user_id", "type": "string" }, "type": "groupBy" @@ -481,7 +481,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", "refId": "A", "sql": { "columns": [ diff --git a/grafana/runtime-data/dashboards/query-data.json b/grafana/runtime-data/dashboards/query-data.json index 8e0b774f..f33f87a8 100644 --- a/grafana/runtime-data/dashboards/query-data.json +++ b/grafana/runtime-data/dashboards/query-data.json @@ -126,7 +126,7 @@ "editorMode": "code", "format": "time_series", "rawQuery": true, - "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", + "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", "refId": "A", "sql": { "columns": [ @@ -218,7 +218,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", + "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", "refId": "A", "sql": { "columns": [ @@ -352,7 +352,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", "refId": "A", "sql": { "columns": [ @@ -360,7 +360,7 @@ "alias": "\"User\"", "parameters": [ { - "name": "userid", + "name": "user_id", "type": "functionParameter" } ], @@ -381,7 +381,7 @@ "groupBy": [ { "property": { - "name": "userid", + "name": "user_id", "type": "string" }, "type": "groupBy" @@ -481,7 +481,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", "refId": "A", "sql": { "columns": [ diff --git a/grafana/runtime-data/dashboards/testnet-nuc-query-data.json b/grafana/runtime-data/dashboards/testnet-nuc-query-data.json index f98d70e9..358ba4eb 100644 --- a/grafana/runtime-data/dashboards/testnet-nuc-query-data.json +++ b/grafana/runtime-data/dashboards/testnet-nuc-query-data.json @@ -126,7 +126,7 @@ "editorMode": "code", "format": "time_series", "rawQuery": true, - "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", + "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", "refId": "A", "sql": { "columns": [ @@ -218,7 +218,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", + "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", "refId": "A", "sql": { "columns": [ @@ -352,7 +352,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", "refId": "A", "sql": { "columns": [ @@ -360,7 +360,7 @@ "alias": "\"User\"", "parameters": [ { - "name": "userid", + "name": "user_id", "type": "functionParameter" } ], @@ -381,7 +381,7 @@ "groupBy": [ { "property": { - "name": "userid", + "name": "user_id", "type": "string" }, "type": "groupBy" @@ -481,7 +481,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", "refId": "A", "sql": { "columns": [ diff --git a/grafana/runtime-data/dashboards/totals-data.json b/grafana/runtime-data/dashboards/totals-data.json index 2db20c7d..ff66ce0e 100644 --- a/grafana/runtime-data/dashboards/totals-data.json +++ b/grafana/runtime-data/dashboards/totals-data.json @@ -83,7 +83,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "rawSql": "SELECT \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", "refId": "A", "sql": { "columns": [ @@ -165,7 +165,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT SUM(total_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "rawSql": "SELECT SUM(total_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", "refId": "A", "sql": { "columns": [ @@ -248,7 +248,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT SUM(q.prompt_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "rawSql": "SELECT SUM(q.prompt_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", "refId": "A", "sql": { "columns": [ @@ -331,7 +331,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT SUM(q.completion_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "rawSql": "SELECT SUM(q.completion_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", "refId": "A", "sql": { "columns": [ @@ -397,4 +397,4 @@ "uid": "aex54yzf0nmyoc", "version": 1, "weekStart": "" -} \ No newline at end of file +} diff --git a/grafana/runtime-data/dashboards/usage-data.json b/grafana/runtime-data/dashboards/usage-data.json index 88857f91..a22bf914 100644 --- a/grafana/runtime-data/dashboards/usage-data.json +++ b/grafana/runtime-data/dashboards/usage-data.json @@ -299,7 +299,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE q.query_timestamp >= NOW() - INTERVAL '1 hours'\nGROUP BY u.email;", + "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE q.query_timestamp >= NOW() - INTERVAL '1 hours'\nGROUP BY u.email;", "refId": "A", "sql": { "columns": [ @@ -307,7 +307,7 @@ "alias": "\"User ID\"", "parameters": [ { - "name": "userid", + "name": "user_id", "type": "functionParameter" } ], @@ -328,7 +328,7 @@ "groupBy": [ { "property": { - "name": "userid", + "name": "user_id", "type": "string" }, "type": "groupBy" @@ -430,7 +430,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE q.query_timestamp >= NOW() - INTERVAL '7 days'\nGROUP BY u.email;", + "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE q.query_timestamp >= NOW() - INTERVAL '7 days'\nGROUP BY u.email;", "refId": "A", "sql": { "columns": [ diff --git a/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py b/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py new file mode 100644 index 00000000..ebaca5a6 --- /dev/null +++ b/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py @@ -0,0 +1,206 @@ +"""chore: merged database schema updates + +Revision ID: 0ba073468afc +Revises: ea942d6c7a00 +Create Date: 2025-10-31 09:43:12.022675 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "0ba073468afc" +down_revision: Union[str, None] = "9ddf28cf6b6f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### merged commands from ea942d6c7a00 and 0ba073468afc ### + # query_logs: new telemetry columns (with defaults to backfill existing rows) + op.add_column( + "query_logs", + sa.Column( + "tool_calls", sa.Integer(), server_default=sa.text("0"), nullable=False + ), + ) + op.add_column( + "query_logs", + sa.Column( + "temperature", sa.Float(), server_default=sa.text("0.9"), nullable=True + ), + ) + op.add_column( + "query_logs", + sa.Column( + "max_tokens", sa.Integer(), server_default=sa.text("4096"), nullable=True + ), + ) + op.add_column( + "query_logs", + sa.Column( + "response_time_ms", + sa.Integer(), + server_default=sa.text("-1"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "model_response_time_ms", + sa.Integer(), + server_default=sa.text("-1"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "tool_response_time_ms", + sa.Integer(), + server_default=sa.text("-1"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "was_streamed", + sa.Boolean(), + server_default=sa.text("False"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "was_multimodal", + sa.Boolean(), + server_default=sa.text("False"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "was_nildb", sa.Boolean(), server_default=sa.text("False"), nullable=False + ), + ) + op.add_column( + "query_logs", + sa.Column( + "was_nilrag", sa.Boolean(), server_default=sa.text("False"), nullable=False + ), + ) + op.add_column( + "query_logs", + sa.Column( + "error_code", sa.Integer(), server_default=sa.text("200"), nullable=False + ), + ) + op.add_column( + "query_logs", + sa.Column( + "error_message", sa.Text(), server_default=sa.text("'OK'"), nullable=False + ), + ) + + # query_logs: remove FK to users.userid before dropping the column later + op.drop_constraint("query_logs_userid_fkey", "query_logs", type_="foreignkey") + + # query_logs: add lockid and index, drop legacy userid and its index + op.add_column( + "query_logs", sa.Column("lockid", sa.String(length=75), nullable=False) + ) + op.drop_index("ix_query_logs_userid", table_name="query_logs") + op.create_index( + op.f("ix_query_logs_lockid"), "query_logs", ["lockid"], unique=False + ) + op.drop_column("query_logs", "userid") + + # users: drop legacy token counters + op.drop_column("users", "prompt_tokens") + op.drop_column("users", "completion_tokens") + + # users: reshape identity columns and indexes + op.add_column("users", sa.Column("user_id", sa.String(length=75), nullable=False)) + op.drop_index("ix_users_apikey", table_name="users") + op.drop_index("ix_users_userid", table_name="users") + op.create_index(op.f("ix_users_user_id"), "users", ["user_id"], unique=False) + op.drop_column("users", "last_activity") + op.drop_column("users", "userid") + op.drop_column("users", "apikey") + op.drop_column("users", "signup_date") + op.drop_column("users", "queries") + op.drop_column("users", "name") + # ### end merged commands ### + + +def downgrade() -> None: + # ### revert merged commands back to 9ddf28cf6b6f ### + # users: restore legacy columns and indexes + op.add_column("users", sa.Column("name", sa.VARCHAR(length=100), nullable=False)) + op.add_column("users", sa.Column("queries", sa.INTEGER(), nullable=False)) + op.add_column( + "users", + sa.Column( + "signup_date", + postgresql.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + ) + op.add_column("users", sa.Column("apikey", sa.VARCHAR(length=75), nullable=False)) + op.add_column("users", sa.Column("userid", sa.VARCHAR(length=75), nullable=False)) + op.add_column( + "users", + sa.Column("last_activity", postgresql.TIMESTAMP(timezone=True), nullable=True), + ) + op.drop_index(op.f("ix_users_user_id"), table_name="users") + op.create_index("ix_users_userid", "users", ["userid"], unique=False) + op.create_index("ix_users_apikey", "users", ["apikey"], unique=False) + op.drop_column("users", "user_id") + op.add_column( + "users", + sa.Column( + "completion_tokens", + sa.INTEGER(), + server_default=sa.text("0"), + nullable=False, + ), + ) + op.add_column( + "users", + sa.Column( + "prompt_tokens", sa.INTEGER(), server_default=sa.text("0"), nullable=False + ), + ) + + # query_logs: restore userid, index and FK; drop new columns + op.add_column( + "query_logs", sa.Column("userid", sa.VARCHAR(length=75), nullable=False) + ) + op.drop_index(op.f("ix_query_logs_lockid"), table_name="query_logs") + op.create_index("ix_query_logs_userid", "query_logs", ["userid"], unique=False) + op.create_foreign_key( + "query_logs_userid_fkey", "query_logs", "users", ["userid"], ["userid"] + ) + op.drop_column("query_logs", "lockid") + op.drop_column("query_logs", "error_message") + op.drop_column("query_logs", "error_code") + op.drop_column("query_logs", "was_nilrag") + op.drop_column("query_logs", "was_nildb") + op.drop_column("query_logs", "was_multimodal") + op.drop_column("query_logs", "was_streamed") + op.drop_column("query_logs", "tool_response_time_ms") + op.drop_column("query_logs", "model_response_time_ms") + op.drop_column("query_logs", "response_time_ms") + op.drop_column("query_logs", "max_tokens") + op.drop_column("query_logs", "temperature") + op.drop_column("query_logs", "tool_calls") + # ### end revert ### diff --git a/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py b/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py new file mode 100644 index 00000000..4c20bb6d --- /dev/null +++ b/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py @@ -0,0 +1,37 @@ +"""fix: userid change to user_id + +Revision ID: 43b23c73035b +Revises: 0ba073468afc +Create Date: 2025-11-03 11:33:03.006101 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "43b23c73035b" +down_revision: Union[str, None] = "0ba073468afc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "query_logs", sa.Column("user_id", sa.String(length=75), nullable=False) + ) + op.create_index( + op.f("ix_query_logs_user_id"), "query_logs", ["user_id"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_query_logs_user_id"), table_name="query_logs") + op.drop_column("query_logs", "user_id") + # ### end Alembic commands ### diff --git a/nilai-api/examples/users.py b/nilai-api/examples/users.py deleted file mode 100644 index b6b206d5..00000000 --- a/nilai-api/examples/users.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/python - -from nilai_api.db.logs import QueryLogManager -from nilai_api.db.users import UserManager - - -# Example Usage -async def main(): - # Add some users - bob = await UserManager.insert_user("Bob", "bob@example.com") - alice = await UserManager.insert_user("Alice", "alice@example.com") - - print(f"Bob's details: {bob}") - print(f"Alice's details: {alice}") - - # Check API key - user_name = await UserManager.check_api_key(bob.apikey) - print(f"API key validation: {user_name}") - - # Update and retrieve token usage - await UserManager.update_token_usage( - bob.userid, prompt_tokens=50, completion_tokens=20 - ) - usage = await UserManager.get_user_token_usage(bob.userid) - print(f"Bob's token usage: {usage}") - - # Log a query - await QueryLogManager.log_query( - userid=bob.userid, - model="gpt-3.5-turbo", - prompt_tokens=8, - completion_tokens=7, - web_search_calls=1, - ) - - -if __name__ == "__main__": - import asyncio - from dotenv import load_dotenv - - load_dotenv() - - asyncio.run(main()) diff --git a/nilai-api/pyproject.toml b/nilai-api/pyproject.toml index 42a1cf4f..fd6f1eef 100644 --- a/nilai-api/pyproject.toml +++ b/nilai-api/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "trafilatura>=1.7.0", "secretvaults", "e2b-code-interpreter>=1.0.3", - "nilauth-credit-middleware==0.1.1", + "nilauth-credit-middleware>=0.1.2", ] diff --git a/nilai-api/src/nilai_api/auth/__init__.py b/nilai-api/src/nilai_api/auth/__init__.py index 2e7cd6f7..2123685a 100644 --- a/nilai-api/src/nilai_api/auth/__init__.py +++ b/nilai-api/src/nilai_api/auth/__init__.py @@ -4,7 +4,6 @@ from logging import getLogger from nilai_api.config import CONFIG -from nilai_api.db.users import UserManager from nilai_api.auth.strategies import AuthenticationStrategy from nuc.validate import ValidationException @@ -36,7 +35,6 @@ async def get_auth_info( ) auth_info = await strategy(credentials.credentials) - await UserManager.update_last_activity(userid=auth_info.user.userid) return auth_info except AuthenticationError as e: raise e diff --git a/nilai-api/src/nilai_api/auth/nuc.py b/nilai-api/src/nilai_api/auth/nuc.py index 46459356..614d9ef1 100644 --- a/nilai-api/src/nilai_api/auth/nuc.py +++ b/nilai-api/src/nilai_api/auth/nuc.py @@ -86,11 +86,11 @@ def validate_nuc(nuc_token: str) -> Tuple[str, str]: # Validate the # Return the subject of the token, the subscription holder - subscription_holder = token.subject.public_key.hex() - user = token.issuer.public_key.hex() + subscription_holder = token.subject + user = token.issuer logger.info(f"Subscription holder: {subscription_holder}") logger.info(f"User: {user}") - return subscription_holder, user + return str(subscription_holder), str(user) def get_token_rate_limit(nuc_token: str) -> Optional[TokenRateLimits]: diff --git a/nilai-api/src/nilai_api/auth/strategies.py b/nilai-api/src/nilai_api/auth/strategies.py index 9917ee39..089e7e94 100644 --- a/nilai-api/src/nilai_api/auth/strategies.py +++ b/nilai-api/src/nilai_api/auth/strategies.py @@ -1,6 +1,6 @@ from typing import Callable, Awaitable, Optional -from datetime import datetime, timezone +from fastapi import HTTPException from nilai_api.db.users import UserManager, UserModel, UserData from nilai_api.auth.nuc import ( validate_nuc, @@ -11,11 +11,18 @@ from nilai_api.auth.common import ( PromptDocument, TokenRateLimits, - AuthenticationInfo, AuthenticationError, + AuthenticationInfo, +) + +from nilauth_credit_middleware import ( + CreditClientSingleton, ) +from nilauth_credit_middleware.api_model import ValidateCredentialResponse + from enum import Enum + # All strategies must return a UserModel # The strategies can raise any exception, which will be caught and converted to an AuthenticationError # The exception detail will be passed to the client @@ -44,18 +51,10 @@ async def wrapper(token) -> AuthenticationInfo: return await function(token) if token == allowed_token: - user_model: UserModel | None = await UserManager.check_user( - allowed_token + user_model = UserModel( + user_id=allowed_token, + rate_limits=None, ) - if user_model is None: - user_model = UserModel( - userid=allowed_token, - name=allowed_token, - apikey=allowed_token, - signup_date=datetime.now(timezone.utc), - ) - await UserManager.insert_user_model(user_model) - return AuthenticationInfo( user=UserData.from_sqlalchemy(user_model), token_rate_limit=None, @@ -68,16 +67,41 @@ async def wrapper(token) -> AuthenticationInfo: return decorator +async def validate_credential(credential: str, is_public: bool) -> UserModel: + """ + Validate a credential with nilauth credit middleware and return the user model + """ + credit_client = CreditClientSingleton.get_client() + try: + validate_response: ValidateCredentialResponse = ( + await credit_client.validate_credential(credential, is_public=is_public) + ) + except HTTPException as e: + if e.status_code == 404: + raise AuthenticationError(f"Credential not found: {e.detail}") + elif e.status_code == 401: + raise AuthenticationError(f"Credential is inactive: {e.detail}") + else: + raise AuthenticationError(f"Failed to validate credential: {e.detail}") + + user_model = await UserManager.check_user(validate_response.user_id) + if user_model is None: + user_model = UserModel( + user_id=validate_response.user_id, + rate_limits=None, + ) + return user_model + + @allow_token(CONFIG.docs.token) async def api_key_strategy(api_key: str) -> AuthenticationInfo: - user_model: Optional[UserModel] = await UserManager.check_api_key(api_key) - if user_model: - return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), - token_rate_limit=None, - prompt_document=None, - ) - raise AuthenticationError("Missing or invalid API key") + user_model = await validate_credential(api_key, is_public=False) + + return AuthenticationInfo( + user=UserData.from_sqlalchemy(user_model), + token_rate_limit=None, + prompt_document=None, + ) @allow_token(CONFIG.docs.token) @@ -89,20 +113,7 @@ async def nuc_strategy(nuc_token) -> AuthenticationInfo: token_rate_limits: Optional[TokenRateLimits] = get_token_rate_limit(nuc_token) prompt_document: Optional[PromptDocument] = get_token_prompt_document(nuc_token) - user_model: Optional[UserModel] = await UserManager.check_user(user) - if user_model: - return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), - token_rate_limit=token_rate_limits, - prompt_document=prompt_document, - ) - - user_model = UserModel( - userid=user, - name=user, - apikey=subscription_holder, - ) - await UserManager.insert_user_model(user_model) + user_model = await validate_credential(subscription_holder, is_public=True) return AuthenticationInfo( user=UserData.from_sqlalchemy(user_model), token_rate_limit=token_rate_limits, diff --git a/nilai-api/src/nilai_api/commands/add_user.py b/nilai-api/src/nilai_api/commands/add_user.py index e9f49e55..202b70d4 100644 --- a/nilai-api/src/nilai_api/commands/add_user.py +++ b/nilai-api/src/nilai_api/commands/add_user.py @@ -6,9 +6,7 @@ @click.command() -@click.option("--name", type=str, required=True, help="User Name") -@click.option("--apikey", type=str, help="API Key") -@click.option("--userid", type=str, help="User Id") +@click.option("--user_id", type=str, help="User Id") @click.option("--ratelimit-day", type=int, help="number of request per day") @click.option("--ratelimit-hour", type=int, help="number of request per hour") @click.option("--ratelimit-minute", type=int, help="number of request per minute") @@ -26,9 +24,7 @@ help="number of web search request per minute", ) def main( - name, - apikey: str | None, - userid: str | None, + user_id: str | None, ratelimit_day: int | None, ratelimit_hour: int | None, ratelimit_minute: int | None, @@ -38,9 +34,7 @@ def main( ): async def add_user(): user: UserModel = await UserManager.insert_user( - name, - apikey, - userid, + user_id, RateLimits( user_rate_limit_day=ratelimit_day, user_rate_limit_hour=ratelimit_hour, @@ -52,9 +46,7 @@ async def add_user(): ) json_user = json.dumps( { - "userid": user.userid, - "name": user.name, - "apikey": user.apikey, + "user_id": user.user_id, "ratelimit_day": user.rate_limits_obj.user_rate_limit_day, "ratelimit_hour": user.rate_limits_obj.user_rate_limit_hour, "ratelimit_minute": user.rate_limits_obj.user_rate_limit_minute, diff --git a/nilai-api/src/nilai_api/config/__init__.py b/nilai-api/src/nilai_api/config/__init__.py index 7af72318..3f19f85e 100644 --- a/nilai-api/src/nilai_api/config/__init__.py +++ b/nilai-api/src/nilai_api/config/__init__.py @@ -70,3 +70,4 @@ def prettify(self): ] logging.info(CONFIG.prettify()) +print(CONFIG.prettify()) diff --git a/nilai-api/src/nilai_api/credit.py b/nilai-api/src/nilai_api/credit.py index 3a06135e..b9d7ea6f 100644 --- a/nilai-api/src/nilai_api/credit.py +++ b/nilai-api/src/nilai_api/credit.py @@ -20,6 +20,9 @@ class NoOpMeteringContext: """A no-op metering context for requests that should skip metering (e.g., Docs Token).""" + def __init__(self): + self.lock_id: str = "noop-lock-id" + def set_response(self, response_data: dict) -> None: """No-op method that does nothing.""" pass diff --git a/nilai-api/src/nilai_api/db/logs.py b/nilai-api/src/nilai_api/db/logs.py index 030c8696..4a78c8a7 100644 --- a/nilai-api/src/nilai_api/db/logs.py +++ b/nilai-api/src/nilai_api/db/logs.py @@ -1,12 +1,14 @@ import logging +import time from datetime import datetime, timezone +from typing import Optional +from nilai_common import Usage import sqlalchemy -from sqlalchemy import ForeignKey, Integer, String, DateTime, Text +from sqlalchemy import Integer, String, DateTime, Text, Boolean, Float from sqlalchemy.exc import SQLAlchemyError from nilai_api.db import Base, Column, get_db_session -from nilai_api.db.users import UserModel logger = logging.getLogger(__name__) @@ -16,9 +18,8 @@ class QueryLog(Base): __tablename__ = "query_logs" id: int = Column(Integer, primary_key=True, autoincrement=True) # type: ignore - userid: str = Column( - String(75), ForeignKey(UserModel.userid), nullable=False, index=True - ) # type: ignore + user_id: str = Column(String(75), nullable=False, index=True) # type: ignore + lockid: str = Column(String(75), nullable=False, index=True) # type: ignore query_timestamp: datetime = Column( DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False ) # type: ignore @@ -26,51 +27,285 @@ class QueryLog(Base): prompt_tokens: int = Column(Integer, nullable=False) # type: ignore completion_tokens: int = Column(Integer, nullable=False) # type: ignore total_tokens: int = Column(Integer, nullable=False) # type: ignore + tool_calls: int = Column(Integer, nullable=False) # type: ignore web_search_calls: int = Column(Integer, nullable=False) # type: ignore + temperature: Optional[float] = Column(Float, nullable=True) # type: ignore + max_tokens: Optional[int] = Column(Integer, nullable=True) # type: ignore + + response_time_ms: int = Column(Integer, nullable=False) # type: ignore + model_response_time_ms: int = Column(Integer, nullable=False) # type: ignore + tool_response_time_ms: int = Column(Integer, nullable=False) # type: ignore + + was_streamed: bool = Column(Boolean, nullable=False) # type: ignore + was_multimodal: bool = Column(Boolean, nullable=False) # type: ignore + was_nildb: bool = Column(Boolean, nullable=False) # type: ignore + was_nilrag: bool = Column(Boolean, nullable=False) # type: ignore + + error_code: int = Column(Integer, nullable=False) # type: ignore + error_message: str = Column(Text, nullable=False) # type: ignore def __repr__(self): - return f"" + return f"" + + +class QueryLogContext: + """ + Context manager for logging query metrics during a request. + Used as a FastAPI dependency to track request metrics. + """ + + def __init__(self): + self.user_id: Optional[str] = None + self.lockid: Optional[str] = None + self.model: Optional[str] = None + self.prompt_tokens: int = 0 + self.completion_tokens: int = 0 + self.tool_calls: int = 0 + self.web_search_calls: int = 0 + self.temperature: Optional[float] = None + self.max_tokens: Optional[int] = None + self.was_streamed: bool = False + self.was_multimodal: bool = False + self.was_nildb: bool = False + self.was_nilrag: bool = False + self.error_code: int = 0 + self.error_message: str = "" + + # Timing tracking + self.start_time: float = time.monotonic() + self.model_start_time: Optional[float] = None + self.model_end_time: Optional[float] = None + self.tool_start_time: Optional[float] = None + self.tool_end_time: Optional[float] = None + + def set_user(self, user_id: str) -> None: + """Set the user ID for this query.""" + self.user_id = user_id + + def set_lockid(self, lockid: str) -> None: + """Set the lock ID for this query.""" + self.lockid = lockid + + def set_model(self, model: str) -> None: + """Set the model name for this query.""" + self.model = model + + def set_request_params( + self, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + was_streamed: bool = False, + was_multimodal: bool = False, + was_nildb: bool = False, + was_nilrag: bool = False, + ) -> None: + """Set request parameters.""" + self.temperature = temperature + self.max_tokens = max_tokens + self.was_streamed = was_streamed + self.was_multimodal = was_multimodal + self.was_nildb = was_nildb + self.was_nilrag = was_nilrag + + def set_usage( + self, + prompt_tokens: int = 0, + completion_tokens: int = 0, + tool_calls: int = 0, + web_search_calls: int = 0, + ) -> None: + """Set token usage and feature usage.""" + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + self.tool_calls = tool_calls + self.web_search_calls = web_search_calls + + def set_error(self, error_code: int, error_message: str) -> None: + """Set error information.""" + self.error_code = error_code + self.error_message = error_message + + def start_model_timing(self) -> None: + """Mark the start of model inference.""" + self.model_start_time = time.monotonic() + + def end_model_timing(self) -> None: + """Mark the end of model inference.""" + self.model_end_time = time.monotonic() + + def start_tool_timing(self) -> None: + """Mark the start of tool execution.""" + self.tool_start_time = time.monotonic() + + def end_tool_timing(self) -> None: + """Mark the end of tool execution.""" + self.tool_end_time = time.monotonic() + + def _calculate_timings(self) -> tuple[int, int, int]: + """Calculate response times in milliseconds.""" + total_ms = int((time.monotonic() - self.start_time) * 1000) + + model_ms = 0 + if self.model_start_time and self.model_end_time: + model_ms = int((self.model_end_time - self.model_start_time) * 1000) + + tool_ms = 0 + if self.tool_start_time and self.tool_end_time: + tool_ms = int((self.tool_end_time - self.tool_start_time) * 1000) + + return total_ms, model_ms, tool_ms + + async def commit(self) -> None: + """ + Commit the query log to the database. + Should be called at the end of the request lifecycle. + """ + if not self.user_id or not self.model: + logger.warning( + "Skipping query log: user_id or model not set " + f"(user_id={self.user_id}, model={self.model})" + ) + return + + total_ms, model_ms, tool_ms = self._calculate_timings() + total_tokens = self.prompt_tokens + self.completion_tokens + + try: + async with get_db_session() as session: + query_log = QueryLog( + user_id=self.user_id, + lockid=self.lockid, + model=self.model, + prompt_tokens=self.prompt_tokens, + completion_tokens=self.completion_tokens, + total_tokens=total_tokens, + tool_calls=self.tool_calls, + web_search_calls=self.web_search_calls, + temperature=self.temperature, + max_tokens=self.max_tokens, + query_timestamp=datetime.now(timezone.utc), + response_time_ms=total_ms, + model_response_time_ms=model_ms, + tool_response_time_ms=tool_ms, + was_streamed=self.was_streamed, + was_multimodal=self.was_multimodal, + was_nilrag=self.was_nilrag, + was_nildb=self.was_nildb, + error_code=self.error_code, + error_message=self.error_message, + ) + session.add(query_log) + await session.commit() + logger.info( + f"Query logged for user {self.user_id}: model={self.model}, " + f"tokens={total_tokens}, total_ms={total_ms}" + ) + except SQLAlchemyError as e: + logger.error(f"Error logging query: {e}") + # Don't raise - logging failure shouldn't break the request class QueryLogManager: + """Static methods for direct query logging (legacy support).""" + @staticmethod async def log_query( - userid: str, + user_id: str, + lockid: str, model: str, prompt_tokens: int, completion_tokens: int, + response_time_ms: int, web_search_calls: int, + was_streamed: bool, + was_multimodal: bool, + was_nilrag: bool, + was_nildb: bool, + tool_calls: int = 0, + temperature: float = 1.0, + max_tokens: int = 0, + model_response_time_ms: int = 0, + tool_response_time_ms: int = 0, + error_code: int = 0, + error_message: str = "", ): """ - Log a user's query. - - Args: - userid (str): User's unique ID - model (str): The model that generated the response - prompt_tokens (int): Number of input tokens used - completion_tokens (int): Number of tokens in the generated response + Log a user's query (legacy method). + Consider using QueryLogContext as a dependency instead. """ total_tokens = prompt_tokens + completion_tokens try: async with get_db_session() as session: query_log = QueryLog( - userid=userid, + user_id=user_id, + lockid=lockid, model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, - query_timestamp=datetime.now(timezone.utc), + tool_calls=tool_calls, web_search_calls=web_search_calls, + temperature=temperature, + max_tokens=max_tokens, + query_timestamp=datetime.now(timezone.utc), + response_time_ms=response_time_ms, + model_response_time_ms=model_response_time_ms, + tool_response_time_ms=tool_response_time_ms, + was_streamed=was_streamed, + was_multimodal=was_multimodal, + was_nilrag=was_nilrag, + was_nildb=was_nildb, + error_code=error_code, + error_message=error_message, ) session.add(query_log) await session.commit() logger.info( - f"Query logged for user {userid} with total tokens {total_tokens}." + f"Query logged for user {user_id} with total tokens {total_tokens}." ) except SQLAlchemyError as e: logger.error(f"Error logging query: {e}") raise + @staticmethod + async def get_user_token_usage(user_id: str) -> Optional[Usage]: + """ + Get aggregated token usage for a specific user using server-side SQL aggregation. + This is more efficient than fetching all records and calculating in Python. + """ + try: + async with get_db_session() as session: + # Use SQL aggregation functions to calculate on the database server + query = ( + sqlalchemy.select( + sqlalchemy.func.coalesce( + sqlalchemy.func.sum(QueryLog.prompt_tokens), 0 + ).label("prompt_tokens"), + sqlalchemy.func.coalesce( + sqlalchemy.func.sum(QueryLog.completion_tokens), 0 + ).label("completion_tokens"), + sqlalchemy.func.coalesce( + sqlalchemy.func.sum(QueryLog.total_tokens), 0 + ).label("total_tokens"), + sqlalchemy.func.count().label("queries"), + ).where(QueryLog.user_id == user_id) # type: ignore[arg-type] + ) + + result = await session.execute(query) + row = result.one_or_none() + + if row is None: + return None + + return Usage( + prompt_tokens=int(row.prompt_tokens), + completion_tokens=int(row.completion_tokens), + total_tokens=int(row.total_tokens), + ) + except SQLAlchemyError as e: + logger.error(f"Error getting token usage: {e}") + return None + -__all__ = ["QueryLogManager", "QueryLog"] +__all__ = ["QueryLogManager", "QueryLog", "QueryLogContext"] diff --git a/nilai-api/src/nilai_api/db/users.py b/nilai-api/src/nilai_api/db/users.py index 515ba389..e475c424 100644 --- a/nilai-api/src/nilai_api/db/users.py +++ b/nilai-api/src/nilai_api/db/users.py @@ -2,11 +2,10 @@ import uuid from pydantic import BaseModel, ConfigDict, Field -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional +from typing import Optional import sqlalchemy -from sqlalchemy import Integer, String, DateTime, JSON +from sqlalchemy import String, JSON from sqlalchemy.exc import SQLAlchemyError from nilai_api.db import Base, Column, get_db_session @@ -57,21 +56,11 @@ def get_effective_limits(self) -> "RateLimits": # Enhanced User Model with additional constraints and validation class UserModel(Base): __tablename__ = "users" - - userid: str = Column(String(75), primary_key=True, index=True) # type: ignore - name: str = Column(String(100), nullable=False) # type: ignore - apikey: str = Column(String(75), unique=False, nullable=False, index=True) # type: ignore - prompt_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore - completion_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore - queries: int = Column(Integer, default=0, nullable=False) # type: ignore - signup_date: datetime = Column( - DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False - ) # type: ignore - last_activity: datetime = Column(DateTime(timezone=True), nullable=True) # type: ignore + user_id: str = Column(String(75), primary_key=True, index=True) # type: ignore rate_limits: dict = Column(JSON, nullable=True) # type: ignore def __repr__(self): - return f"" + return f"" @property def rate_limits_obj(self) -> RateLimits: @@ -85,14 +74,7 @@ def to_pydantic(self) -> "UserData": class UserData(BaseModel): - userid: str - name: str - apikey: str - prompt_tokens: int = 0 - completion_tokens: int = 0 - queries: int = 0 - signup_date: datetime - last_activity: Optional[datetime] = None + user_id: str # apikey or subscription holder public key rate_limits: RateLimits = Field(default_factory=RateLimits().get_effective_limits) model_config = ConfigDict(from_attributes=True) @@ -100,21 +82,10 @@ class UserData(BaseModel): @classmethod def from_sqlalchemy(cls, user: UserModel) -> "UserData": return cls( - userid=user.userid, - name=user.name, - apikey=user.apikey, - prompt_tokens=user.prompt_tokens or 0, - completion_tokens=user.completion_tokens or 0, - queries=user.queries or 0, - signup_date=user.signup_date or datetime.now(timezone.utc), - last_activity=user.last_activity, + user_id=user.user_id, rate_limits=user.rate_limits_obj, ) - @property - def is_subscription_owner(self): - return self.userid == self.apikey - class UserManager: @staticmethod @@ -127,31 +98,9 @@ def generate_api_key() -> str: """Generate a unique API key.""" return str(uuid.uuid4()) - @staticmethod - async def update_last_activity(userid: str): - """ - Update the last activity timestamp for a user. - - Args: - userid (str): User's unique ID - """ - try: - async with get_db_session() as session: - user = await session.get(UserModel, userid) - if user: - user.last_activity = datetime.now(timezone.utc) - await session.commit() - logger.info(f"Updated last activity for user {userid}") - else: - logger.warning(f"User {userid} not found") - except SQLAlchemyError as e: - logger.error(f"Error updating last activity: {e}") - @staticmethod async def insert_user( - name: str, - apikey: str | None = None, - userid: str | None = None, + user_id: str | None = None, rate_limits: RateLimits | None = None, ) -> UserModel: """ @@ -160,19 +109,16 @@ async def insert_user( Args: name (str): Name of the user apikey (str): API key for the user - userid (str): Unique ID for the user + user_id (str): Unique ID for the user rate_limits (RateLimits): Rate limit configuration Returns: UserModel: The created user model """ - userid = userid if userid else UserManager.generate_user_id() - apikey = apikey if apikey else UserManager.generate_api_key() + user_id = user_id if user_id else UserManager.generate_user_id() user = UserModel( - userid=userid, - name=name, - apikey=apikey, + user_id=user_id, rate_limits=rate_limits.model_dump() if rate_limits else None, ) return await UserManager.insert_user_model(user) @@ -189,35 +135,14 @@ async def insert_user_model(user: UserModel) -> UserModel: async with get_db_session() as session: session.add(user) await session.commit() - logger.info(f"User {user.name} added successfully.") + logger.info(f"User {user.user_id} added successfully.") return user except SQLAlchemyError as e: logger.error(f"Error inserting user: {e}") raise @staticmethod - async def check_user(userid: str) -> Optional[UserModel]: - """ - Validate a user. - - Args: - userid (str): User ID to validate - - Returns: - User's name if user is valid, None otherwise - """ - try: - async with get_db_session() as session: - query = sqlalchemy.select(UserModel).filter(UserModel.userid == userid) # type: ignore - user = await session.execute(query) - user = user.scalar_one_or_none() - return user - except SQLAlchemyError as e: - logger.error(f"Error checking API key: {e}") - return None - - @staticmethod - async def check_api_key(api_key: str) -> Optional[UserModel]: + async def check_user(user_id: str) -> Optional[UserModel]: """ Validate an API key. @@ -225,118 +150,27 @@ async def check_api_key(api_key: str) -> Optional[UserModel]: api_key (str): API key to validate Returns: - User's name if API key is valid, None otherwise + User's rate limits if user id is valid, None otherwise """ try: async with get_db_session() as session: - query = sqlalchemy.select(UserModel).filter(UserModel.apikey == api_key) # type: ignore + query = sqlalchemy.select(UserModel).filter( + UserModel.user_id == user_id # type: ignore + ) user = await session.execute(query) user = user.scalar_one_or_none() return user except SQLAlchemyError as e: - logger.error(f"Error checking API key: {e}") - return None - - @staticmethod - async def update_token_usage( - userid: str, prompt_tokens: int, completion_tokens: int - ): - """ - Update token usage for a specific user. - - Args: - userid (str): User's unique ID - prompt_tokens (int): Number of input tokens - completion_tokens (int): Number of generated tokens - """ - try: - async with get_db_session() as session: - user = await session.get(UserModel, userid) - if user: - user.prompt_tokens += prompt_tokens - user.completion_tokens += completion_tokens - user.queries += 1 - await session.commit() - logger.info(f"Updated token usage for user {userid}") - else: - logger.warning(f"User {userid} not found") - except SQLAlchemyError as e: - logger.error(f"Error updating token usage: {e}") - - @staticmethod - async def get_token_usage(userid: str) -> Optional[Dict[str, Any]]: - """ - Get token usage for a specific user. - - Args: - userid (str): User's unique ID - """ - try: - async with get_db_session() as session: - user = await session.get(UserModel, userid) - if user: - return { - "prompt_tokens": user.prompt_tokens, - "completion_tokens": user.completion_tokens, - "total_tokens": user.prompt_tokens + user.completion_tokens, - "queries": user.queries, - } - else: - logger.warning(f"User {userid} not found") - return None - except SQLAlchemyError as e: - logger.error(f"Error updating token usage: {e}") - return None - - @staticmethod - async def get_all_users() -> Optional[List[UserData]]: - """ - Retrieve all users from the database. - - Returns: - List of UserData or None if no users found - """ - try: - async with get_db_session() as session: - users = await session.execute(sqlalchemy.select(UserModel)) - users = users.scalars().all() - return [UserData.from_sqlalchemy(user) for user in users] - except SQLAlchemyError as e: - logger.error(f"Error retrieving all users: {e}") - return None - - @staticmethod - async def get_user_token_usage(userid: str) -> Optional[Dict[str, int]]: - """ - Retrieve total token usage for a user. - - Args: - userid (str): User's unique ID - - Returns: - Dict of token usage or None if user not found - """ - try: - async with get_db_session() as session: - user = await session.get(UserModel, userid) - if user: - return { - "prompt_tokens": user.prompt_tokens, - "completion_tokens": user.completion_tokens, - "queries": user.queries, - } - return None - except SQLAlchemyError as e: - logger.error(f"Error retrieving token usage: {e}") + logger.error(f"Rate limit checking user id: {e}") return None @staticmethod - async def update_rate_limits(userid: str, rate_limits: RateLimits) -> bool: + async def update_rate_limits(user_id: str, rate_limits: RateLimits) -> bool: """ Update rate limits for a specific user. Args: - userid (str): User's unique ID + user_id (str): User's unique ID rate_limits (RateLimits): New rate limit configuration Returns: @@ -344,14 +178,14 @@ async def update_rate_limits(userid: str, rate_limits: RateLimits) -> bool: """ try: async with get_db_session() as session: - user = await session.get(UserModel, userid) + user = await session.get(UserModel, user_id) if user: user.rate_limits = rate_limits.model_dump() await session.commit() - logger.info(f"Updated rate limits for user {userid}") + logger.info(f"Updated rate limits for user {user_id}") return True else: - logger.warning(f"User {userid} not found") + logger.warning(f"User {user_id} not found") return False except SQLAlchemyError as e: logger.error(f"Error updating rate limits: {e}") diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index c2d03273..ae1e63ae 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -1,3 +1,4 @@ +import logging from asyncio import iscoroutine from typing import Callable, Tuple, Awaitable, Annotated @@ -11,6 +12,8 @@ from nilai_api.auth import get_auth_info, AuthenticationInfo, TokenRateLimits from nilai_api.config import CONFIG +logger = logging.getLogger(__name__) + LUA_RATE_LIMIT_SCRIPT = """ local key = KEYS[1] local limit = tonumber(ARGV[1]) @@ -52,7 +55,7 @@ async def _extract_coroutine_result(maybe_future, request: Request): class UserRateLimits(BaseModel): - subscription_holder: str + user_id: str token_rate_limit: TokenRateLimits | None rate_limits: RateLimits @@ -60,14 +63,13 @@ class UserRateLimits(BaseModel): def get_user_limits( auth_info: Annotated[AuthenticationInfo, Depends(get_auth_info)], ) -> UserRateLimits: - # TODO: When the only allowed strategy is NUC, we can change the apikey name to subscription_holder - # In apikey mode, the apikey is unique as the userid. - # In nuc mode, the apikey is associated with a subscription holder and the userid is the user + # In apikey mode, the apikey is unique as the user_id. + # In nuc mode, the apikey is associated with a subscription holder and the user_id is the user # For NUCs we want the rate limit to be per subscription holder, not per user - # In JWT mode, the apikey is the userid too + # In JWT mode, the apikey is the user_id too # So we use the apikey as the id return UserRateLimits( - subscription_holder=auth_info.user.apikey, + user_id=auth_info.user.user_id, token_rate_limit=auth_info.token_rate_limit, rate_limits=auth_info.user.rate_limits, ) @@ -105,21 +107,21 @@ async def __call__( await self.check_bucket( redis, redis_rate_limit_command, - f"minute:{user_limits.subscription_holder}", + f"minute:{user_limits.user_id}", user_limits.rate_limits.user_rate_limit_minute, MINUTE_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"hour:{user_limits.subscription_holder}", + f"hour:{user_limits.user_id}", user_limits.rate_limits.user_rate_limit_hour, HOUR_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"day:{user_limits.subscription_holder}", + f"day:{user_limits.user_id}", user_limits.rate_limits.user_rate_limit_day, DAY_MS, ) @@ -127,7 +129,7 @@ async def __call__( await self.check_bucket( redis, redis_rate_limit_command, - f"user:{user_limits.subscription_holder}", + f"user:{user_limits.user_id}", user_limits.rate_limits.user_rate_limit, 0, # No expiration for for-good rate limit ) @@ -187,7 +189,7 @@ async def __call__( await self.check_bucket( redis, redis_rate_limit_command, - f"web_search:{user_limits.subscription_holder}", + f"web_search:{user_limits.user_id}", user_limits.rate_limits.web_search_rate_limit, 0, ) @@ -206,13 +208,33 @@ async def check_bucket( times: int | None, milliseconds: int, ): + """ + Check if the rate limit is exceeded for a given key + + Args: + redis: The Redis client + redis_rate_limit_command: The Redis rate limit command + key: The key to check the rate limit for + times: The number of times allowed for the key + milliseconds: The expiration time in milliseconds of the rate limit + + Returns: + None if the rate limit is not exceeded + The number of milliseconds to wait before the rate limit is reset if the rate limit is exceeded + + Raises: + HTTPException: If the rate limit is exceeded + """ if times is None: return + # Evaluate the Lua script to check if the rate limit is exceeded expire = await redis.evalsha( redis_rate_limit_command, 1, key, str(times), str(milliseconds) ) # type: ignore - if int(expire) > 0: + logger.error( + f"Rate limit exceeded for key: {key}, expires in: {expire} milliseconds, times allowed: {times}, expiration time: {milliseconds / 1000} seconds" + ) raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too Many Requests", diff --git a/nilai-api/src/nilai_api/routers/endpoints/chat.py b/nilai-api/src/nilai_api/routers/endpoints/chat.py index 45425e31..783dcc8f 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/chat.py +++ b/nilai-api/src/nilai_api/routers/endpoints/chat.py @@ -5,7 +5,7 @@ from base64 import b64encode from typing import AsyncGenerator, Optional, Union, List, Tuple -from fastapi import APIRouter, Body, Depends, HTTPException, status, Request +from fastapi import APIRouter, Body, Depends, HTTPException, status, Request, BackgroundTasks from fastapi.responses import StreamingResponse from openai import AsyncOpenAI @@ -13,7 +13,7 @@ from nilai_api.config import CONFIG from nilai_api.crypto import sign_message from nilai_api.credit import LLMMeter, LLMUsage -from nilai_api.db.logs import QueryLogManager +from nilai_api.db.logs import QueryLogManager, QueryLogContext from nilai_api.db.users import UserManager from nilai_api.handlers.nildb.handler import get_prompt_from_nildb from nilai_api.handlers.nilrag import handle_nilrag @@ -72,6 +72,7 @@ async def chat_completion( ], ) ), + background_tasks: BackgroundTasks = BackgroundTasks(), _rate_limit=Depends( RateLimit( concurrent_extractor=chat_completion_concurrent_rate_limit, @@ -80,6 +81,7 @@ async def chat_completion( ), auth_info: AuthenticationInfo = Depends(get_auth_info), meter: MeteringContext = Depends(LLMMeter), + log_ctx: QueryLogContext = Depends(QueryLogContext), ) -> Union[SignedChatCompletion, StreamingResponse]: """ Generate a chat completion response from the AI model. @@ -133,243 +135,311 @@ async def chat_completion( ) response = await chat_completion(request, user) """ - - if len(req.messages) == 0: - raise HTTPException( - status_code=400, - detail="Request contained 0 messages", - ) + # Initialize log context early so we can log any errors + log_ctx.set_user(auth_info.user.user_id) + log_ctx.set_lockid(meter.lock_id) model_name = req.model request_id = str(uuid.uuid4()) t_start = time.monotonic() - logger.info(f"[chat] call start request_id={req.messages}") - endpoint = await state.get_model(model_name) - if endpoint is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid model name {model_name}, check /v1/models for options", - ) - - has_multimodal = req.has_multimodal_content() - logger.info(f"[chat] has_multimodal: {has_multimodal}") - if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): - raise HTTPException( - status_code=400, - detail="Model does not support multimodal content, remove image inputs from request", - ) - - model_url = endpoint.url + "/v1/" - logger.info( - f"[chat] start request_id={request_id} user={auth_info.user.userid} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" - ) - - client = AsyncOpenAI(base_url=model_url, api_key="") - if auth_info.prompt_document: - try: - nildb_prompt: str = await get_prompt_from_nildb(auth_info.prompt_document) - req.messages.insert( - 0, MessageAdapter.new_message(role="system", content=nildb_prompt) + try: + if len(req.messages) == 0: + raise HTTPException( + status_code=400, + detail="Request contained 0 messages", ) - except Exception as e: + endpoint = await state.get_model(model_name) + if endpoint is None: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Unable to extract prompt from nilDB: {str(e)}", + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid model name {model_name}, check /v1/models for options", ) - if req.nilrag: - logger.info(f"[chat] nilrag start request_id={request_id}") - t_nilrag = time.monotonic() - await handle_nilrag(req) - logger.info( - f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" - ) - - messages = req.messages - sources: Optional[List[Source]] = None + # Now we have a valid model, set it in log context + log_ctx.set_model(model_name) - if req.web_search: - logger.info(f"[chat] web_search start request_id={request_id}") - t_ws = time.monotonic() - web_search_result = await handle_web_search(req, model_name, client) - messages = web_search_result.messages - sources = web_search_result.sources - logger.info( - f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" - ) - logger.info(f"[chat] web_search messages: {messages}") - - request_kwargs = { - "model": req.model, - "messages": messages, - "top_p": req.top_p, - "temperature": req.temperature, - "max_tokens": req.max_tokens, - } - - if req.tools: - if not endpoint.metadata.tool_support: + if not endpoint.metadata.tool_support and req.tools: raise HTTPException( status_code=400, detail="Model does not support tool usage, remove tools from request", ) - if model_name == "openai/gpt-oss-20b": + + has_multimodal = req.has_multimodal_content() + logger.info(f"[chat] has_multimodal: {has_multimodal}") + if has_multimodal and ( + not endpoint.metadata.multimodal_support or req.web_search + ): raise HTTPException( status_code=400, - detail="This model only supports tool calls with responses endpoint", + detail="Model does not support multimodal content, remove image inputs from request", ) - request_kwargs["tools"] = req.tools - request_kwargs["tool_choice"] = req.tool_choice - if req.stream: + model_url = endpoint.url + "/v1/" - async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: - t_call = time.monotonic() - prompt_token_usage = 0 - completion_token_usage = 0 + logger.info( + f"[chat] start request_id={request_id} user={auth_info.user.user_id} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" + ) + log_ctx.set_request_params( + temperature=req.temperature, + max_tokens=req.max_tokens, + was_streamed=req.stream or False, + was_multimodal=has_multimodal, + was_nildb=bool(auth_info.prompt_document), + was_nilrag=bool(req.nilrag), + ) + client = AsyncOpenAI(base_url=model_url, api_key="") + if auth_info.prompt_document: try: - logger.info(f"[chat] stream start request_id={request_id}") + nildb_prompt: str = await get_prompt_from_nildb( + auth_info.prompt_document + ) + req.messages.insert( + 0, MessageAdapter.new_message(role="system", content=nildb_prompt) + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Unable to extract prompt from nilDB: {str(e)}", + ) + + if req.nilrag: + logger.info(f"[chat] nilrag start request_id={request_id}") + t_nilrag = time.monotonic() + await handle_nilrag(req) + logger.info( + f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" + ) - request_kwargs["stream"] = True - request_kwargs["extra_body"] = { - "stream_options": { - "include_usage": True, - "continuous_usage_stats": False, + messages = req.messages + sources: Optional[List[Source]] = None + + if req.web_search: + logger.info(f"[chat] web_search start request_id={request_id}") + t_ws = time.monotonic() + web_search_result = await handle_web_search(req, model_name, client) + messages = web_search_result.messages + sources = web_search_result.sources + logger.info( + f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" + ) + logger.info(f"[chat] web_search messages: {messages}") + + if req.stream: + + async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: + t_call = time.monotonic() + prompt_token_usage = 0 + completion_token_usage = 0 + + try: + logger.info(f"[chat] stream start request_id={request_id}") + + log_ctx.start_model_timing() + + request_kwargs = { + "model": req.model, + "messages": messages, + "stream": True, + "top_p": req.top_p, + "temperature": req.temperature, + "max_tokens": req.max_tokens, + "extra_body": { + "stream_options": { + "include_usage": True, + "continuous_usage_stats": False, + } + }, } - } + if req.tools: + request_kwargs["tools"] = req.tools + + response = await client.chat.completions.create(**request_kwargs) + + async for chunk in response: + if chunk.usage is not None: + prompt_token_usage = chunk.usage.prompt_tokens + completion_token_usage = chunk.usage.completion_tokens + + payload = chunk.model_dump(exclude_unset=True) + + if chunk.usage is not None and sources: + payload["sources"] = [ + s.model_dump(mode="json") for s in sources + ] + + yield f"data: {json.dumps(payload)}\n\n" + + log_ctx.end_model_timing() + meter.set_response( + { + "usage": LLMUsage( + prompt_tokens=prompt_token_usage, + completion_tokens=completion_token_usage, + web_searches=len(sources) if sources else 0, + ) + } + ) + log_ctx.set_usage( + prompt_tokens=prompt_token_usage, + completion_tokens=completion_token_usage, + web_search_calls=len(sources) if sources else 0, + ) + background_tasks.add_task(log_ctx.commit) + logger.info( + "[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d " + "duration_ms=%.0f total_ms=%.0f", + request_id, + prompt_token_usage, + completion_token_usage, + (time.monotonic() - t_call) * 1000, + (time.monotonic() - t_start) * 1000, + ) + + except Exception as e: + logger.error( + "[chat] stream error request_id=%s error=%s", request_id, e + ) + log_ctx.set_error(error_code=500, error_message=str(e)) + await log_ctx.commit() + yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" + + return StreamingResponse( + chat_completion_stream_generator(), + media_type="text/event-stream", + ) - response = await client.chat.completions.create(**request_kwargs) + current_messages = messages + request_kwargs = { + "model": req.model, + "messages": current_messages, # type: ignore + "top_p": req.top_p, + "temperature": req.temperature, + "max_tokens": req.max_tokens, + } + if req.tools: + request_kwargs["tools"] = req.tools # type: ignore + request_kwargs["tool_choice"] = req.tool_choice + + logger.info(f"[chat] call start request_id={request_id}") + logger.info(f"[chat] call message: {current_messages}") + t_call = time.monotonic() + log_ctx.start_model_timing() + response = await client.chat.completions.create(**request_kwargs) # type: ignore + log_ctx.end_model_timing() + logger.info( + f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) + logger.info(f"[chat] call response: {response}") + + # Handle tool workflow fully inside tools.router + log_ctx.start_tool_timing() + ( + final_completion, + agg_prompt_tokens, + agg_completion_tokens, + ) = await handle_tool_workflow(client, req, current_messages, response) + log_ctx.end_tool_timing() + logger.info(f"[chat] call final_completion: {final_completion}") + model_response = SignedChatCompletion( + **final_completion.model_dump(), + signature="", + sources=sources, + ) - async for chunk in response: - if chunk.usage is not None: - prompt_token_usage = chunk.usage.prompt_tokens - completion_token_usage = chunk.usage.completion_tokens + logger.info( + f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) - payload = chunk.model_dump(exclude_unset=True) + if model_response.usage is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Model response does not contain usage statistics", + ) - if chunk.usage is not None and sources: - payload["sources"] = [ - s.model_dump(mode="json") for s in sources - ] + if agg_prompt_tokens or agg_completion_tokens: + total_prompt_tokens = response.usage.prompt_tokens + total_completion_tokens = response.usage.completion_tokens - yield f"data: {json.dumps(payload)}\n\n" + total_prompt_tokens += agg_prompt_tokens + total_completion_tokens += agg_completion_tokens - await UserManager.update_token_usage( - auth_info.user.userid, - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - ) - meter.set_response( - { - "usage": LLMUsage( - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - web_searches=len(sources) if sources else 0, - ) - } - ) - await QueryLogManager.log_query( - auth_info.user.userid, - model=req.model, - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - web_search_calls=len(sources) if sources else 0, - ) - logger.info( - "[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d " - "duration_ms=%.0f total_ms=%.0f", - request_id, - prompt_token_usage, - completion_token_usage, - (time.monotonic() - t_call) * 1000, - (time.monotonic() - t_start) * 1000, - ) + model_response.usage.prompt_tokens = total_prompt_tokens + model_response.usage.completion_tokens = total_completion_tokens + model_response.usage.total_tokens = ( + total_prompt_tokens + total_completion_tokens + ) - except Exception as e: - logger.error( - "[chat] stream error request_id=%s error=%s", request_id, e + # Update token usage in DB + meter.set_response( + { + "usage": LLMUsage( + prompt_tokens=model_response.usage.prompt_tokens, + completion_tokens=model_response.usage.completion_tokens, + web_searches=len(sources) if sources else 0, ) - yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" - - return StreamingResponse( - chat_completion_stream_generator(), - media_type="text/event-stream", + } ) - logger.info(f"[chat] call start request_id={request_id}") - t_call = time.monotonic() - response = await client.chat.completions.create(**request_kwargs) - logger.info( - f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" - ) - logger.info(f"[chat] call response: {response}") - - # Handle tool workflow fully inside tools.router - ( - final_completion, - agg_prompt_tokens, - agg_completion_tokens, - ) = await handle_tool_workflow(client, req, request_kwargs["messages"], response) - logger.info(f"[chat] call final_completion: {final_completion}") - model_response = SignedChatCompletion( - **final_completion.model_dump(), - signature="", - sources=sources, - ) - - logger.info( - f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" - ) + # Log query with context + tool_calls_count = 0 + if final_completion.choices and final_completion.choices[0].message.tool_calls: + tool_calls_count = len(final_completion.choices[0].message.tool_calls) - if model_response.usage is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Model response does not contain usage statistics", + log_ctx.set_usage( + prompt_tokens=model_response.usage.prompt_tokens, + completion_tokens=model_response.usage.completion_tokens, + tool_calls=tool_calls_count, + web_search_calls=len(sources) if sources else 0, ) + # Use background task for successful requests to avoid blocking response + background_tasks.add_task(log_ctx.commit) - if agg_prompt_tokens or agg_completion_tokens: - total_prompt_tokens = response.usage.prompt_tokens - total_completion_tokens = response.usage.completion_tokens + # Sign the response + response_json = model_response.model_dump_json() + signature = sign_message(state.private_key, response_json) + model_response.signature = b64encode(signature).decode() - total_prompt_tokens += agg_prompt_tokens - total_completion_tokens += agg_completion_tokens - - model_response.usage.prompt_tokens = total_prompt_tokens - model_response.usage.completion_tokens = total_completion_tokens - model_response.usage.total_tokens = ( - total_prompt_tokens + total_completion_tokens + logger.info( + f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" + ) + return model_response + except HTTPException as e: + # Extract error code from HTTPException, default to status code + error_code = e.status_code + error_message = str(e.detail) if e.detail else str(e) + logger.error( + f"[chat] HTTPException request_id={request_id} user={auth_info.user.user_id} " + f"model={model_name} error_code={error_code} error={error_message}", + exc_info=True, ) - # Update token usage in DB - await UserManager.update_token_usage( - auth_info.user.userid, - prompt_tokens=model_response.usage.prompt_tokens, - completion_tokens=model_response.usage.completion_tokens, - ) - meter.set_response( - { - "usage": LLMUsage( - prompt_tokens=model_response.usage.prompt_tokens, - completion_tokens=model_response.usage.completion_tokens, - web_searches=len(sources) if sources else 0, - ) - } - ) - await QueryLogManager.log_query( - auth_info.user.userid, - model=req.model, - prompt_tokens=model_response.usage.prompt_tokens, - completion_tokens=model_response.usage.completion_tokens, - web_search_calls=len(sources) if sources else 0, - ) - - # Sign the response - response_json = model_response.model_dump_json() - signature = sign_message(state.private_key, response_json) - model_response.signature = b64encode(signature).decode() - - logger.info( - f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" - ) - return model_response + # Only log server errors (5xx) to database to prevent DoS attacks via client errors + # Client errors (4xx) are logged to application logs only + if error_code >= 500: + # Set model if not already set (e.g., for validation errors before model validation) + if log_ctx.model is None: + log_ctx.set_model(model_name) + log_ctx.set_error(error_code=error_code, error_message=error_message) + await log_ctx.commit() + # For 4xx errors, we skip DB logging - they're logged above via logger.error() + # This prevents DoS attacks where attackers send many invalid requests + + raise + except Exception as e: + # Catch any other unexpected exceptions + error_message = str(e) + logger.error( + f"[chat] unexpected error request_id={request_id} user={auth_info.user.user_id} " + f"model={model_name} error={error_message}", + exc_info=True, + ) + # Set model if not already set + if log_ctx.model is None: + log_ctx.set_model(model_name) + log_ctx.set_error(error_code=500, error_message=error_message) + await log_ctx.commit() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Internal server error: {error_message}", + ) \ No newline at end of file diff --git a/nilai-api/src/nilai_api/routers/endpoints/responses.py b/nilai-api/src/nilai_api/routers/endpoints/responses.py index 1b764674..fa519529 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/responses.py +++ b/nilai-api/src/nilai_api/routers/endpoints/responses.py @@ -5,7 +5,7 @@ from base64 import b64encode from typing import AsyncGenerator, Optional, Union, List, Tuple -from fastapi import APIRouter, Body, Depends, HTTPException, status, Request +from fastapi import APIRouter, Body, Depends, HTTPException, status, Request, BackgroundTasks from fastapi.responses import StreamingResponse from openai import AsyncOpenAI @@ -13,7 +13,7 @@ from nilai_api.config import CONFIG from nilai_api.crypto import sign_message from nilai_api.credit import LLMMeter, LLMUsage -from nilai_api.db.logs import QueryLogManager +from nilai_api.db.logs import QueryLogManager, QueryLogContext from nilai_api.db.users import UserManager from nilai_api.handlers.nildb.handler import get_prompt_from_nildb @@ -73,6 +73,7 @@ async def create_response( "web_search": False, } ), + background_tasks: BackgroundTasks = BackgroundTasks(), _rate_limit=Depends( RateLimit( concurrent_extractor=responses_concurrent_rate_limit, @@ -81,6 +82,7 @@ async def create_response( ), auth_info: AuthenticationInfo = Depends(get_auth_info), meter: MeteringContext = Depends(LLMMeter), + log_ctx: QueryLogContext = Depends(QueryLogContext), ) -> Union[SignedResponse, StreamingResponse]: """ Generate a response from the AI model using the Responses API. @@ -124,225 +126,283 @@ async def create_response( - **500 Internal Server Error**: - Model fails to generate a response """ - if not req.input: - raise HTTPException( - status_code=400, - detail="Request 'input' field cannot be empty.", - ) + # Initialize log context early so we can log any errors + log_ctx.set_user(auth_info.user.user_id) + log_ctx.set_lockid(meter.lock_id) model_name = req.model request_id = str(uuid.uuid4()) t_start = time.monotonic() - endpoint = await state.get_model(model_name) - if endpoint is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid model name {model_name}, check /v1/models for options", - ) + try: + if not req.input: + raise HTTPException( + status_code=400, + detail="Request 'input' field cannot be empty.", + ) - if not endpoint.metadata.tool_support and req.tools: - raise HTTPException( - status_code=400, - detail="Model does not support tool usage, remove tools from request", - ) + endpoint = await state.get_model(model_name) + if endpoint is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid model name {model_name}, check /v1/models for options", + ) - has_multimodal = req.has_multimodal_content() - if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): - raise HTTPException( - status_code=400, - detail="Model does not support multimodal content, remove image inputs from request", - ) + # Now we have a valid model, set it in log context + log_ctx.set_model(model_name) - model_url = endpoint.url + "/v1/" + if not endpoint.metadata.tool_support and req.tools: + raise HTTPException( + status_code=400, + detail="Model does not support tool usage, remove tools from request", + ) - client = AsyncOpenAI(base_url=model_url, api_key="") - if auth_info.prompt_document: - try: - nildb_prompt: str = await get_prompt_from_nildb(auth_info.prompt_document) - req.ensure_instructions(nildb_prompt) - except Exception as e: + has_multimodal = req.has_multimodal_content() + if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Unable to extract prompt from nilDB: {str(e)}", + status_code=400, + detail="Model does not support multimodal content, remove image inputs from request", ) - input_items = req.input - instructions = req.instructions - sources: Optional[List[Source]] = None + model_url = endpoint.url + "/v1/" - if req.web_search: - logger.info(f"[responses] web_search start request_id={request_id}") - t_ws = time.monotonic() - web_search_result = await handle_web_search_for_responses( - req, model_name, client - ) - input_items = web_search_result.input - instructions = web_search_result.instructions - sources = web_search_result.sources logger.info( - f"[responses] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" + f"[responses] start request_id={request_id} user={auth_info.user.user_id} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" + ) + log_ctx.set_request_params( + temperature=req.temperature, + max_tokens=req.max_output_tokens, + was_streamed=req.stream or False, + was_multimodal=has_multimodal, + was_nildb=bool(auth_info.prompt_document), + was_nilrag=False, ) - if req.stream: - - async def response_stream_generator() -> AsyncGenerator[str, None]: - t_call = time.monotonic() - prompt_token_usage = 0 - completion_token_usage = 0 - + client = AsyncOpenAI(base_url=model_url, api_key="") + if auth_info.prompt_document: try: - logger.info(f"[responses] stream start request_id={request_id}") - request_kwargs = { - "model": req.model, - "input": input_items, - "instructions": instructions, - "stream": True, - "top_p": req.top_p, - "temperature": req.temperature, - "max_output_tokens": req.max_output_tokens, - "extra_body": { - "stream_options": { - "include_usage": True, - "continuous_usage_stats": False, - } - }, - } - if req.tools: - request_kwargs["tools"] = req.tools - - stream = await client.responses.create(**request_kwargs) - - async for event in stream: - payload = event.model_dump(exclude_unset=True) - - if isinstance(event, ResponseCompletedEvent): - if event.response and event.response.usage: - usage = event.response.usage - prompt_token_usage = usage.input_tokens - completion_token_usage = usage.output_tokens - payload["response"]["usage"] = usage.model_dump(mode="json") - - if sources: - if "data" not in payload: - payload["data"] = {} - payload["data"]["sources"] = [ - s.model_dump(mode="json") for s in sources - ] - - yield f"data: {json.dumps(payload)}\n\n" - - await UserManager.update_token_usage( - auth_info.user.userid, - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, + nildb_prompt: str = await get_prompt_from_nildb(auth_info.prompt_document) + req.ensure_instructions(nildb_prompt) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Unable to extract prompt from nilDB: {str(e)}", ) - meter.set_response( - { - "usage": LLMUsage( - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - web_searches=len(sources) if sources else 0, - ) + + input_items = req.input + instructions = req.instructions + sources: Optional[List[Source]] = None + + if req.web_search: + logger.info(f"[responses] web_search start request_id={request_id}") + t_ws = time.monotonic() + web_search_result = await handle_web_search_for_responses( + req, model_name, client + ) + input_items = web_search_result.input + instructions = web_search_result.instructions + sources = web_search_result.sources + logger.info( + f"[responses] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" + ) + + if req.stream: + + async def response_stream_generator() -> AsyncGenerator[str, None]: + t_call = time.monotonic() + prompt_token_usage = 0 + completion_token_usage = 0 + + try: + logger.info(f"[responses] stream start request_id={request_id}") + log_ctx.start_model_timing() + + request_kwargs = { + "model": req.model, + "input": input_items, + "instructions": instructions, + "stream": True, + "top_p": req.top_p, + "temperature": req.temperature, + "max_output_tokens": req.max_output_tokens, + "extra_body": { + "stream_options": { + "include_usage": True, + "continuous_usage_stats": False, + } + }, } - ) - await QueryLogManager.log_query( - auth_info.user.userid, - model=req.model, - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - web_search_calls=len(sources) if sources else 0, - ) - logger.info( - "[responses] stream done request_id=%s prompt_tokens=%d completion_tokens=%d duration_ms=%.0f total_ms=%.0f", - request_id, - prompt_token_usage, - completion_token_usage, - (time.monotonic() - t_call) * 1000, - (time.monotonic() - t_start) * 1000, - ) + if req.tools: + request_kwargs["tools"] = req.tools + + stream = await client.responses.create(**request_kwargs) + + async for event in stream: + payload = event.model_dump(exclude_unset=True) + + if isinstance(event, ResponseCompletedEvent): + if event.response and event.response.usage: + usage = event.response.usage + prompt_token_usage = usage.input_tokens + completion_token_usage = usage.output_tokens + payload["response"]["usage"] = usage.model_dump(mode="json") + + if sources: + if "data" not in payload: + payload["data"] = {} + payload["data"]["sources"] = [ + s.model_dump(mode="json") for s in sources + ] + + yield f"data: {json.dumps(payload)}\n\n" + + log_ctx.end_model_timing() + meter.set_response( + { + "usage": LLMUsage( + prompt_tokens=prompt_token_usage, + completion_tokens=completion_token_usage, + web_searches=len(sources) if sources else 0, + ) + } + ) + log_ctx.set_usage( + prompt_tokens=prompt_token_usage, + completion_tokens=completion_token_usage, + web_search_calls=len(sources) if sources else 0, + ) + background_tasks.add_task(log_ctx.commit) + logger.info( + "[responses] stream done request_id=%s prompt_tokens=%d completion_tokens=%d duration_ms=%.0f total_ms=%.0f", + request_id, + prompt_token_usage, + completion_token_usage, + (time.monotonic() - t_call) * 1000, + (time.monotonic() - t_start) * 1000, + ) + + except Exception as e: + logger.error( + "[responses] stream error request_id=%s error=%s", request_id, e + ) + log_ctx.set_error(error_code=500, error_message=str(e)) + await log_ctx.commit() + yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" + + return StreamingResponse( + response_stream_generator(), media_type="text/event-stream" + ) - except Exception as e: - logger.error( - "[responses] stream error request_id=%s error=%s", request_id, e - ) - yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" + request_kwargs = { + "model": req.model, + "input": input_items, + "instructions": instructions, + "top_p": req.top_p, + "temperature": req.temperature, + "max_output_tokens": req.max_output_tokens, + } + if req.tools: + request_kwargs["tools"] = req.tools + request_kwargs["tool_choice"] = req.tool_choice + + logger.info(f"[responses] call start request_id={request_id}") + t_call = time.monotonic() + log_ctx.start_model_timing() + response = await client.responses.create(**request_kwargs) + log_ctx.end_model_timing() + logger.info( + f"[responses] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) - return StreamingResponse( - response_stream_generator(), media_type="text/event-stream" + # Handle tool workflow + log_ctx.start_tool_timing() + ( + final_response, + agg_prompt_tokens, + agg_completion_tokens, + ) = await handle_responses_tool_workflow(client, req, input_items, response) + log_ctx.end_tool_timing() + + model_response = SignedResponse( + **final_response.model_dump(), + signature="", + sources=sources, ) - request_kwargs = { - "model": req.model, - "input": input_items, - "instructions": instructions, - "top_p": req.top_p, - "temperature": req.temperature, - "max_output_tokens": req.max_output_tokens, - } - if req.tools: - request_kwargs["tools"] = req.tools - request_kwargs["tool_choice"] = req.tool_choice - - logger.info(f"[responses] call start request_id={request_id}") - t_call = time.monotonic() - - response = await client.responses.create(**request_kwargs) - logger.info( - f"[responses] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" - ) + if model_response.usage is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Model response does not contain usage statistics", + ) - ( - final_response, - agg_prompt_tokens, - agg_completion_tokens, - ) = await handle_responses_tool_workflow(client, req, input_items, response) + if agg_prompt_tokens or agg_completion_tokens: + model_response.usage.input_tokens += agg_prompt_tokens + model_response.usage.output_tokens += agg_completion_tokens - model_response = SignedResponse( - **final_response.model_dump(), - signature="", - sources=sources, - ) + prompt_tokens = model_response.usage.input_tokens + completion_tokens = model_response.usage.output_tokens - if model_response.usage is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Model response does not contain usage statistics", + meter.set_response( + { + "usage": LLMUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + web_searches=len(sources) if sources else 0, + ) + } ) - if agg_prompt_tokens or agg_completion_tokens: - model_response.usage.input_tokens += agg_prompt_tokens - model_response.usage.output_tokens += agg_completion_tokens + # Log query with context + # Note: Response object structure for tools might differ from Chat, + # but we'll assume basic usage logging is sufficient or adapt if needed. + # For now, we don't count tool calls explicitly in log_ctx for responses unless we parse them. + # Chat.py does: tool_calls_count = len(final_completion.choices[0].message.tool_calls) + # Responses API structure is different. `final_response` is a Response object. + # It might not have 'choices'. It has 'output'. + + log_ctx.set_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + web_search_calls=len(sources) if sources else 0, + ) + background_tasks.add_task(log_ctx.commit) - prompt_tokens = model_response.usage.input_tokens - completion_tokens = model_response.usage.output_tokens + response_json = model_response.model_dump_json() + signature = sign_message(state.private_key, response_json) + model_response.signature = b64encode(signature).decode() - await UserManager.update_token_usage( - auth_info.user.userid, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - ) - meter.set_response( - { - "usage": LLMUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - web_searches=len(sources) if sources else 0, - ) - } - ) - await QueryLogManager.log_query( - auth_info.user.userid, - model=req.model, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - web_search_calls=len(sources) if sources else 0, - ) - - response_json = model_response.model_dump_json() - signature = sign_message(state.private_key, response_json) - model_response.signature = b64encode(signature).decode() + logger.info( + f"[responses] done request_id={request_id} prompt_tokens={prompt_tokens} completion_tokens={completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" + ) + return model_response + + except HTTPException as e: + error_code = e.status_code + error_message = str(e.detail) if e.detail else str(e) + logger.error( + f"[responses] HTTPException request_id={request_id} user={auth_info.user.user_id} " + f"model={model_name} error_code={error_code} error={error_message}", + exc_info=True, + ) - logger.info( - f"[responses] done request_id={request_id} prompt_tokens={prompt_tokens} completion_tokens={completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" - ) - return model_response + if error_code >= 500: + if log_ctx.model is None: + log_ctx.set_model(model_name) + log_ctx.set_error(error_code=error_code, error_message=error_message) + await log_ctx.commit() + + raise + except Exception as e: + error_message = str(e) + logger.error( + f"[responses] unexpected error request_id={request_id} user={auth_info.user.user_id} " + f"model={model_name} error={error_message}", + exc_info=True, + ) + if log_ctx.model is None: + log_ctx.set_model(model_name) + log_ctx.set_error(error_code=500, error_message=error_message) + await log_ctx.commit() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Internal server error: {error_message}", + ) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 386dce2d..1cb33935 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -5,6 +5,7 @@ from nilai_api.attestation import get_attestation_report from nilai_api.auth import get_auth_info, AuthenticationInfo +from nilai_api.db.logs import QueryLogManager from nilai_api.handlers.nildb.api_model import ( PromptDelegationRequest, PromptDelegationToken, @@ -17,7 +18,6 @@ from nilai_common import ( AttestationReport, ModelMetadata, - Nonce, Usage, ) @@ -32,14 +32,10 @@ @router.get("/v1/delegation") async def get_prompt_store_delegation( prompt_delegation_request: PromptDelegationRequest, - auth_info: AuthenticationInfo = Depends(get_auth_info), + _: AuthenticationInfo = Depends( + get_auth_info + ), # This is to satisfy that the user is authenticated ) -> PromptDelegationToken: - if not auth_info.user.is_subscription_owner: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Prompt storage is reserved to subscription owners: {auth_info.user} is not a subscription owner, apikey: {auth_info.user}", - ) - try: return await get_nildb_delegation_token(prompt_delegation_request) except Exception as e: @@ -63,12 +59,15 @@ async def get_usage(auth_info: AuthenticationInfo = Depends(get_auth_info)) -> U usage = await get_usage(user) ``` """ - return Usage( - prompt_tokens=auth_info.user.prompt_tokens, - completion_tokens=auth_info.user.completion_tokens, - total_tokens=auth_info.user.prompt_tokens + auth_info.user.completion_tokens, - queries=auth_info.user.queries, # type: ignore # FIXME this field is not part of Usage + user_usage: Optional[Usage] = await QueryLogManager.get_user_token_usage( + auth_info.user.user_id ) + if user_usage is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found", + ) + return user_usage @router.get("/v1/attestation/report", tags=["Attestation"]) diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index 3d822cba..574b7437 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -32,6 +32,7 @@ ResponseInputItemParam, EasyInputMessageParam, ResponseFunctionToolCallParam, + Usage, ) from nilai_common.config import SETTINGS, MODEL_SETTINGS, MODEL_CAPABILITIES from nilai_common.discovery import ModelServiceDiscovery @@ -74,4 +75,5 @@ "ResponseInputItemParam", "EasyInputMessageParam", "ResponseFunctionToolCallParam", + "Usage", ] diff --git a/packages/nilai-common/src/nilai_common/api_models/__init__.py b/packages/nilai-common/src/nilai_common/api_models/__init__.py index 95243ff2..de76a353 100644 --- a/packages/nilai-common/src/nilai_common/api_models/__init__.py +++ b/packages/nilai-common/src/nilai_common/api_models/__init__.py @@ -30,6 +30,7 @@ MessageAdapter, ImageContent, TextContent, + Usage, ) from nilai_common.api_models.responses_model import ( @@ -74,6 +75,7 @@ "MessageAdapter", "ImageContent", "TextContent", + "Usage", "Response", "ResponseCompletedEvent", "ResponseRequest", diff --git a/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py b/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py index 41082ef5..743285ce 100644 --- a/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py +++ b/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py @@ -1,6 +1,5 @@ -from __future__ import annotations - from typing import ( + Annotated, Iterable, List, Optional, diff --git a/packages/nilai-common/src/nilai_common/discovery.py b/packages/nilai-common/src/nilai_common/discovery.py index 86c90224..345a9a6d 100644 --- a/packages/nilai-common/src/nilai_common/discovery.py +++ b/packages/nilai-common/src/nilai_common/discovery.py @@ -5,7 +5,7 @@ from typing import Dict, Optional import redis.asyncio as redis -from nilai_common.api_model import ModelEndpoint, ModelMetadata +from nilai_common.api_models import ModelEndpoint, ModelMetadata from tenacity import retry, stop_after_attempt, wait_exponential # Configure logging diff --git a/tests/integration/nilai_api/test_users_db_integration.py b/tests/integration/nilai_api/test_users_db_integration.py index 82d8d022..a9f5663a 100644 --- a/tests/integration/nilai_api/test_users_db_integration.py +++ b/tests/integration/nilai_api/test_users_db_integration.py @@ -4,6 +4,7 @@ These tests use a real PostgreSQL database via testcontainers. """ +import uuid import pytest import json @@ -17,37 +18,19 @@ class TestUserManagerIntegration: async def test_simple_user_creation(self, clean_database): """Test creating a simple user and retrieving it.""" # Insert user with minimal data - user = await UserManager.insert_user(name="Simple Test User") + user = await UserManager.insert_user(user_id="Simple Test User") # Verify user creation - assert user.name == "Simple Test User" - assert user.userid is not None - assert user.apikey is not None - assert user.userid != user.apikey # Should be different UUIDs + assert user.user_id == "Simple Test User" + assert user.rate_limits is None, ( + f"Rate limits are not set for user {user.user_id}" + ) # Retrieve user by ID - found_user = await UserManager.check_user(user.userid) + found_user = await UserManager.check_user(user.user_id) assert found_user is not None - assert found_user.userid == user.userid - assert found_user.name == "Simple Test User" - assert found_user.apikey == user.apikey - - @pytest.mark.asyncio - async def test_api_key_validation(self, clean_database): - """Test API key validation functionality.""" - # Create user - user = await UserManager.insert_user("API Test User") - - # Validate correct API key - api_user = await UserManager.check_api_key(user.apikey) - assert api_user is not None - assert api_user.apikey == user.apikey - assert api_user.userid == user.userid - assert api_user.name == "API Test User" - - # Test invalid API key - invalid_user = await UserManager.check_api_key("invalid-api-key") - assert invalid_user is None + assert found_user.user_id == user.user_id + assert found_user.rate_limits == user.rate_limits @pytest.mark.asyncio async def test_rate_limits_json_crud_basic(self, clean_database): @@ -66,14 +49,14 @@ async def test_rate_limits_json_crud_basic(self, clean_database): # CREATE: Insert user with rate limits user = await UserManager.insert_user( - name="Rate Limits Test User", rate_limits=rate_limits + user_id="Rate Limits Test User", rate_limits=rate_limits ) # Verify rate limits are stored as JSON assert user.rate_limits == rate_limits.model_dump() # READ: Retrieve user and verify rate limits JSON - retrieved_user = await UserManager.check_user(user.userid) + retrieved_user = await UserManager.check_user(user.user_id) assert retrieved_user is not None assert retrieved_user.rate_limits == rate_limits.model_dump() @@ -98,11 +81,11 @@ async def test_rate_limits_json_update(self, clean_database): ) user = await UserManager.insert_user( - name="Update Rate Limits User", rate_limits=initial_rate_limits + user_id="Update Rate Limits User", rate_limits=initial_rate_limits ) # Verify initial rate limits - retrieved_user = await UserManager.check_user(user.userid) + retrieved_user = await UserManager.check_user(user.user_id) assert retrieved_user is not None assert retrieved_user.rate_limits == initial_rate_limits.model_dump() @@ -125,19 +108,19 @@ async def test_rate_limits_json_update(self, clean_database): stmt = sa.text(""" UPDATE users SET rate_limits = :rate_limits_json - WHERE userid = :userid + WHERE user_id = :user_id """) await session.execute( stmt, { "rate_limits_json": updated_rate_limits.model_dump_json(), - "userid": user.userid, + "user_id": user.user_id, }, ) await session.commit() # READ: Verify the update worked - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None assert updated_user.rate_limits == updated_rate_limits.model_dump() @@ -162,11 +145,11 @@ async def test_rate_limits_json_partial_update(self, clean_database): ) user = await UserManager.insert_user( - name="Partial Rate Limits User", rate_limits=partial_rate_limits + user_id="Partial Rate Limits User", rate_limits=partial_rate_limits ) # Verify partial data is stored correctly - retrieved_user = await UserManager.check_user(user.userid) + retrieved_user = await UserManager.check_user(user.user_id) assert retrieved_user is not None assert retrieved_user.rate_limits == partial_rate_limits.model_dump() @@ -183,13 +166,13 @@ async def test_rate_limits_json_partial_update(self, clean_database): '{user_rate_limit_hour}', '75' ) - WHERE userid = :userid + WHERE user_id = :user_id """) - await session.execute(stmt, {"userid": user.userid}) + await session.execute(stmt, {"user_id": user.user_id}) await session.commit() # Verify partial update worked - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None expected_data = partial_rate_limits.model_dump() @@ -211,7 +194,7 @@ async def test_rate_limits_json_null_and_delete(self, clean_database): ) user = await UserManager.insert_user( - name="Delete Rate Limits User", rate_limits=rate_limits + user_id="Delete Rate Limits User", rate_limits=rate_limits ) # DELETE: Set rate_limits to NULL @@ -219,12 +202,14 @@ async def test_rate_limits_json_null_and_delete(self, clean_database): import sqlalchemy as sa async with get_db_session() as session: - stmt = sa.text("UPDATE users SET rate_limits = NULL WHERE userid = :userid") - await session.execute(stmt, {"userid": user.userid}) + stmt = sa.text( + "UPDATE users SET rate_limits = NULL WHERE user_id = :user_id" + ) + await session.execute(stmt, {"user_id": user.user_id}) await session.commit() # Verify NULL handling - null_user = await UserManager.check_user(user.userid) + null_user = await UserManager.check_user(user.user_id) assert null_user is not None assert null_user.rate_limits is None @@ -239,15 +224,15 @@ async def test_rate_limits_json_null_and_delete(self, clean_database): # First set some data new_data = {"user_rate_limit_day": 500, "web_search_rate_limit_day": 25} stmt = sa.text( - "UPDATE users SET rate_limits = :data WHERE userid = :userid" + "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" ) await session.execute( - stmt, {"data": json.dumps(new_data), "userid": user.userid} + stmt, {"data": json.dumps(new_data), "user_id": user.user_id} ) await session.commit() # Verify data was set - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None assert updated_user.rate_limits == new_data @@ -256,13 +241,13 @@ async def test_rate_limits_json_null_and_delete(self, clean_database): stmt = sa.text(""" UPDATE users SET rate_limits = rate_limits::jsonb - 'web_search_rate_limit_day' - WHERE userid = :userid + WHERE user_id = :user_id """) - await session.execute(stmt, {"userid": user.userid}) + await session.execute(stmt, {"user_id": user.user_id}) await session.commit() # Verify field was removed - final_user = await UserManager.check_user(user.userid) + final_user = await UserManager.check_user(user.user_id) expected_final_data = {"user_rate_limit_day": 500} assert final_user is not None assert final_user.rate_limits == expected_final_data @@ -293,15 +278,15 @@ async def test_rate_limits_json_validation_and_conversion(self, clean_database): for i, test_data in enumerate(test_cases): async with get_db_session() as session: stmt = sa.text( - "UPDATE users SET rate_limits = :data WHERE userid = :userid" + "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" ) await session.execute( - stmt, {"data": json.dumps(test_data), "userid": user.userid} + stmt, {"data": json.dumps(test_data), "user_id": user.user_id} ) await session.commit() # Retrieve and verify - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None assert updated_user.rate_limits == test_data @@ -327,11 +312,13 @@ async def test_rate_limits_json_validation_and_conversion(self, clean_database): # Test empty JSON object async with get_db_session() as session: - stmt = sa.text("UPDATE users SET rate_limits = '{}' WHERE userid = :userid") - await session.execute(stmt, {"userid": user.userid}) + stmt = sa.text( + "UPDATE users SET rate_limits = '{}' WHERE user_id = :user_id" + ) + await session.execute(stmt, {"user_id": user.user_id}) await session.commit() - empty_user = await UserManager.check_user(user.userid) + empty_user = await UserManager.check_user(user.user_id) assert empty_user is not None assert empty_user.rate_limits == {} empty_rate_limits_obj = empty_user.rate_limits_obj @@ -343,18 +330,18 @@ async def test_rate_limits_json_validation_and_conversion(self, clean_database): async with get_db_session() as session: # This should work as PostgreSQL JSONB validates JSON stmt = sa.text( - "UPDATE users SET rate_limits = :data WHERE userid = :userid" + "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" ) await session.execute( stmt, { "data": '{"user_rate_limit_day": 5000}', # Valid JSON string - "userid": user.userid, + "user_id": user.user_id, }, ) await session.commit() - json_string_user = await UserManager.check_user(user.userid) + json_string_user = await UserManager.check_user(user.user_id) assert json_string_user is not None assert json_string_user.rate_limits == {"user_rate_limit_day": 5000} @@ -366,16 +353,15 @@ async def test_rate_limits_json_validation_and_conversion(self, clean_database): async def test_rate_limits_update_workflow(self, clean_database): """Test complete workflow: create user with no rate limits -> update rate limits -> verify update.""" # Step 1: Create user with NO rate limits - user = await UserManager.insert_user(name="Rate Limits Workflow User") + user_id = str(uuid.uuid4()) + user = await UserManager.insert_user(user_id=user_id) # Verify user was created with no rate limits - assert user.name == "Rate Limits Workflow User" - assert user.userid is not None - assert user.apikey is not None + assert user.user_id == user_id assert user.rate_limits is None # No rate limits initially # Step 2: Retrieve user and confirm no rate limits - retrieved_user = await UserManager.check_user(user.userid) + retrieved_user = await UserManager.check_user(user.user_id) assert retrieved_user is not None print(retrieved_user.to_pydantic()) assert retrieved_user is not None @@ -401,12 +387,12 @@ async def test_rate_limits_update_workflow(self, clean_database): # Step 4: Update the user's rate limits using the new function update_success = await UserManager.update_rate_limits( - user.userid, new_rate_limits + user.user_id, new_rate_limits ) assert update_success is True # Step 5: Retrieve user again and verify rate limits were updated - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None assert updated_user.rate_limits is not None assert updated_user.rate_limits == new_rate_limits.model_dump() @@ -431,12 +417,12 @@ async def test_rate_limits_update_workflow(self, clean_database): ) partial_update_success = await UserManager.update_rate_limits( - user.userid, partial_rate_limits + user.user_id, partial_rate_limits ) assert partial_update_success is True # Step 8: Verify partial update worked - final_user = await UserManager.check_user(user.userid) + final_user = await UserManager.check_user(user.user_id) assert final_user is not None assert final_user.rate_limits == partial_rate_limits.model_dump() @@ -447,8 +433,8 @@ async def test_rate_limits_update_workflow(self, clean_database): # Other fields should have config defaults (not None due to get_effective_limits) # Step 9: Test error case - update non-existent user - fake_userid = "non-existent-user-id" + fake_user_id = "non-existent-user-id" error_update = await UserManager.update_rate_limits( - fake_userid, new_rate_limits + fake_user_id, new_rate_limits ) assert error_update is False diff --git a/tests/unit/nilai_api/__init__.py b/tests/unit/nilai_api/__init__.py index 0be52613..7cbc1237 100644 --- a/tests/unit/nilai_api/__init__.py +++ b/tests/unit/nilai_api/__init__.py @@ -21,11 +21,11 @@ def generate_api_key(self) -> str: async def insert_user(self, name: str, email: str) -> Dict[str, str]: """Insert a new user into the mock database.""" - userid = self.generate_user_id() + user_id = self.generate_user_id() apikey = self.generate_api_key() user_data = { - "userid": userid, + "user_id": user_id, "name": name, "email": email, "apikey": apikey, @@ -36,34 +36,34 @@ async def insert_user(self, name: str, email: str) -> Dict[str, str]: "last_activity": None, } - self.users[userid] = user_data - return {"userid": userid, "apikey": apikey} + self.users[user_id] = user_data + return {"user_id": user_id, "apikey": apikey} async def check_api_key(self, api_key: str) -> Optional[dict]: """Validate an API key in the mock database.""" for user in self.users.values(): if user["apikey"] == api_key: - return {"name": user["name"], "userid": user["userid"]} + return {"name": user["name"], "user_id": user["user_id"]} return None async def update_token_usage( - self, userid: str, prompt_tokens: int, completion_tokens: int + self, user_id: str, prompt_tokens: int, completion_tokens: int ): """Update token usage for a specific user.""" - if userid in self.users: - user = self.users[userid] + if user_id in self.users: + user = self.users[user_id] user["prompt_tokens"] += prompt_tokens user["completion_tokens"] += completion_tokens user["queries"] += 1 user["last_activity"] = datetime.now(timezone.utc) async def log_query( - self, userid: str, model: str, prompt_tokens: int, completion_tokens: int + self, user_id: str, model: str, prompt_tokens: int, completion_tokens: int ): """Log a user's query in the mock database.""" query_log = { "id": self._next_query_log_id, - "userid": userid, + "user_id": user_id, "query_timestamp": datetime.now(timezone.utc), "model": model, "prompt_tokens": prompt_tokens, @@ -74,9 +74,9 @@ async def log_query( self.query_logs[self._next_query_log_id] = query_log self._next_query_log_id += 1 - async def get_token_usage(self, userid: str) -> Optional[Dict[str, Any]]: + async def get_token_usage(self, user_id: str) -> Optional[Dict[str, Any]]: """Get token usage for a specific user.""" - user = self.users.get(userid) + user = self.users.get(user_id) if user: return { "prompt_tokens": user["prompt_tokens"], @@ -90,9 +90,9 @@ async def get_all_users(self) -> Optional[List[Dict[str, Any]]]: """Retrieve all users from the mock database.""" return list(self.users.values()) if self.users else None - async def get_user_token_usage(self, userid: str) -> Optional[Dict[str, int]]: + async def get_user_token_usage(self, user_id: str) -> Optional[Dict[str, int]]: """Retrieve total token usage for a user.""" - user = self.users.get(userid) + user = self.users.get(user_id) if user: return { "prompt_tokens": user["prompt_tokens"], diff --git a/tests/unit/nilai_api/auth/test_auth.py b/tests/unit/nilai_api/auth/test_auth.py index 591c447a..7aee50a7 100644 --- a/tests/unit/nilai_api/auth/test_auth.py +++ b/tests/unit/nilai_api/auth/test_auth.py @@ -4,7 +4,6 @@ from nilai_api.db.users import RateLimits import pytest -from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials from nilai_api.config import CONFIG as config @@ -14,13 +13,9 @@ @pytest.fixture -def mock_user_manager(mocker): - from nilai_api.db.users import UserManager - - """Fixture to mock UserManager methods.""" - mocker.patch.object(UserManager, "check_api_key") - mocker.patch.object(UserManager, "update_last_activity") - return UserManager +def mock_validate_credential(mocker): + """Fixture to mock validate_credential function.""" + return mocker.patch("nilai_api.auth.strategies.validate_credential") @pytest.fixture @@ -29,7 +24,7 @@ def mock_user_model(): mock = MagicMock(spec=UserModel) mock.name = "Test User" - mock.userid = "test-user-id" + mock.user_id = "test-user-id" mock.apikey = "test-api-key" mock.prompt_tokens = 0 mock.completion_tokens = 0 @@ -49,53 +44,38 @@ def mock_user_data(mock_user_model): return UserData.from_sqlalchemy(mock_user_model) -@pytest.fixture -def mock_auth_info(): - from nilai_api.auth import AuthenticationInfo - - mock = MagicMock(spec=AuthenticationInfo) - mock.user = mock_user_data - return mock - - @pytest.mark.asyncio -async def test_get_auth_info_valid_token( - mock_user_manager, mock_auth_info, mock_user_model -): +async def test_get_auth_info_valid_token(mock_validate_credential, mock_user_model): from nilai_api.auth import get_auth_info """Test get_auth_info with a valid token.""" - mock_user_manager.check_api_key.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model credentials = HTTPAuthorizationCredentials( scheme="Bearer", credentials="valid-token" ) auth_info = await get_auth_info(credentials) print(auth_info) - assert auth_info.user.name == "Test User", ( - f"Expected Test User but got {auth_info.user.name}" - ) - assert auth_info.user.userid == "test-user-id", ( - f"Expected test-user-id but got {auth_info.user.userid}" + + assert auth_info.user.user_id == "test-user-id", ( + f"Expected test-user-id but got {auth_info.user.user_id}" ) @pytest.mark.asyncio -async def test_get_auth_info_invalid_token(mock_user_manager): +async def test_get_auth_info_invalid_token(mock_validate_credential): from nilai_api.auth import get_auth_info + from nilai_api.auth.common import AuthenticationError """Test get_auth_info with an invalid token.""" - mock_user_manager.check_api_key.return_value = None + mock_validate_credential.side_effect = AuthenticationError("Credential not found") credentials = HTTPAuthorizationCredentials( scheme="Bearer", credentials="invalid-token" ) - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(AuthenticationError) as exc_info: auth_infor = await get_auth_info(credentials) print(auth_infor) print(exc_info) - assert exc_info.value.status_code == 401, ( - f"Expected status code 401 but got {exc_info.value.status_code}" - ) - assert exc_info.value.detail == "Missing or invalid API key", ( - f"Expected Missing or invalid API key but got {exc_info.value.detail}" + assert "Credential not found" in str(exc_info.value.detail), ( + f"Expected 'Credential not found' but got {exc_info.value.detail}" ) diff --git a/tests/unit/nilai_api/auth/test_strategies.py b/tests/unit/nilai_api/auth/test_strategies.py index 0c169f53..5b65c5b0 100644 --- a/tests/unit/nilai_api/auth/test_strategies.py +++ b/tests/unit/nilai_api/auth/test_strategies.py @@ -1,7 +1,6 @@ import pytest from unittest.mock import patch, MagicMock from datetime import datetime, timezone, timedelta -from fastapi import HTTPException from nilai_api.auth.strategies import api_key_strategy, nuc_strategy from nilai_api.auth.common import AuthenticationInfo, PromptDocument @@ -15,14 +14,7 @@ class TestAuthStrategies: def mock_user_model(self): """Mock UserModel fixture""" mock = MagicMock(spec=UserModel) - mock.name = "Test User" - mock.userid = "test-user-id" - mock.apikey = "test-api-key" - mock.prompt_tokens = 0 - mock.completion_tokens = 0 - mock.queries = 0 - mock.signup_date = datetime.now(timezone.utc) - mock.last_activity = datetime.now(timezone.utc) + mock.user_id = "test-user-id" mock.rate_limits = RateLimits().get_effective_limits().model_dump_json() mock.rate_limits_obj = RateLimits().get_effective_limits() return mock @@ -37,27 +29,26 @@ def mock_prompt_document(self): @pytest.mark.asyncio async def test_api_key_strategy_success(self, mock_user_model): """Test successful API key authentication""" - with patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check: - mock_check.return_value = mock_user_model + with patch("nilai_api.auth.strategies.validate_credential") as mock_validate: + mock_validate.return_value = mock_user_model result = await api_key_strategy("test-api-key") assert isinstance(result, AuthenticationInfo) - assert result.user.name == "Test User" assert result.token_rate_limit is None assert result.prompt_document is None + mock_validate.assert_called_once_with("test-api-key", is_public=False) @pytest.mark.asyncio async def test_api_key_strategy_invalid_key(self): """Test API key authentication with invalid key""" - with patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check: - mock_check.return_value = None + from nilai_api.auth.common import AuthenticationError - with pytest.raises(HTTPException) as exc_info: - await api_key_strategy("invalid-key") + with patch("nilai_api.auth.strategies.validate_credential") as mock_validate: + mock_validate.side_effect = AuthenticationError("Credential not found") - assert exc_info.value.status_code == 401 - assert "Missing or invalid API key" in str(exc_info.value.detail) + with pytest.raises(AuthenticationError, match="Credential not found"): + await api_key_strategy("invalid-key") @pytest.mark.asyncio async def test_nuc_strategy_existing_user_with_prompt_document( @@ -65,7 +56,7 @@ async def test_nuc_strategy_existing_user_with_prompt_document( ): """Test NUC authentication with existing user and prompt document""" with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -73,23 +64,27 @@ async def test_nuc_strategy_existing_user_with_prompt_document( "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "user_id") + mock_validate_nuc.return_value = ("subscription_holder", "user_id") mock_get_rate_limit.return_value = None mock_get_prompt_doc.return_value = mock_prompt_document - mock_check_user.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model result = await nuc_strategy("nuc-token") assert isinstance(result, AuthenticationInfo) - assert result.user.name == "Test User" assert result.token_rate_limit is None assert result.prompt_document == mock_prompt_document + mock_validate_credential.assert_called_once_with( + "subscription_holder", is_public=True + ) @pytest.mark.asyncio - async def test_nuc_strategy_new_user_with_token_limits(self, mock_prompt_document): + async def test_nuc_strategy_new_user_with_token_limits( + self, mock_prompt_document, mock_user_model + ): """Test NUC authentication creating new user with token limits""" from nilai_api.auth.nuc_helpers.usage import TokenRateLimits, TokenRateLimit @@ -104,7 +99,7 @@ async def test_nuc_strategy_new_user_with_token_limits(self, mock_prompt_documen ) with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -112,30 +107,28 @@ async def test_nuc_strategy_new_user_with_token_limits(self, mock_prompt_documen "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, - patch( - "nilai_api.auth.strategies.UserManager.insert_user_model" - ) as mock_insert, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "new_user_id") + mock_validate_nuc.return_value = ("subscription_holder", "new_user_id") mock_get_rate_limit.return_value = mock_token_limits mock_get_prompt_doc.return_value = mock_prompt_document - mock_check_user.return_value = None - mock_insert.return_value = None + mock_validate_credential.return_value = mock_user_model result = await nuc_strategy("nuc-token") assert isinstance(result, AuthenticationInfo) assert result.token_rate_limit == mock_token_limits assert result.prompt_document == mock_prompt_document - mock_insert.assert_called_once() + mock_validate_credential.assert_called_once_with( + "subscription_holder", is_public=True + ) @pytest.mark.asyncio async def test_nuc_strategy_no_prompt_document(self, mock_user_model): """Test NUC authentication when no prompt document is found""" with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -143,18 +136,17 @@ async def test_nuc_strategy_no_prompt_document(self, mock_user_model): "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "user_id") + mock_validate_nuc.return_value = ("subscription_holder", "user_id") mock_get_rate_limit.return_value = None mock_get_prompt_doc.return_value = None - mock_check_user.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model result = await nuc_strategy("nuc-token") assert isinstance(result, AuthenticationInfo) - assert result.user.name == "Test User" assert result.token_rate_limit is None assert result.prompt_document is None @@ -171,7 +163,7 @@ async def test_nuc_strategy_validation_error(self): async def test_nuc_strategy_get_prompt_document_error(self, mock_user_model): """Test NUC authentication when get_token_prompt_document fails""" with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -179,15 +171,15 @@ async def test_nuc_strategy_get_prompt_document_error(self, mock_user_model): "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "user_id") + mock_validate_nuc.return_value = ("subscription_holder", "user_id") mock_get_rate_limit.return_value = None mock_get_prompt_doc.side_effect = Exception( "Prompt document extraction failed" ) - mock_check_user.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model # The function should let the exception bubble up or handle it gracefully # Based on the diff, it looks like it doesn't catch exceptions from get_token_prompt_document @@ -200,29 +192,22 @@ async def test_all_strategies_return_authentication_info_with_prompt_document_fi ): """Test that all strategies return AuthenticationInfo with prompt_document field""" mock_user_model = MagicMock(spec=UserModel) - mock_user_model.name = "Test" - mock_user_model.userid = "test" - mock_user_model.apikey = "test" - mock_user_model.prompt_tokens = 0 - mock_user_model.completion_tokens = 0 - mock_user_model.queries = 0 - mock_user_model.signup_date = datetime.now(timezone.utc) - mock_user_model.last_activity = datetime.now(timezone.utc) + mock_user_model.user_id = "test" mock_user_model.rate_limits = ( RateLimits().get_effective_limits().model_dump_json() ) mock_user_model.rate_limits_obj = RateLimits().get_effective_limits() # Test API key strategy - with patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check: - mock_check.return_value = mock_user_model + with patch("nilai_api.auth.strategies.validate_credential") as mock_validate: + mock_validate.return_value = mock_user_model result = await api_key_strategy("test-key") assert hasattr(result, "prompt_document") assert result.prompt_document is None # Test NUC strategy with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -230,13 +215,13 @@ async def test_all_strategies_return_authentication_info_with_prompt_document_fi "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "user_id") + mock_validate_nuc.return_value = ("subscription_holder", "user_id") mock_get_rate_limit.return_value = None mock_get_prompt_doc.return_value = None - mock_check_user.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model result = await nuc_strategy("nuc-token") assert hasattr(result, "prompt_document") diff --git a/tests/unit/nilai_api/routers/test_chat_completions_private.py b/tests/unit/nilai_api/routers/test_chat_completions_private.py index 2ba88cc8..35621a3c 100644 --- a/tests/unit/nilai_api/routers/test_chat_completions_private.py +++ b/tests/unit/nilai_api/routers/test_chat_completions_private.py @@ -20,7 +20,7 @@ async def test_runs_in_a_loop(): @pytest.fixture def mock_user(): mock = MagicMock(spec=UserModel) - mock.userid = "test-user-id" + mock.user_id = "test-user-id" mock.name = "Test User" mock.apikey = "test-api-key" mock.prompt_tokens = 100 @@ -38,62 +38,30 @@ def mock_user(): def mock_user_manager(mock_user, mocker): from nilai_api.db.logs import QueryLogManager from nilai_api.db.users import UserManager + from nilai_common import Usage + # Mock QueryLogManager for usage tracking mocker.patch.object( - UserManager, - "get_token_usage", - return_value={ - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "queries": 10, - }, - ) - mocker.patch.object(UserManager, "update_token_usage") - mocker.patch.object( - UserManager, + QueryLogManager, "get_user_token_usage", - return_value={ - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "completion_tokens_details": None, - "prompt_tokens_details": None, - "queries": 10, - }, - ) - mocker.patch.object( - UserManager, - "insert_user", - return_value={ - "userid": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, + new_callable=AsyncMock, + return_value=Usage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + completion_tokens_details=None, + prompt_tokens_details=None, + ), ) - mocker.patch.object( - UserManager, - "check_api_key", + mocker.patch.object(QueryLogManager, "log_query", new_callable=AsyncMock) + + # Mock validate_credential for authentication + mocker.patch( + "nilai_api.auth.strategies.validate_credential", + new_callable=AsyncMock, return_value=mock_user, ) - mocker.patch.object( - UserManager, - "get_all_users", - return_value=[ - { - "userid": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - { - "userid": "test-user-id-2", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - ], - ) - mocker.patch.object(QueryLogManager, "log_query") - mocker.patch.object(UserManager, "update_last_activity") + return UserManager @@ -173,7 +141,6 @@ def test_get_usage(mock_user, mock_user_manager, mock_state, client): "total_tokens": 150, "completion_tokens_details": None, "prompt_tokens_details": None, - "queries": 10, } @@ -220,6 +187,9 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien "nilai_api.routers.endpoints.chat.handle_tool_workflow", return_value=(response_data, 0, 0), ) + mocker.patch( + "nilai_api.routers.private.QueryLogContext.commit", new_callable=AsyncMock + ) response = client.post( "/v1/chat/completions", json={ @@ -260,6 +230,9 @@ def test_chat_completion_stream_includes_sources( "nilai_api.routers.endpoints.chat.handle_web_search", new=AsyncMock(return_value=mock_web_search_result), ) + mocker.patch( + "nilai_api.routers.private.QueryLogContext.commit", new_callable=AsyncMock + ) class MockChunk: def __init__(self, data, usage=None): diff --git a/tests/unit/nilai_api/routers/test_nildb_endpoints.py b/tests/unit/nilai_api/routers/test_nildb_endpoints.py index b54b664c..98a68da0 100644 --- a/tests/unit/nilai_api/routers/test_nildb_endpoints.py +++ b/tests/unit/nilai_api/routers/test_nildb_endpoints.py @@ -18,14 +18,7 @@ class TestNilDBEndpoints: def mock_subscription_owner_user(self): """Mock user data for subscription owner""" mock_user_model = MagicMock(spec=UserModel) - mock_user_model.name = "Subscription Owner" - mock_user_model.userid = "owner-id" - mock_user_model.apikey = "owner-id" # Same as userid for subscription owner - mock_user_model.prompt_tokens = 0 - mock_user_model.completion_tokens = 0 - mock_user_model.queries = 0 - mock_user_model.signup_date = datetime.now(timezone.utc) - mock_user_model.last_activity = datetime.now(timezone.utc) + mock_user_model.user_id = "owner-id" mock_user_model.rate_limits = ( RateLimits().get_effective_limits().model_dump_json() ) @@ -37,14 +30,7 @@ def mock_subscription_owner_user(self): def mock_regular_user(self): """Mock user data for regular user (not subscription owner)""" mock_user_model = MagicMock(spec=UserModel) - mock_user_model.name = "Regular User" - mock_user_model.userid = "user-id" - mock_user_model.apikey = "different-api-key" # Different from userid - mock_user_model.prompt_tokens = 0 - mock_user_model.completion_tokens = 0 - mock_user_model.queries = 0 - mock_user_model.signup_date = datetime.now(timezone.utc) - mock_user_model.last_activity = datetime.now(timezone.utc) + mock_user_model.user_id = "user-id" mock_user_model.rate_limits = ( RateLimits().get_effective_limits().model_dump_json() ) @@ -99,21 +85,25 @@ async def test_get_prompt_store_delegation_success( mock_get_delegation.assert_called_once_with("user-123") @pytest.mark.asyncio - async def test_get_prompt_store_delegation_forbidden_regular_user( - self, mock_auth_info_regular_user + async def test_get_prompt_store_delegation_success_regular_user( + self, mock_auth_info_regular_user, mock_prompt_delegation_token ): - """Test delegation token request by regular user (not subscription owner)""" + """Test delegation token request by regular user (endpoint no longer checks subscription ownership)""" from nilai_api.routers.private import get_prompt_store_delegation - request = "user-123" + with patch( + "nilai_api.routers.private.get_nildb_delegation_token" + ) as mock_get_delegation: + mock_get_delegation.return_value = mock_prompt_delegation_token - with pytest.raises(HTTPException) as exc_info: - await get_prompt_store_delegation(request, mock_auth_info_regular_user) + request = "user-123" - assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN - assert "Prompt storage is reserved to subscription owners" in str( - exc_info.value.detail - ) + result = await get_prompt_store_delegation( + request, mock_auth_info_regular_user + ) + + assert isinstance(result, PromptDelegationToken) + assert result.token == "delegation_token_123" @pytest.mark.asyncio async def test_get_prompt_store_delegation_handler_error( @@ -150,7 +140,7 @@ async def test_chat_completion_with_prompt_document_injection(self): ) mock_user = MagicMock() - mock_user.userid = "test-user-id" + mock_user.user_id = "test-user-id" mock_user.name = "Test User" mock_user.apikey = "test-api-key" mock_user.rate_limits = RateLimits().get_effective_limits() @@ -163,6 +153,14 @@ async def test_chat_completion_with_prompt_document_injection(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + # Mock log context + mock_log_ctx = MagicMock() + mock_log_ctx.set_user = MagicMock() + mock_log_ctx.set_model = MagicMock() + mock_log_ctx.set_request_params = MagicMock() + mock_log_ctx.start_model_timing = MagicMock() + mock_log_ctx.end_model_timing = MagicMock() + request = ChatRequest( model="test-model", messages=[{"role": "user", "content": "Hello"}] ) @@ -205,10 +203,6 @@ async def test_chat_completion_with_prompt_document_injection(self): mock_web_search_result.sources = [] mock_handle_web_search.return_value = mock_web_search_result - # Mock async database operations - mock_update_usage.return_value = None - mock_log_query.return_value = None - # Mock OpenAI client mock_client_instance = MagicMock() mock_response = MagicMock() @@ -266,9 +260,7 @@ async def test_chat_completion_prompt_document_extraction_error(self): ) mock_user = MagicMock() - mock_user.userid = "test-user-id" - mock_user.name = "Test User" - mock_user.apikey = "test-api-key" + mock_user.user_id = "test-user-id" mock_user.rate_limits = RateLimits().get_effective_limits() mock_auth_info = AuthenticationInfo( @@ -279,6 +271,12 @@ async def test_chat_completion_prompt_document_extraction_error(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + # Mock log context + mock_log_ctx = MagicMock() + mock_log_ctx.set_user = MagicMock() + mock_log_ctx.set_model = MagicMock() + mock_log_ctx.set_request_params = MagicMock() + request = ChatRequest( model="test-model", messages=[{"role": "user", "content": "Hello"}] ) @@ -300,7 +298,10 @@ async def test_chat_completion_prompt_document_extraction_error(self): with pytest.raises(HTTPException) as exc_info: await chat_completion( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @@ -316,9 +317,7 @@ async def test_chat_completion_without_prompt_document(self): from nilai_common import ChatRequest mock_user = MagicMock() - mock_user.userid = "test-user-id" - mock_user.name = "Test User" - mock_user.apikey = "test-api-key" + mock_user.user_id = "test-user-id" mock_user.rate_limits = RateLimits().get_effective_limits() mock_auth_info = AuthenticationInfo( @@ -331,6 +330,14 @@ async def test_chat_completion_without_prompt_document(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + # Mock log context + mock_log_ctx = MagicMock() + mock_log_ctx.set_user = MagicMock() + mock_log_ctx.set_model = MagicMock() + mock_log_ctx.set_request_params = MagicMock() + mock_log_ctx.start_model_timing = MagicMock() + mock_log_ctx.end_model_timing = MagicMock() + request = ChatRequest( model="test-model", messages=[{"role": "user", "content": "Hello"}] ) @@ -371,10 +378,6 @@ async def test_chat_completion_without_prompt_document(self): mock_web_search_result.sources = [] mock_handle_web_search.return_value = mock_web_search_result - # Mock async database operations - mock_update_usage.return_value = None - mock_log_query.return_value = None - # Mock OpenAI client mock_client_instance = MagicMock() mock_response = MagicMock() @@ -658,12 +661,10 @@ def test_prompt_delegation_token_model_validation(self): assert token.token == "delegation_token_123" assert token.did == "did:nil:builder123" - def test_user_is_subscription_owner_property( - self, mock_subscription_owner_user, mock_regular_user - ): - """Test the is_subscription_owner property""" - # Subscription owner (userid == apikey) - assert mock_subscription_owner_user.is_subscription_owner is True - - # Regular user (userid != apikey) - assert mock_regular_user.is_subscription_owner is False + def test_user_data_structure(self, mock_subscription_owner_user, mock_regular_user): + """Test the UserData structure has required fields""" + # Check that UserData has the expected fields + assert hasattr(mock_subscription_owner_user, "user_id") + assert hasattr(mock_subscription_owner_user, "rate_limits") + assert hasattr(mock_regular_user, "user_id") + assert hasattr(mock_regular_user, "rate_limits") diff --git a/tests/unit/nilai_api/routers/test_responses_private.py b/tests/unit/nilai_api/routers/test_responses_private.py index cf122c79..4e2f7932 100644 --- a/tests/unit/nilai_api/routers/test_responses_private.py +++ b/tests/unit/nilai_api/routers/test_responses_private.py @@ -43,61 +43,25 @@ def mock_user_manager(mock_user, mocker): from nilai_api.db.users import UserManager from nilai_api.db.logs import QueryLogManager + # Patch QueryLogManager for usage mocker.patch.object( - UserManager, - "get_token_usage", - return_value={ - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "queries": 10, - }, - ) - mocker.patch.object(UserManager, "update_token_usage") - mocker.patch.object( - UserManager, + QueryLogManager, "get_user_token_usage", return_value={ "prompt_tokens": 100, "completion_tokens": 50, - "total_tokens": 150, - "completion_tokens_details": None, - "prompt_tokens_details": None, "queries": 10, }, ) + mocker.patch.object(QueryLogManager, "log_query") + + # Patch UserManager.check_user instead of check_api_key mocker.patch.object( UserManager, - "insert_user", - return_value={ - "userid": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - ) - mocker.patch.object( - UserManager, - "check_api_key", + "check_user", return_value=mock_user, ) - mocker.patch.object( - UserManager, - "get_all_users", - return_value=[ - { - "userid": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - { - "userid": "test-user-id-2", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - ], - ) - mocker.patch.object(QueryLogManager, "log_query") - mocker.patch.object(UserManager, "update_last_activity") + return UserManager @@ -107,6 +71,7 @@ def mock_state(mocker): mock_discovery_service = mocker.Mock() mock_discovery_service.discover_models = AsyncMock(return_value=expected_models) + mock_discovery_service.initialize = AsyncMock() mocker.patch.object(state, "discovery_service", mock_discovery_service) diff --git a/tests/unit/nilai_api/test_db.py b/tests/unit/nilai_api/test_db.py index dff0fd8b..3979321d 100644 --- a/tests/unit/nilai_api/test_db.py +++ b/tests/unit/nilai_api/test_db.py @@ -15,7 +15,7 @@ async def test_insert_user(mock_db): """Test user insertion functionality.""" user = await mock_db.insert_user("Test User", "test@example.com") - assert "userid" in user + assert "user_id" in user assert "apikey" in user assert len(mock_db.users) == 1 @@ -38,9 +38,9 @@ async def test_token_usage(mock_db): """Test token usage tracking.""" user = await mock_db.insert_user("Test User", "test@example.com") - await mock_db.update_token_usage(user["userid"], 50, 20) + await mock_db.update_token_usage(user["user_id"], 50, 20) - token_usage = await mock_db.get_token_usage(user["userid"]) + token_usage = await mock_db.get_token_usage(user["user_id"]) assert token_usage["prompt_tokens"] == 50 assert token_usage["completion_tokens"] == 20 assert token_usage["queries"] == 1 @@ -51,9 +51,9 @@ async def test_query_logging(mock_db): """Test query logging functionality.""" user = await mock_db.insert_user("Test User", "test@example.com") - await mock_db.log_query(user["userid"], "test-model", 10, 15) + await mock_db.log_query(user["user_id"], "test-model", 10, 15) assert len(mock_db.query_logs) == 1 log_entry = list(mock_db.query_logs.values())[0] - assert log_entry["userid"] == user["userid"] + assert log_entry["user_id"] == user["user_id"] assert log_entry["model"] == "test-model" diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index 27a5c1bc..ff85470e 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -45,7 +45,7 @@ async def test_concurrent_rate_limit(req): rate_limit = RateLimit(concurrent_extractor=lambda _: (5, "test")) user_limits = UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, @@ -117,7 +117,7 @@ async def web_search_extractor(_): "user_limits", [ UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=10, @@ -131,7 +131,7 @@ async def web_search_extractor(_): ), ), UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, @@ -145,7 +145,7 @@ async def web_search_extractor(_): ), ), UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, @@ -159,7 +159,7 @@ async def web_search_extractor(_): ), ), UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=TokenRateLimits( limits=[ TokenRateLimit( @@ -220,7 +220,7 @@ async def web_search_extractor(request): rate_limit = RateLimit(web_search_extractor=web_search_extractor) user_limits = UserRateLimits( - subscription_holder=apikey, + user_id=apikey, token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, diff --git a/uv.lock b/uv.lock index 65292720..03918e23 100644 --- a/uv.lock +++ b/uv.lock @@ -2238,7 +2238,7 @@ requires-dist = [ { name = "gunicorn", specifier = ">=23.0.0" }, { name = "httpx", specifier = ">=0.27.2" }, { name = "nilai-common", editable = "packages/nilai-common" }, - { name = "nilauth-credit-middleware", specifier = "==0.1.1" }, + { name = "nilauth-credit-middleware", specifier = ">=0.1.2" }, { name = "nilrag", specifier = ">=0.1.11" }, { name = "nuc", specifier = ">=0.1.0" }, { name = "openai", specifier = ">=1.59.9" }, @@ -2328,7 +2328,7 @@ dev = [ [[package]] name = "nilauth-credit-middleware" -version = "0.1.1" +version = "0.1.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fastapi", extra = ["standard"] }, @@ -2336,9 +2336,9 @@ dependencies = [ { name = "nuc" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/cf/7716fa5f4aca83ef39d6f9f8bebc1d80d194c52c9ce6e75ee6bd1f401217/nilauth_credit_middleware-0.1.1.tar.gz", hash = "sha256:ae32c4c1e6bc083c8a7581d72a6da271ce9c0f0f9271a1694acb81ccd0a4a8bd", size = 10259, upload-time = "2025-10-16T11:15:03.918Z" } +sdist = { url = "https://files.pythonhosted.org/packages/46/bc/ae9b2c26919151fc7193b406a98831eeef197f6ec46b0c075138e66ec016/nilauth_credit_middleware-0.1.2.tar.gz", hash = "sha256:66423a4d18aba1eb5f5d47a04c8f7ae6a19ab4e34433475aa9dc1ba398483fdd", size = 11979, upload-time = "2025-10-30T16:21:20.538Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/b5/6e4090ae2ae8848d12e43f82d8d995cd1dff9de8e947cf5fb2b8a72a828e/nilauth_credit_middleware-0.1.1-py3-none-any.whl", hash = "sha256:10a0fda4ac11f51b9a5dd7b3a8fbabc0b28ff92a170a7729ac11eb15c7b37887", size = 14919, upload-time = "2025-10-16T11:15:02.201Z" }, + { url = "https://files.pythonhosted.org/packages/05/c3/73d55667aad701a64f3d1330d66c90a8c292fd19f054093ca74960aca1fb/nilauth_credit_middleware-0.1.2-py3-none-any.whl", hash = "sha256:31f3233e6706c6167b6246a4edb9a405d587eccb1399231223f95c0cdf1ce57c", size = 18121, upload-time = "2025-10-30T16:21:19.547Z" }, ] [[package]] From cd60426783169442e7329f9f7faff0ff0fe0e26a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Wed, 5 Nov 2025 13:22:13 +0100 Subject: [PATCH 2/7] feat: updated nilauth-credit image hash --- docker-compose.dev.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 33d9174c..fe4ed99d 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -94,7 +94,7 @@ services: retries: 5 nilauth-credit-server: - image: ghcr.io/nillionnetwork/nilauth-credit:sha-cb9e36a + image: ghcr.io/nillionnetwork/nilauth-credit:sha-6754a1d platform: linux/amd64 # for macOS to force running on Rosetta 2 environment: DATABASE_URL: postgresql://nilauth:nilauth_dev_password@nilauth-postgres:5432/nilauth_credit From 72344acc073625dab320c740896e6f3fa80bb1b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Thu, 13 Nov 2025 11:28:52 +0100 Subject: [PATCH 3/7] feat: unit test corrections --- changes.patch | 3200 ----------------- tests/e2e/test_chat_completions.py | 7 +- tests/e2e/test_chat_completions_http.py | 1 - tests/unit/nilai_api/auth/test_auth.py | 8 - .../routers/test_chat_completions_private.py | 8 - 5 files changed, 1 insertion(+), 3223 deletions(-) delete mode 100644 changes.patch diff --git a/changes.patch b/changes.patch deleted file mode 100644 index e2061aab..00000000 --- a/changes.patch +++ /dev/null @@ -1,3200 +0,0 @@ -diff --git a/.gitignore b/.gitignore -index f3d8ab4..7cbf1bf 100644 ---- a/.gitignore -+++ b/.gitignore -@@ -179,3 +179,5 @@ private_key.key.lock - - development-compose.yml - production-compose.yml -+ -+.vscode/ -diff --git a/QUERY_LOG_DEPENDENCY_MIGRATION.md b/QUERY_LOG_DEPENDENCY_MIGRATION.md -new file mode 100644 -index 0000000..0eaca24 ---- /dev/null -+++ b/QUERY_LOG_DEPENDENCY_MIGRATION.md -@@ -0,0 +1,253 @@ -+# QueryLog Dependency Migration Guide -+ -+## Overview -+ -+The `QueryLogManager` has been converted to a FastAPI dependency pattern using `QueryLogContext`. This provides better integration with the request lifecycle and more accurate timing metrics. -+ -+## What Changed -+ -+### Before (Static Manager) -+```python -+from nilai_api.db.logs import QueryLogManager -+ -+# Manual logging with all parameters -+await QueryLogManager.log_query( -+ userid=auth_info.user.userid, -+ model=req.model, -+ prompt_tokens=prompt_tokens, -+ completion_tokens=completion_tokens, -+ response_time_ms=response_time_ms, -+ web_search_calls=len(sources) if sources else 0, -+ was_streamed=req.stream, -+ was_multimodal=has_multimodal, -+ was_nilrag=bool(req.nilrag), -+ was_nildb=bool(auth_info.prompt_document), -+) -+``` -+ -+### After (Dependency Pattern) -+```python -+from fastapi import Depends -+from nilai_api.db.logs import QueryLogContext, get_query_log_context -+ -+@router.post("/endpoint") -+async def endpoint( -+ log_ctx: QueryLogContext = Depends(get_query_log_context), # Inject dependency -+): -+ # Set context as you go -+ log_ctx.set_user(auth_info.user.userid) -+ log_ctx.set_model(req.model) -+ -+ # ... do work ... -+ -+ # Commit at the end (calculates timing automatically) -+ await log_ctx.commit() -+``` -+ -+## Key Features -+ -+### 1. Automatic Timing Tracking -+```python -+# Context automatically tracks: -+# - Total request time (from dependency creation) -+# - Model inference time (with start_model_timing/end_model_timing) -+# - Tool execution time (with start_tool_timing/end_tool_timing) -+ -+log_ctx.start_model_timing() -+response = await model.generate() -+log_ctx.end_model_timing() -+``` -+ -+### 2. Incremental Context Building -+```python -+# Set request parameters -+log_ctx.set_request_params( -+ temperature=req.temperature, -+ max_tokens=req.max_tokens, -+ was_streamed=req.stream, -+ was_multimodal=has_multimodal, -+ was_nildb=bool(auth_info.prompt_document), -+ was_nilrag=bool(req.nilrag), -+) -+ -+# Set usage metrics (can be called multiple times, last wins) -+log_ctx.set_usage( -+ prompt_tokens=100, -+ completion_tokens=50, -+ tool_calls=2, -+ web_search_calls=1, -+) -+``` -+ -+### 3. Error Tracking -+```python -+try: -+ # ... process request ... -+except HTTPException as e: -+ log_ctx.set_error(error_code=e.status_code, error_message=str(e.detail)) -+ await log_ctx.commit() -+ raise -+``` -+ -+### 4. Safe Commit (No Breaking) -+```python -+# Commit never raises exceptions - logging failures are logged but don't break requests -+await log_ctx.commit() -+``` -+ -+## Migration Steps for `/v1/chat/completions` -+ -+### Step 1: Add Dependency to Function Signature -+ -+```python -+@router.post("/v1/chat/completions", tags=["Chat"], response_model=None) -+async def chat_completion( -+ req: ChatRequest = Body(...), -+ _rate_limit=Depends(RateLimit(...)), -+ auth_info: AuthenticationInfo = Depends(get_auth_info), -+ meter: MeteringContext = Depends(LLMMeter), -+ log_ctx: QueryLogContext = Depends(get_query_log_context), # ADD THIS -+): -+``` -+ -+### Step 2: Initialize Context Early -+ -+```python -+ # Right after validation -+ log_ctx.set_user(auth_info.user.userid) -+ log_ctx.set_model(req.model) -+ log_ctx.set_request_params( -+ temperature=req.temperature, -+ max_tokens=req.max_tokens, -+ was_streamed=req.stream, -+ was_multimodal=has_multimodal, -+ was_nildb=bool(auth_info.prompt_document), -+ was_nilrag=bool(req.nilrag), -+ ) -+``` -+ -+### Step 3: Track Model Timing -+ -+```python -+ # Before model call -+ log_ctx.start_model_timing() -+ -+ response = await client.chat.completions.create(...) -+ -+ # After model call -+ log_ctx.end_model_timing() -+``` -+ -+### Step 4: Track Tool Timing (if applicable) -+ -+```python -+ if req.tools: -+ log_ctx.start_tool_timing() -+ -+ (final_completion, agg_prompt, agg_completion) = await handle_tool_workflow(...) -+ -+ log_ctx.end_tool_timing() -+ log_ctx.set_usage(tool_calls=len(response.choices[0].message.tool_calls or [])) -+``` -+ -+### Step 5: Replace QueryLogManager.log_query() -+ -+```python -+ # OLD - Remove this: -+ await QueryLogManager.log_query( -+ auth_info.user.userid, -+ model=req.model, -+ prompt_tokens=..., -+ completion_tokens=..., -+ response_time_ms=..., -+ web_search_calls=..., -+ ) -+ -+ # NEW - Replace with: -+ log_ctx.set_usage( -+ prompt_tokens=model_response.usage.prompt_tokens, -+ completion_tokens=model_response.usage.completion_tokens, -+ web_search_calls=len(sources) if sources else 0, -+ ) -+ await log_ctx.commit() -+``` -+ -+### Step 6: Handle Streaming Case -+ -+For streaming responses, commit inside the generator: -+ -+```python -+async def chat_completion_stream_generator(): -+ try: -+ # ... streaming logic ... -+ -+ async for chunk in response: -+ if chunk.usage is not None: -+ prompt_token_usage = chunk.usage.prompt_tokens -+ completion_token_usage = chunk.usage.completion_tokens -+ # ... yield chunks ... -+ -+ # At the end of stream -+ log_ctx.set_usage( -+ prompt_tokens=prompt_token_usage, -+ completion_tokens=completion_token_usage, -+ web_search_calls=len(sources) if sources else 0, -+ ) -+ await log_ctx.commit() -+ except Exception as e: -+ log_ctx.set_error(error_code=500, error_message=str(e)) -+ await log_ctx.commit() -+ raise -+``` -+ -+## Complete Example -+ -+Here's a minimal complete example: -+ -+```python -+@router.post("/v1/chat/completions") -+async def chat_completion( -+ req: ChatRequest, -+ auth_info: AuthenticationInfo = Depends(get_auth_info), -+ log_ctx: QueryLogContext = Depends(get_query_log_context), -+): -+ # Setup -+ log_ctx.set_user(auth_info.user.userid) -+ log_ctx.set_model(req.model) -+ -+ try: -+ # Process request -+ log_ctx.start_model_timing() -+ response = await process_request(req) -+ log_ctx.end_model_timing() -+ -+ # Set usage -+ log_ctx.set_usage( -+ prompt_tokens=response.usage.prompt_tokens, -+ completion_tokens=response.usage.completion_tokens, -+ ) -+ -+ # Commit -+ await log_ctx.commit() -+ -+ return response -+ except HTTPException as e: -+ log_ctx.set_error(e.status_code, str(e.detail)) -+ await log_ctx.commit() -+ raise -+``` -+ -+## Benefits -+ -+1. ✅ **Automatic timing** - No manual time.monotonic() tracking needed -+2. ✅ **Granular metrics** - Separate model vs tool timing -+3. ✅ **Error tracking** - Built-in error code and message support -+4. ✅ **Type safety** - Full type hints throughout -+5. ✅ **Non-breaking** - Legacy `QueryLogManager.log_query()` still works -+6. ✅ **Clean separation** - Logging logic separate from business logic -+7. ✅ **Request isolation** - Each request gets its own context instance -+8. ✅ **Flexible updates** - Update metrics as you discover them during request processing -+ -+## Backward Compatibility -+ -+The old `QueryLogManager.log_query()` static method still works and is marked as "legacy support". You can migrate endpoints gradually without breaking existing functionality. -diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml -index d40200e..5784fca 100644 ---- a/docker-compose.dev.yml -+++ b/docker-compose.dev.yml -@@ -33,8 +33,6 @@ services: - condition: service_healthy - nilauth-credit-server: - condition: service_healthy -- environment: -- - POSTGRES_DB=${POSTGRES_DB_NUC} - volumes: - - ./nilai-api/:/app/nilai-api/ - - ./packages/:/app/packages/ -@@ -97,7 +95,7 @@ services: - - nilauth-credit-server: - image: ghcr.io/nillionnetwork/nilauth-credit:sha-cb9e36a -- platform: linux/amd64 # for macOS to force running on Rosetta 2 -+ # platform: linux/amd64 # for macOS to force running on Rosetta 2 - environment: - DATABASE_URL: postgresql://nilauth:nilauth_dev_password@nilauth-postgres:5432/nilauth_credit - HOST: 0.0.0.0 -diff --git a/grafana/runtime-data/dashboards/nuc-query-data.json b/grafana/runtime-data/dashboards/nuc-query-data.json -index d66fd42..c7bbb6b 100644 ---- a/grafana/runtime-data/dashboards/nuc-query-data.json -+++ b/grafana/runtime-data/dashboards/nuc-query-data.json -@@ -126,7 +126,7 @@ - "editorMode": "code", - "format": "time_series", - "rawQuery": true, -- "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", -+ "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", - "refId": "A", - "sql": { - "columns": [ -@@ -218,7 +218,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", -+ "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", - "refId": "A", - "sql": { - "columns": [ -@@ -352,7 +352,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", -+ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", - "refId": "A", - "sql": { - "columns": [ -@@ -360,7 +360,7 @@ - "alias": "\"User\"", - "parameters": [ - { -- "name": "userid", -+ "name": "user_id", - "type": "functionParameter" - } - ], -@@ -381,7 +381,7 @@ - "groupBy": [ - { - "property": { -- "name": "userid", -+ "name": "user_id", - "type": "string" - }, - "type": "groupBy" -@@ -481,7 +481,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", -+ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", - "refId": "A", - "sql": { - "columns": [ -diff --git a/grafana/runtime-data/dashboards/query-data.json b/grafana/runtime-data/dashboards/query-data.json -index 8e0b774..f33f87a 100644 ---- a/grafana/runtime-data/dashboards/query-data.json -+++ b/grafana/runtime-data/dashboards/query-data.json -@@ -126,7 +126,7 @@ - "editorMode": "code", - "format": "time_series", - "rawQuery": true, -- "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", -+ "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", - "refId": "A", - "sql": { - "columns": [ -@@ -218,7 +218,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", -+ "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", - "refId": "A", - "sql": { - "columns": [ -@@ -352,7 +352,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", -+ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", - "refId": "A", - "sql": { - "columns": [ -@@ -360,7 +360,7 @@ - "alias": "\"User\"", - "parameters": [ - { -- "name": "userid", -+ "name": "user_id", - "type": "functionParameter" - } - ], -@@ -381,7 +381,7 @@ - "groupBy": [ - { - "property": { -- "name": "userid", -+ "name": "user_id", - "type": "string" - }, - "type": "groupBy" -@@ -481,7 +481,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", -+ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", - "refId": "A", - "sql": { - "columns": [ -diff --git a/grafana/runtime-data/dashboards/testnet-nuc-query-data.json b/grafana/runtime-data/dashboards/testnet-nuc-query-data.json -index f98d70e..358ba4e 100644 ---- a/grafana/runtime-data/dashboards/testnet-nuc-query-data.json -+++ b/grafana/runtime-data/dashboards/testnet-nuc-query-data.json -@@ -126,7 +126,7 @@ - "editorMode": "code", - "format": "time_series", - "rawQuery": true, -- "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", -+ "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", - "refId": "A", - "sql": { - "columns": [ -@@ -218,7 +218,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", -+ "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", - "refId": "A", - "sql": { - "columns": [ -@@ -352,7 +352,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", -+ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", - "refId": "A", - "sql": { - "columns": [ -@@ -360,7 +360,7 @@ - "alias": "\"User\"", - "parameters": [ - { -- "name": "userid", -+ "name": "user_id", - "type": "functionParameter" - } - ], -@@ -381,7 +381,7 @@ - "groupBy": [ - { - "property": { -- "name": "userid", -+ "name": "user_id", - "type": "string" - }, - "type": "groupBy" -@@ -481,7 +481,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", -+ "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", - "refId": "A", - "sql": { - "columns": [ -diff --git a/grafana/runtime-data/dashboards/totals-data.json b/grafana/runtime-data/dashboards/totals-data.json -index 2db20c7..ff66ce0 100644 ---- a/grafana/runtime-data/dashboards/totals-data.json -+++ b/grafana/runtime-data/dashboards/totals-data.json -@@ -83,7 +83,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", -+ "rawSql": "SELECT \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", - "refId": "A", - "sql": { - "columns": [ -@@ -165,7 +165,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT SUM(total_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", -+ "rawSql": "SELECT SUM(total_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", - "refId": "A", - "sql": { - "columns": [ -@@ -248,7 +248,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT SUM(q.prompt_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", -+ "rawSql": "SELECT SUM(q.prompt_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", - "refId": "A", - "sql": { - "columns": [ -@@ -331,7 +331,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT SUM(q.completion_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", -+ "rawSql": "SELECT SUM(q.completion_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", - "refId": "A", - "sql": { - "columns": [ -@@ -397,4 +397,4 @@ - "uid": "aex54yzf0nmyoc", - "version": 1, - "weekStart": "" --} -\ No newline at end of file -+} -diff --git a/grafana/runtime-data/dashboards/usage-data.json b/grafana/runtime-data/dashboards/usage-data.json -index 88857f9..a22bf91 100644 ---- a/grafana/runtime-data/dashboards/usage-data.json -+++ b/grafana/runtime-data/dashboards/usage-data.json -@@ -299,7 +299,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE q.query_timestamp >= NOW() - INTERVAL '1 hours'\nGROUP BY u.email;", -+ "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE q.query_timestamp >= NOW() - INTERVAL '1 hours'\nGROUP BY u.email;", - "refId": "A", - "sql": { - "columns": [ -@@ -307,7 +307,7 @@ - "alias": "\"User ID\"", - "parameters": [ - { -- "name": "userid", -+ "name": "user_id", - "type": "functionParameter" - } - ], -@@ -328,7 +328,7 @@ - "groupBy": [ - { - "property": { -- "name": "userid", -+ "name": "user_id", - "type": "string" - }, - "type": "groupBy" -@@ -430,7 +430,7 @@ - "editorMode": "code", - "format": "table", - "rawQuery": true, -- "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE q.query_timestamp >= NOW() - INTERVAL '7 days'\nGROUP BY u.email;", -+ "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE q.query_timestamp >= NOW() - INTERVAL '7 days'\nGROUP BY u.email;", - "refId": "A", - "sql": { - "columns": [ -diff --git a/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py b/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py -new file mode 100644 -index 0000000..ebaca5a ---- /dev/null -+++ b/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py -@@ -0,0 +1,206 @@ -+"""chore: merged database schema updates -+ -+Revision ID: 0ba073468afc -+Revises: ea942d6c7a00 -+Create Date: 2025-10-31 09:43:12.022675 -+ -+""" -+ -+from typing import Sequence, Union -+ -+from alembic import op -+import sqlalchemy as sa -+from sqlalchemy.dialects import postgresql -+ -+# revision identifiers, used by Alembic. -+revision: str = "0ba073468afc" -+down_revision: Union[str, None] = "9ddf28cf6b6f" -+branch_labels: Union[str, Sequence[str], None] = None -+depends_on: Union[str, Sequence[str], None] = None -+ -+ -+def upgrade() -> None: -+ # ### merged commands from ea942d6c7a00 and 0ba073468afc ### -+ # query_logs: new telemetry columns (with defaults to backfill existing rows) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "tool_calls", sa.Integer(), server_default=sa.text("0"), nullable=False -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "temperature", sa.Float(), server_default=sa.text("0.9"), nullable=True -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "max_tokens", sa.Integer(), server_default=sa.text("4096"), nullable=True -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "response_time_ms", -+ sa.Integer(), -+ server_default=sa.text("-1"), -+ nullable=False, -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "model_response_time_ms", -+ sa.Integer(), -+ server_default=sa.text("-1"), -+ nullable=False, -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "tool_response_time_ms", -+ sa.Integer(), -+ server_default=sa.text("-1"), -+ nullable=False, -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "was_streamed", -+ sa.Boolean(), -+ server_default=sa.text("False"), -+ nullable=False, -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "was_multimodal", -+ sa.Boolean(), -+ server_default=sa.text("False"), -+ nullable=False, -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "was_nildb", sa.Boolean(), server_default=sa.text("False"), nullable=False -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "was_nilrag", sa.Boolean(), server_default=sa.text("False"), nullable=False -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "error_code", sa.Integer(), server_default=sa.text("200"), nullable=False -+ ), -+ ) -+ op.add_column( -+ "query_logs", -+ sa.Column( -+ "error_message", sa.Text(), server_default=sa.text("'OK'"), nullable=False -+ ), -+ ) -+ -+ # query_logs: remove FK to users.userid before dropping the column later -+ op.drop_constraint("query_logs_userid_fkey", "query_logs", type_="foreignkey") -+ -+ # query_logs: add lockid and index, drop legacy userid and its index -+ op.add_column( -+ "query_logs", sa.Column("lockid", sa.String(length=75), nullable=False) -+ ) -+ op.drop_index("ix_query_logs_userid", table_name="query_logs") -+ op.create_index( -+ op.f("ix_query_logs_lockid"), "query_logs", ["lockid"], unique=False -+ ) -+ op.drop_column("query_logs", "userid") -+ -+ # users: drop legacy token counters -+ op.drop_column("users", "prompt_tokens") -+ op.drop_column("users", "completion_tokens") -+ -+ # users: reshape identity columns and indexes -+ op.add_column("users", sa.Column("user_id", sa.String(length=75), nullable=False)) -+ op.drop_index("ix_users_apikey", table_name="users") -+ op.drop_index("ix_users_userid", table_name="users") -+ op.create_index(op.f("ix_users_user_id"), "users", ["user_id"], unique=False) -+ op.drop_column("users", "last_activity") -+ op.drop_column("users", "userid") -+ op.drop_column("users", "apikey") -+ op.drop_column("users", "signup_date") -+ op.drop_column("users", "queries") -+ op.drop_column("users", "name") -+ # ### end merged commands ### -+ -+ -+def downgrade() -> None: -+ # ### revert merged commands back to 9ddf28cf6b6f ### -+ # users: restore legacy columns and indexes -+ op.add_column("users", sa.Column("name", sa.VARCHAR(length=100), nullable=False)) -+ op.add_column("users", sa.Column("queries", sa.INTEGER(), nullable=False)) -+ op.add_column( -+ "users", -+ sa.Column( -+ "signup_date", -+ postgresql.TIMESTAMP(timezone=True), -+ server_default=sa.text("now()"), -+ nullable=False, -+ ), -+ ) -+ op.add_column("users", sa.Column("apikey", sa.VARCHAR(length=75), nullable=False)) -+ op.add_column("users", sa.Column("userid", sa.VARCHAR(length=75), nullable=False)) -+ op.add_column( -+ "users", -+ sa.Column("last_activity", postgresql.TIMESTAMP(timezone=True), nullable=True), -+ ) -+ op.drop_index(op.f("ix_users_user_id"), table_name="users") -+ op.create_index("ix_users_userid", "users", ["userid"], unique=False) -+ op.create_index("ix_users_apikey", "users", ["apikey"], unique=False) -+ op.drop_column("users", "user_id") -+ op.add_column( -+ "users", -+ sa.Column( -+ "completion_tokens", -+ sa.INTEGER(), -+ server_default=sa.text("0"), -+ nullable=False, -+ ), -+ ) -+ op.add_column( -+ "users", -+ sa.Column( -+ "prompt_tokens", sa.INTEGER(), server_default=sa.text("0"), nullable=False -+ ), -+ ) -+ -+ # query_logs: restore userid, index and FK; drop new columns -+ op.add_column( -+ "query_logs", sa.Column("userid", sa.VARCHAR(length=75), nullable=False) -+ ) -+ op.drop_index(op.f("ix_query_logs_lockid"), table_name="query_logs") -+ op.create_index("ix_query_logs_userid", "query_logs", ["userid"], unique=False) -+ op.create_foreign_key( -+ "query_logs_userid_fkey", "query_logs", "users", ["userid"], ["userid"] -+ ) -+ op.drop_column("query_logs", "lockid") -+ op.drop_column("query_logs", "error_message") -+ op.drop_column("query_logs", "error_code") -+ op.drop_column("query_logs", "was_nilrag") -+ op.drop_column("query_logs", "was_nildb") -+ op.drop_column("query_logs", "was_multimodal") -+ op.drop_column("query_logs", "was_streamed") -+ op.drop_column("query_logs", "tool_response_time_ms") -+ op.drop_column("query_logs", "model_response_time_ms") -+ op.drop_column("query_logs", "response_time_ms") -+ op.drop_column("query_logs", "max_tokens") -+ op.drop_column("query_logs", "temperature") -+ op.drop_column("query_logs", "tool_calls") -+ # ### end revert ### -diff --git a/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py b/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py -new file mode 100644 -index 0000000..4c20bb6 ---- /dev/null -+++ b/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py -@@ -0,0 +1,37 @@ -+"""fix: userid change to user_id -+ -+Revision ID: 43b23c73035b -+Revises: 0ba073468afc -+Create Date: 2025-11-03 11:33:03.006101 -+ -+""" -+ -+from typing import Sequence, Union -+ -+from alembic import op -+import sqlalchemy as sa -+ -+ -+# revision identifiers, used by Alembic. -+revision: str = "43b23c73035b" -+down_revision: Union[str, None] = "0ba073468afc" -+branch_labels: Union[str, Sequence[str], None] = None -+depends_on: Union[str, Sequence[str], None] = None -+ -+ -+def upgrade() -> None: -+ # ### commands auto generated by Alembic - please adjust! ### -+ op.add_column( -+ "query_logs", sa.Column("user_id", sa.String(length=75), nullable=False) -+ ) -+ op.create_index( -+ op.f("ix_query_logs_user_id"), "query_logs", ["user_id"], unique=False -+ ) -+ # ### end Alembic commands ### -+ -+ -+def downgrade() -> None: -+ # ### commands auto generated by Alembic - please adjust! ### -+ op.drop_index(op.f("ix_query_logs_user_id"), table_name="query_logs") -+ op.drop_column("query_logs", "user_id") -+ # ### end Alembic commands ### -diff --git a/nilai-api/examples/users.py b/nilai-api/examples/users.py -deleted file mode 100644 -index b6b206d..0000000 ---- a/nilai-api/examples/users.py -+++ /dev/null -@@ -1,43 +0,0 @@ --#!/usr/bin/python -- --from nilai_api.db.logs import QueryLogManager --from nilai_api.db.users import UserManager -- -- --# Example Usage --async def main(): -- # Add some users -- bob = await UserManager.insert_user("Bob", "bob@example.com") -- alice = await UserManager.insert_user("Alice", "alice@example.com") -- -- print(f"Bob's details: {bob}") -- print(f"Alice's details: {alice}") -- -- # Check API key -- user_name = await UserManager.check_api_key(bob.apikey) -- print(f"API key validation: {user_name}") -- -- # Update and retrieve token usage -- await UserManager.update_token_usage( -- bob.userid, prompt_tokens=50, completion_tokens=20 -- ) -- usage = await UserManager.get_user_token_usage(bob.userid) -- print(f"Bob's token usage: {usage}") -- -- # Log a query -- await QueryLogManager.log_query( -- userid=bob.userid, -- model="gpt-3.5-turbo", -- prompt_tokens=8, -- completion_tokens=7, -- web_search_calls=1, -- ) -- -- --if __name__ == "__main__": -- import asyncio -- from dotenv import load_dotenv -- -- load_dotenv() -- -- asyncio.run(main()) -diff --git a/nilai-api/pyproject.toml b/nilai-api/pyproject.toml -index 0bbfba3..9caae2a 100644 ---- a/nilai-api/pyproject.toml -+++ b/nilai-api/pyproject.toml -@@ -35,7 +35,7 @@ dependencies = [ - "trafilatura>=1.7.0", - "secretvaults", - "e2b-code-interpreter>=1.0.3", -- "nilauth-credit-middleware>=0.1.1", -+ "nilauth-credit-middleware>=0.1.2", - ] - - -diff --git a/nilai-api/src/nilai_api/auth/__init__.py b/nilai-api/src/nilai_api/auth/__init__.py -index 2e7cd6f..2123685 100644 ---- a/nilai-api/src/nilai_api/auth/__init__.py -+++ b/nilai-api/src/nilai_api/auth/__init__.py -@@ -4,7 +4,6 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer - from logging import getLogger - - from nilai_api.config import CONFIG --from nilai_api.db.users import UserManager - from nilai_api.auth.strategies import AuthenticationStrategy - - from nuc.validate import ValidationException -@@ -36,7 +35,6 @@ async def get_auth_info( - ) - - auth_info = await strategy(credentials.credentials) -- await UserManager.update_last_activity(userid=auth_info.user.userid) - return auth_info - except AuthenticationError as e: - raise e -diff --git a/nilai-api/src/nilai_api/auth/nuc.py b/nilai-api/src/nilai_api/auth/nuc.py -index 4645935..614d9ef 100644 ---- a/nilai-api/src/nilai_api/auth/nuc.py -+++ b/nilai-api/src/nilai_api/auth/nuc.py -@@ -86,11 +86,11 @@ def validate_nuc(nuc_token: str) -> Tuple[str, str]: - - # Validate the - # Return the subject of the token, the subscription holder -- subscription_holder = token.subject.public_key.hex() -- user = token.issuer.public_key.hex() -+ subscription_holder = token.subject -+ user = token.issuer - logger.info(f"Subscription holder: {subscription_holder}") - logger.info(f"User: {user}") -- return subscription_holder, user -+ return str(subscription_holder), str(user) - - - def get_token_rate_limit(nuc_token: str) -> Optional[TokenRateLimits]: -diff --git a/nilai-api/src/nilai_api/auth/strategies.py b/nilai-api/src/nilai_api/auth/strategies.py -index 9917ee3..089e7e9 100644 ---- a/nilai-api/src/nilai_api/auth/strategies.py -+++ b/nilai-api/src/nilai_api/auth/strategies.py -@@ -1,6 +1,6 @@ - from typing import Callable, Awaitable, Optional --from datetime import datetime, timezone - -+from fastapi import HTTPException - from nilai_api.db.users import UserManager, UserModel, UserData - from nilai_api.auth.nuc import ( - validate_nuc, -@@ -11,11 +11,18 @@ from nilai_api.config import CONFIG - from nilai_api.auth.common import ( - PromptDocument, - TokenRateLimits, -- AuthenticationInfo, - AuthenticationError, -+ AuthenticationInfo, -+) -+ -+from nilauth_credit_middleware import ( -+ CreditClientSingleton, - ) -+from nilauth_credit_middleware.api_model import ValidateCredentialResponse -+ - - from enum import Enum -+ - # All strategies must return a UserModel - # The strategies can raise any exception, which will be caught and converted to an AuthenticationError - # The exception detail will be passed to the client -@@ -44,18 +51,10 @@ def allow_token( - return await function(token) - - if token == allowed_token: -- user_model: UserModel | None = await UserManager.check_user( -- allowed_token -+ user_model = UserModel( -+ user_id=allowed_token, -+ rate_limits=None, - ) -- if user_model is None: -- user_model = UserModel( -- userid=allowed_token, -- name=allowed_token, -- apikey=allowed_token, -- signup_date=datetime.now(timezone.utc), -- ) -- await UserManager.insert_user_model(user_model) -- - return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), - token_rate_limit=None, -@@ -68,16 +67,41 @@ def allow_token( - return decorator - - -+async def validate_credential(credential: str, is_public: bool) -> UserModel: -+ """ -+ Validate a credential with nilauth credit middleware and return the user model -+ """ -+ credit_client = CreditClientSingleton.get_client() -+ try: -+ validate_response: ValidateCredentialResponse = ( -+ await credit_client.validate_credential(credential, is_public=is_public) -+ ) -+ except HTTPException as e: -+ if e.status_code == 404: -+ raise AuthenticationError(f"Credential not found: {e.detail}") -+ elif e.status_code == 401: -+ raise AuthenticationError(f"Credential is inactive: {e.detail}") -+ else: -+ raise AuthenticationError(f"Failed to validate credential: {e.detail}") -+ -+ user_model = await UserManager.check_user(validate_response.user_id) -+ if user_model is None: -+ user_model = UserModel( -+ user_id=validate_response.user_id, -+ rate_limits=None, -+ ) -+ return user_model -+ -+ - @allow_token(CONFIG.docs.token) - async def api_key_strategy(api_key: str) -> AuthenticationInfo: -- user_model: Optional[UserModel] = await UserManager.check_api_key(api_key) -- if user_model: -- return AuthenticationInfo( -- user=UserData.from_sqlalchemy(user_model), -- token_rate_limit=None, -- prompt_document=None, -- ) -- raise AuthenticationError("Missing or invalid API key") -+ user_model = await validate_credential(api_key, is_public=False) -+ -+ return AuthenticationInfo( -+ user=UserData.from_sqlalchemy(user_model), -+ token_rate_limit=None, -+ prompt_document=None, -+ ) - - - @allow_token(CONFIG.docs.token) -@@ -89,20 +113,7 @@ async def nuc_strategy(nuc_token) -> AuthenticationInfo: - token_rate_limits: Optional[TokenRateLimits] = get_token_rate_limit(nuc_token) - prompt_document: Optional[PromptDocument] = get_token_prompt_document(nuc_token) - -- user_model: Optional[UserModel] = await UserManager.check_user(user) -- if user_model: -- return AuthenticationInfo( -- user=UserData.from_sqlalchemy(user_model), -- token_rate_limit=token_rate_limits, -- prompt_document=prompt_document, -- ) -- -- user_model = UserModel( -- userid=user, -- name=user, -- apikey=subscription_holder, -- ) -- await UserManager.insert_user_model(user_model) -+ user_model = await validate_credential(subscription_holder, is_public=True) - return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), - token_rate_limit=token_rate_limits, -diff --git a/nilai-api/src/nilai_api/commands/add_user.py b/nilai-api/src/nilai_api/commands/add_user.py -index e9f49e5..5bd488b 100644 ---- a/nilai-api/src/nilai_api/commands/add_user.py -+++ b/nilai-api/src/nilai_api/commands/add_user.py -@@ -6,9 +6,7 @@ import click - - - @click.command() --@click.option("--name", type=str, required=True, help="User Name") --@click.option("--apikey", type=str, help="API Key") --@click.option("--userid", type=str, help="User Id") -+@click.option("--user_id", type=str, help="User Id") - @click.option("--ratelimit-day", type=int, help="number of request per day") - @click.option("--ratelimit-hour", type=int, help="number of request per hour") - @click.option("--ratelimit-minute", type=int, help="number of request per minute") -@@ -26,9 +24,7 @@ import click - help="number of web search request per minute", - ) - def main( -- name, -- apikey: str | None, -- userid: str | None, -+ user_id: str | None, - ratelimit_day: int | None, - ratelimit_hour: int | None, - ratelimit_minute: int | None, -@@ -38,9 +34,7 @@ def main( - ): - async def add_user(): - user: UserModel = await UserManager.insert_user( -- name, -- apikey, -- userid, -+ user_id, - RateLimits( - user_rate_limit_day=ratelimit_day, - user_rate_limit_hour=ratelimit_hour, -@@ -52,7 +46,7 @@ def main( - ) - json_user = json.dumps( - { -- "userid": user.userid, -+ "user_id": user.user_id, - "name": user.name, - "apikey": user.apikey, - "ratelimit_day": user.rate_limits_obj.user_rate_limit_day, -diff --git a/nilai-api/src/nilai_api/config/__init__.py b/nilai-api/src/nilai_api/config/__init__.py -index 1939c74..b61a9fe 100644 ---- a/nilai-api/src/nilai_api/config/__init__.py -+++ b/nilai-api/src/nilai_api/config/__init__.py -@@ -64,3 +64,4 @@ __all__ = [ - ] - - logging.info(CONFIG.prettify()) -+print(CONFIG.prettify()) -diff --git a/nilai-api/src/nilai_api/credit.py b/nilai-api/src/nilai_api/credit.py -index 3a06135..b9d7ea6 100644 ---- a/nilai-api/src/nilai_api/credit.py -+++ b/nilai-api/src/nilai_api/credit.py -@@ -20,6 +20,9 @@ logger = logging.getLogger(__name__) - class NoOpMeteringContext: - """A no-op metering context for requests that should skip metering (e.g., Docs Token).""" - -+ def __init__(self): -+ self.lock_id: str = "noop-lock-id" -+ - def set_response(self, response_data: dict) -> None: - """No-op method that does nothing.""" - pass -diff --git a/nilai-api/src/nilai_api/db/logs.py b/nilai-api/src/nilai_api/db/logs.py -index 030c869..4a78c8a 100644 ---- a/nilai-api/src/nilai_api/db/logs.py -+++ b/nilai-api/src/nilai_api/db/logs.py -@@ -1,12 +1,14 @@ - import logging -+import time - from datetime import datetime, timezone -+from typing import Optional - -+from nilai_common import Usage - import sqlalchemy - --from sqlalchemy import ForeignKey, Integer, String, DateTime, Text -+from sqlalchemy import Integer, String, DateTime, Text, Boolean, Float - from sqlalchemy.exc import SQLAlchemyError - from nilai_api.db import Base, Column, get_db_session --from nilai_api.db.users import UserModel - - logger = logging.getLogger(__name__) - -@@ -16,9 +18,8 @@ class QueryLog(Base): - __tablename__ = "query_logs" - - id: int = Column(Integer, primary_key=True, autoincrement=True) # type: ignore -- userid: str = Column( -- String(75), ForeignKey(UserModel.userid), nullable=False, index=True -- ) # type: ignore -+ user_id: str = Column(String(75), nullable=False, index=True) # type: ignore -+ lockid: str = Column(String(75), nullable=False, index=True) # type: ignore - query_timestamp: datetime = Column( - DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False - ) # type: ignore -@@ -26,51 +27,285 @@ class QueryLog(Base): - prompt_tokens: int = Column(Integer, nullable=False) # type: ignore - completion_tokens: int = Column(Integer, nullable=False) # type: ignore - total_tokens: int = Column(Integer, nullable=False) # type: ignore -+ tool_calls: int = Column(Integer, nullable=False) # type: ignore - web_search_calls: int = Column(Integer, nullable=False) # type: ignore -+ temperature: Optional[float] = Column(Float, nullable=True) # type: ignore -+ max_tokens: Optional[int] = Column(Integer, nullable=True) # type: ignore -+ -+ response_time_ms: int = Column(Integer, nullable=False) # type: ignore -+ model_response_time_ms: int = Column(Integer, nullable=False) # type: ignore -+ tool_response_time_ms: int = Column(Integer, nullable=False) # type: ignore -+ -+ was_streamed: bool = Column(Boolean, nullable=False) # type: ignore -+ was_multimodal: bool = Column(Boolean, nullable=False) # type: ignore -+ was_nildb: bool = Column(Boolean, nullable=False) # type: ignore -+ was_nilrag: bool = Column(Boolean, nullable=False) # type: ignore -+ -+ error_code: int = Column(Integer, nullable=False) # type: ignore -+ error_message: str = Column(Text, nullable=False) # type: ignore - - def __repr__(self): -- return f"" -+ return f"" -+ -+ -+class QueryLogContext: -+ """ -+ Context manager for logging query metrics during a request. -+ Used as a FastAPI dependency to track request metrics. -+ """ -+ -+ def __init__(self): -+ self.user_id: Optional[str] = None -+ self.lockid: Optional[str] = None -+ self.model: Optional[str] = None -+ self.prompt_tokens: int = 0 -+ self.completion_tokens: int = 0 -+ self.tool_calls: int = 0 -+ self.web_search_calls: int = 0 -+ self.temperature: Optional[float] = None -+ self.max_tokens: Optional[int] = None -+ self.was_streamed: bool = False -+ self.was_multimodal: bool = False -+ self.was_nildb: bool = False -+ self.was_nilrag: bool = False -+ self.error_code: int = 0 -+ self.error_message: str = "" -+ -+ # Timing tracking -+ self.start_time: float = time.monotonic() -+ self.model_start_time: Optional[float] = None -+ self.model_end_time: Optional[float] = None -+ self.tool_start_time: Optional[float] = None -+ self.tool_end_time: Optional[float] = None -+ -+ def set_user(self, user_id: str) -> None: -+ """Set the user ID for this query.""" -+ self.user_id = user_id -+ -+ def set_lockid(self, lockid: str) -> None: -+ """Set the lock ID for this query.""" -+ self.lockid = lockid -+ -+ def set_model(self, model: str) -> None: -+ """Set the model name for this query.""" -+ self.model = model -+ -+ def set_request_params( -+ self, -+ temperature: Optional[float] = None, -+ max_tokens: Optional[int] = None, -+ was_streamed: bool = False, -+ was_multimodal: bool = False, -+ was_nildb: bool = False, -+ was_nilrag: bool = False, -+ ) -> None: -+ """Set request parameters.""" -+ self.temperature = temperature -+ self.max_tokens = max_tokens -+ self.was_streamed = was_streamed -+ self.was_multimodal = was_multimodal -+ self.was_nildb = was_nildb -+ self.was_nilrag = was_nilrag -+ -+ def set_usage( -+ self, -+ prompt_tokens: int = 0, -+ completion_tokens: int = 0, -+ tool_calls: int = 0, -+ web_search_calls: int = 0, -+ ) -> None: -+ """Set token usage and feature usage.""" -+ self.prompt_tokens = prompt_tokens -+ self.completion_tokens = completion_tokens -+ self.tool_calls = tool_calls -+ self.web_search_calls = web_search_calls -+ -+ def set_error(self, error_code: int, error_message: str) -> None: -+ """Set error information.""" -+ self.error_code = error_code -+ self.error_message = error_message -+ -+ def start_model_timing(self) -> None: -+ """Mark the start of model inference.""" -+ self.model_start_time = time.monotonic() -+ -+ def end_model_timing(self) -> None: -+ """Mark the end of model inference.""" -+ self.model_end_time = time.monotonic() -+ -+ def start_tool_timing(self) -> None: -+ """Mark the start of tool execution.""" -+ self.tool_start_time = time.monotonic() -+ -+ def end_tool_timing(self) -> None: -+ """Mark the end of tool execution.""" -+ self.tool_end_time = time.monotonic() -+ -+ def _calculate_timings(self) -> tuple[int, int, int]: -+ """Calculate response times in milliseconds.""" -+ total_ms = int((time.monotonic() - self.start_time) * 1000) -+ -+ model_ms = 0 -+ if self.model_start_time and self.model_end_time: -+ model_ms = int((self.model_end_time - self.model_start_time) * 1000) -+ -+ tool_ms = 0 -+ if self.tool_start_time and self.tool_end_time: -+ tool_ms = int((self.tool_end_time - self.tool_start_time) * 1000) -+ -+ return total_ms, model_ms, tool_ms -+ -+ async def commit(self) -> None: -+ """ -+ Commit the query log to the database. -+ Should be called at the end of the request lifecycle. -+ """ -+ if not self.user_id or not self.model: -+ logger.warning( -+ "Skipping query log: user_id or model not set " -+ f"(user_id={self.user_id}, model={self.model})" -+ ) -+ return -+ -+ total_ms, model_ms, tool_ms = self._calculate_timings() -+ total_tokens = self.prompt_tokens + self.completion_tokens -+ -+ try: -+ async with get_db_session() as session: -+ query_log = QueryLog( -+ user_id=self.user_id, -+ lockid=self.lockid, -+ model=self.model, -+ prompt_tokens=self.prompt_tokens, -+ completion_tokens=self.completion_tokens, -+ total_tokens=total_tokens, -+ tool_calls=self.tool_calls, -+ web_search_calls=self.web_search_calls, -+ temperature=self.temperature, -+ max_tokens=self.max_tokens, -+ query_timestamp=datetime.now(timezone.utc), -+ response_time_ms=total_ms, -+ model_response_time_ms=model_ms, -+ tool_response_time_ms=tool_ms, -+ was_streamed=self.was_streamed, -+ was_multimodal=self.was_multimodal, -+ was_nilrag=self.was_nilrag, -+ was_nildb=self.was_nildb, -+ error_code=self.error_code, -+ error_message=self.error_message, -+ ) -+ session.add(query_log) -+ await session.commit() -+ logger.info( -+ f"Query logged for user {self.user_id}: model={self.model}, " -+ f"tokens={total_tokens}, total_ms={total_ms}" -+ ) -+ except SQLAlchemyError as e: -+ logger.error(f"Error logging query: {e}") -+ # Don't raise - logging failure shouldn't break the request - - - class QueryLogManager: -+ """Static methods for direct query logging (legacy support).""" -+ - @staticmethod - async def log_query( -- userid: str, -+ user_id: str, -+ lockid: str, - model: str, - prompt_tokens: int, - completion_tokens: int, -+ response_time_ms: int, - web_search_calls: int, -+ was_streamed: bool, -+ was_multimodal: bool, -+ was_nilrag: bool, -+ was_nildb: bool, -+ tool_calls: int = 0, -+ temperature: float = 1.0, -+ max_tokens: int = 0, -+ model_response_time_ms: int = 0, -+ tool_response_time_ms: int = 0, -+ error_code: int = 0, -+ error_message: str = "", - ): - """ -- Log a user's query. -- -- Args: -- userid (str): User's unique ID -- model (str): The model that generated the response -- prompt_tokens (int): Number of input tokens used -- completion_tokens (int): Number of tokens in the generated response -+ Log a user's query (legacy method). -+ Consider using QueryLogContext as a dependency instead. - """ - total_tokens = prompt_tokens + completion_tokens - - try: - async with get_db_session() as session: - query_log = QueryLog( -- userid=userid, -+ user_id=user_id, -+ lockid=lockid, - model=model, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, -- query_timestamp=datetime.now(timezone.utc), -+ tool_calls=tool_calls, - web_search_calls=web_search_calls, -+ temperature=temperature, -+ max_tokens=max_tokens, -+ query_timestamp=datetime.now(timezone.utc), -+ response_time_ms=response_time_ms, -+ model_response_time_ms=model_response_time_ms, -+ tool_response_time_ms=tool_response_time_ms, -+ was_streamed=was_streamed, -+ was_multimodal=was_multimodal, -+ was_nilrag=was_nilrag, -+ was_nildb=was_nildb, -+ error_code=error_code, -+ error_message=error_message, - ) - session.add(query_log) - await session.commit() - logger.info( -- f"Query logged for user {userid} with total tokens {total_tokens}." -+ f"Query logged for user {user_id} with total tokens {total_tokens}." - ) - except SQLAlchemyError as e: - logger.error(f"Error logging query: {e}") - raise - -+ @staticmethod -+ async def get_user_token_usage(user_id: str) -> Optional[Usage]: -+ """ -+ Get aggregated token usage for a specific user using server-side SQL aggregation. -+ This is more efficient than fetching all records and calculating in Python. -+ """ -+ try: -+ async with get_db_session() as session: -+ # Use SQL aggregation functions to calculate on the database server -+ query = ( -+ sqlalchemy.select( -+ sqlalchemy.func.coalesce( -+ sqlalchemy.func.sum(QueryLog.prompt_tokens), 0 -+ ).label("prompt_tokens"), -+ sqlalchemy.func.coalesce( -+ sqlalchemy.func.sum(QueryLog.completion_tokens), 0 -+ ).label("completion_tokens"), -+ sqlalchemy.func.coalesce( -+ sqlalchemy.func.sum(QueryLog.total_tokens), 0 -+ ).label("total_tokens"), -+ sqlalchemy.func.count().label("queries"), -+ ).where(QueryLog.user_id == user_id) # type: ignore[arg-type] -+ ) -+ -+ result = await session.execute(query) -+ row = result.one_or_none() -+ -+ if row is None: -+ return None -+ -+ return Usage( -+ prompt_tokens=int(row.prompt_tokens), -+ completion_tokens=int(row.completion_tokens), -+ total_tokens=int(row.total_tokens), -+ ) -+ except SQLAlchemyError as e: -+ logger.error(f"Error getting token usage: {e}") -+ return None -+ - --__all__ = ["QueryLogManager", "QueryLog"] -+__all__ = ["QueryLogManager", "QueryLog", "QueryLogContext"] -diff --git a/nilai-api/src/nilai_api/db/users.py b/nilai-api/src/nilai_api/db/users.py -index 515ba38..e475c42 100644 ---- a/nilai-api/src/nilai_api/db/users.py -+++ b/nilai-api/src/nilai_api/db/users.py -@@ -2,11 +2,10 @@ import logging - import uuid - from pydantic import BaseModel, ConfigDict, Field - --from datetime import datetime, timezone --from typing import Any, Dict, List, Optional -+from typing import Optional - - import sqlalchemy --from sqlalchemy import Integer, String, DateTime, JSON -+from sqlalchemy import String, JSON - from sqlalchemy.exc import SQLAlchemyError - - from nilai_api.db import Base, Column, get_db_session -@@ -57,21 +56,11 @@ class RateLimits(BaseModel): - # Enhanced User Model with additional constraints and validation - class UserModel(Base): - __tablename__ = "users" -- -- userid: str = Column(String(75), primary_key=True, index=True) # type: ignore -- name: str = Column(String(100), nullable=False) # type: ignore -- apikey: str = Column(String(75), unique=False, nullable=False, index=True) # type: ignore -- prompt_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore -- completion_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore -- queries: int = Column(Integer, default=0, nullable=False) # type: ignore -- signup_date: datetime = Column( -- DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False -- ) # type: ignore -- last_activity: datetime = Column(DateTime(timezone=True), nullable=True) # type: ignore -+ user_id: str = Column(String(75), primary_key=True, index=True) # type: ignore - rate_limits: dict = Column(JSON, nullable=True) # type: ignore - - def __repr__(self): -- return f"" -+ return f"" - - @property - def rate_limits_obj(self) -> RateLimits: -@@ -85,14 +74,7 @@ class UserModel(Base): - - - class UserData(BaseModel): -- userid: str -- name: str -- apikey: str -- prompt_tokens: int = 0 -- completion_tokens: int = 0 -- queries: int = 0 -- signup_date: datetime -- last_activity: Optional[datetime] = None -+ user_id: str # apikey or subscription holder public key - rate_limits: RateLimits = Field(default_factory=RateLimits().get_effective_limits) - - model_config = ConfigDict(from_attributes=True) -@@ -100,21 +82,10 @@ class UserData(BaseModel): - @classmethod - def from_sqlalchemy(cls, user: UserModel) -> "UserData": - return cls( -- userid=user.userid, -- name=user.name, -- apikey=user.apikey, -- prompt_tokens=user.prompt_tokens or 0, -- completion_tokens=user.completion_tokens or 0, -- queries=user.queries or 0, -- signup_date=user.signup_date or datetime.now(timezone.utc), -- last_activity=user.last_activity, -+ user_id=user.user_id, - rate_limits=user.rate_limits_obj, - ) - -- @property -- def is_subscription_owner(self): -- return self.userid == self.apikey -- - - class UserManager: - @staticmethod -@@ -127,31 +98,9 @@ class UserManager: - """Generate a unique API key.""" - return str(uuid.uuid4()) - -- @staticmethod -- async def update_last_activity(userid: str): -- """ -- Update the last activity timestamp for a user. -- -- Args: -- userid (str): User's unique ID -- """ -- try: -- async with get_db_session() as session: -- user = await session.get(UserModel, userid) -- if user: -- user.last_activity = datetime.now(timezone.utc) -- await session.commit() -- logger.info(f"Updated last activity for user {userid}") -- else: -- logger.warning(f"User {userid} not found") -- except SQLAlchemyError as e: -- logger.error(f"Error updating last activity: {e}") -- - @staticmethod - async def insert_user( -- name: str, -- apikey: str | None = None, -- userid: str | None = None, -+ user_id: str | None = None, - rate_limits: RateLimits | None = None, - ) -> UserModel: - """ -@@ -160,19 +109,16 @@ class UserManager: - Args: - name (str): Name of the user - apikey (str): API key for the user -- userid (str): Unique ID for the user -+ user_id (str): Unique ID for the user - rate_limits (RateLimits): Rate limit configuration - - Returns: - UserModel: The created user model - """ -- userid = userid if userid else UserManager.generate_user_id() -- apikey = apikey if apikey else UserManager.generate_api_key() -+ user_id = user_id if user_id else UserManager.generate_user_id() - - user = UserModel( -- userid=userid, -- name=name, -- apikey=apikey, -+ user_id=user_id, - rate_limits=rate_limits.model_dump() if rate_limits else None, - ) - return await UserManager.insert_user_model(user) -@@ -189,35 +135,14 @@ class UserManager: - async with get_db_session() as session: - session.add(user) - await session.commit() -- logger.info(f"User {user.name} added successfully.") -+ logger.info(f"User {user.user_id} added successfully.") - return user - except SQLAlchemyError as e: - logger.error(f"Error inserting user: {e}") - raise - - @staticmethod -- async def check_user(userid: str) -> Optional[UserModel]: -- """ -- Validate a user. -- -- Args: -- userid (str): User ID to validate -- -- Returns: -- User's name if user is valid, None otherwise -- """ -- try: -- async with get_db_session() as session: -- query = sqlalchemy.select(UserModel).filter(UserModel.userid == userid) # type: ignore -- user = await session.execute(query) -- user = user.scalar_one_or_none() -- return user -- except SQLAlchemyError as e: -- logger.error(f"Error checking API key: {e}") -- return None -- -- @staticmethod -- async def check_api_key(api_key: str) -> Optional[UserModel]: -+ async def check_user(user_id: str) -> Optional[UserModel]: - """ - Validate an API key. - -@@ -225,118 +150,27 @@ class UserManager: - api_key (str): API key to validate - - Returns: -- User's name if API key is valid, None otherwise -+ User's rate limits if user id is valid, None otherwise - """ - try: - async with get_db_session() as session: -- query = sqlalchemy.select(UserModel).filter(UserModel.apikey == api_key) # type: ignore -+ query = sqlalchemy.select(UserModel).filter( -+ UserModel.user_id == user_id # type: ignore -+ ) - user = await session.execute(query) - user = user.scalar_one_or_none() - return user - except SQLAlchemyError as e: -- logger.error(f"Error checking API key: {e}") -- return None -- -- @staticmethod -- async def update_token_usage( -- userid: str, prompt_tokens: int, completion_tokens: int -- ): -- """ -- Update token usage for a specific user. -- -- Args: -- userid (str): User's unique ID -- prompt_tokens (int): Number of input tokens -- completion_tokens (int): Number of generated tokens -- """ -- try: -- async with get_db_session() as session: -- user = await session.get(UserModel, userid) -- if user: -- user.prompt_tokens += prompt_tokens -- user.completion_tokens += completion_tokens -- user.queries += 1 -- await session.commit() -- logger.info(f"Updated token usage for user {userid}") -- else: -- logger.warning(f"User {userid} not found") -- except SQLAlchemyError as e: -- logger.error(f"Error updating token usage: {e}") -- -- @staticmethod -- async def get_token_usage(userid: str) -> Optional[Dict[str, Any]]: -- """ -- Get token usage for a specific user. -- -- Args: -- userid (str): User's unique ID -- """ -- try: -- async with get_db_session() as session: -- user = await session.get(UserModel, userid) -- if user: -- return { -- "prompt_tokens": user.prompt_tokens, -- "completion_tokens": user.completion_tokens, -- "total_tokens": user.prompt_tokens + user.completion_tokens, -- "queries": user.queries, -- } -- else: -- logger.warning(f"User {userid} not found") -- return None -- except SQLAlchemyError as e: -- logger.error(f"Error updating token usage: {e}") -- return None -- -- @staticmethod -- async def get_all_users() -> Optional[List[UserData]]: -- """ -- Retrieve all users from the database. -- -- Returns: -- List of UserData or None if no users found -- """ -- try: -- async with get_db_session() as session: -- users = await session.execute(sqlalchemy.select(UserModel)) -- users = users.scalars().all() -- return [UserData.from_sqlalchemy(user) for user in users] -- except SQLAlchemyError as e: -- logger.error(f"Error retrieving all users: {e}") -- return None -- -- @staticmethod -- async def get_user_token_usage(userid: str) -> Optional[Dict[str, int]]: -- """ -- Retrieve total token usage for a user. -- -- Args: -- userid (str): User's unique ID -- -- Returns: -- Dict of token usage or None if user not found -- """ -- try: -- async with get_db_session() as session: -- user = await session.get(UserModel, userid) -- if user: -- return { -- "prompt_tokens": user.prompt_tokens, -- "completion_tokens": user.completion_tokens, -- "queries": user.queries, -- } -- return None -- except SQLAlchemyError as e: -- logger.error(f"Error retrieving token usage: {e}") -+ logger.error(f"Rate limit checking user id: {e}") - return None - - @staticmethod -- async def update_rate_limits(userid: str, rate_limits: RateLimits) -> bool: -+ async def update_rate_limits(user_id: str, rate_limits: RateLimits) -> bool: - """ - Update rate limits for a specific user. - - Args: -- userid (str): User's unique ID -+ user_id (str): User's unique ID - rate_limits (RateLimits): New rate limit configuration - - Returns: -@@ -344,14 +178,14 @@ class UserManager: - """ - try: - async with get_db_session() as session: -- user = await session.get(UserModel, userid) -+ user = await session.get(UserModel, user_id) - if user: - user.rate_limits = rate_limits.model_dump() - await session.commit() -- logger.info(f"Updated rate limits for user {userid}") -+ logger.info(f"Updated rate limits for user {user_id}") - return True - else: -- logger.warning(f"User {userid} not found") -+ logger.warning(f"User {user_id} not found") - return False - except SQLAlchemyError as e: - logger.error(f"Error updating rate limits: {e}") -diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py -index 8205b55..347f162 100644 ---- a/nilai-api/src/nilai_api/rate_limiting.py -+++ b/nilai-api/src/nilai_api/rate_limiting.py -@@ -53,7 +53,7 @@ async def _extract_coroutine_result(maybe_future, request: Request): - - - class UserRateLimits(BaseModel): -- subscription_holder: str -+ user_id: str - token_rate_limit: TokenRateLimits | None - rate_limits: RateLimits - -@@ -61,14 +61,13 @@ class UserRateLimits(BaseModel): - def get_user_limits( - auth_info: Annotated[AuthenticationInfo, Depends(get_auth_info)], - ) -> UserRateLimits: -- # TODO: When the only allowed strategy is NUC, we can change the apikey name to subscription_holder -- # In apikey mode, the apikey is unique as the userid. -- # In nuc mode, the apikey is associated with a subscription holder and the userid is the user -+ # In apikey mode, the apikey is unique as the user_id. -+ # In nuc mode, the apikey is associated with a subscription holder and the user_id is the user - # For NUCs we want the rate limit to be per subscription holder, not per user -- # In JWT mode, the apikey is the userid too -+ # In JWT mode, the apikey is the user_id too - # So we use the apikey as the id - return UserRateLimits( -- subscription_holder=auth_info.user.apikey, -+ user_id=auth_info.user.user_id, - token_rate_limit=auth_info.token_rate_limit, - rate_limits=auth_info.user.rate_limits, - ) -@@ -106,21 +105,21 @@ class RateLimit: - await self.check_bucket( - redis, - redis_rate_limit_command, -- f"minute:{user_limits.subscription_holder}", -+ f"minute:{user_limits.user_id}", - user_limits.rate_limits.user_rate_limit_minute, - MINUTE_MS, - ) - await self.check_bucket( - redis, - redis_rate_limit_command, -- f"hour:{user_limits.subscription_holder}", -+ f"hour:{user_limits.user_id}", - user_limits.rate_limits.user_rate_limit_hour, - HOUR_MS, - ) - await self.check_bucket( - redis, - redis_rate_limit_command, -- f"day:{user_limits.subscription_holder}", -+ f"day:{user_limits.user_id}", - user_limits.rate_limits.user_rate_limit_day, - DAY_MS, - ) -@@ -128,7 +127,7 @@ class RateLimit: - await self.check_bucket( - redis, - redis_rate_limit_command, -- f"user:{user_limits.subscription_holder}", -+ f"user:{user_limits.user_id}", - user_limits.rate_limits.user_rate_limit, - 0, # No expiration for for-good rate limit - ) -@@ -176,7 +175,7 @@ class RateLimit: - await self.check_bucket( - redis, - redis_rate_limit_command, -- f"web_search:{user_limits.subscription_holder}", -+ f"web_search:{user_limits.user_id}", - user_limits.rate_limits.web_search_rate_limit, - 0, # No expiration for for-good rate limit - ) -@@ -199,7 +198,7 @@ class RateLimit: - await self.check_bucket( - redis, - redis_rate_limit_command, -- f"web_search_{time_unit}:{user_limits.subscription_holder}", -+ f"web_search_{time_unit}:{user_limits.user_id}", - limit, - milliseconds, - ) -diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py -index db75067..038f8db 100644 ---- a/nilai-api/src/nilai_api/routers/private.py -+++ b/nilai-api/src/nilai_api/routers/private.py -@@ -11,13 +11,20 @@ from nilai_api.handlers.nilrag import handle_nilrag - from nilai_api.handlers.web_search import handle_web_search - from nilai_api.handlers.tools.tool_router import handle_tool_workflow - --from fastapi import APIRouter, Body, Depends, HTTPException, status, Request -+from fastapi import ( -+ APIRouter, -+ BackgroundTasks, -+ Body, -+ Depends, -+ HTTPException, -+ status, -+ Request, -+) - from fastapi.responses import StreamingResponse - from nilai_api.auth import get_auth_info, AuthenticationInfo - from nilai_api.config import CONFIG - from nilai_api.crypto import sign_message --from nilai_api.db.logs import QueryLogManager --from nilai_api.db.users import UserManager -+from nilai_api.db.logs import QueryLogContext, QueryLogManager - from nilai_api.rate_limiting import RateLimit - from nilai_api.state import state - -@@ -53,14 +60,10 @@ router = APIRouter() - @router.get("/v1/delegation") - async def get_prompt_store_delegation( - prompt_delegation_request: PromptDelegationRequest, -- auth_info: AuthenticationInfo = Depends(get_auth_info), -+ _: AuthenticationInfo = Depends( -+ get_auth_info -+ ), # This is to satisfy that the user is authenticated - ) -> PromptDelegationToken: -- if not auth_info.user.is_subscription_owner: -- raise HTTPException( -- status_code=status.HTTP_403_FORBIDDEN, -- detail=f"Prompt storage is reserved to subscription owners: {auth_info.user} is not a subscription owner, apikey: {auth_info.user}", -- ) -- - try: - return await get_nildb_delegation_token(prompt_delegation_request) - except Exception as e: -@@ -84,12 +87,15 @@ async def get_usage(auth_info: AuthenticationInfo = Depends(get_auth_info)) -> U - usage = await get_usage(user) - ``` - """ -- return Usage( -- prompt_tokens=auth_info.user.prompt_tokens, -- completion_tokens=auth_info.user.completion_tokens, -- total_tokens=auth_info.user.prompt_tokens + auth_info.user.completion_tokens, -- queries=auth_info.user.queries, # type: ignore # FIXME this field is not part of Usage -+ user_usage: Optional[Usage] = await QueryLogManager.get_user_token_usage( -+ auth_info.user.user_id - ) -+ if user_usage is None: -+ raise HTTPException( -+ status_code=status.HTTP_404_NOT_FOUND, -+ detail="User not found", -+ ) -+ return user_usage - - - @router.get("/v1/attestation/report", tags=["Attestation"]) -@@ -173,6 +179,7 @@ async def chat_completion( - ], - ) - ), -+ background_tasks: BackgroundTasks = BackgroundTasks(), - _rate_limit=Depends( - RateLimit( - concurrent_extractor=chat_completion_concurrent_rate_limit, -@@ -181,6 +188,7 @@ async def chat_completion( - ), - auth_info: AuthenticationInfo = Depends(get_auth_info), - meter: MeteringContext = Depends(LLMMeter), -+ log_ctx: QueryLogContext = Depends(QueryLogContext), - ) -> Union[SignedChatCompletion, StreamingResponse]: - """ - Generate a chat completion response from the AI model. -@@ -234,249 +242,312 @@ async def chat_completion( - ) - response = await chat_completion(request, user) - """ -- -- if len(req.messages) == 0: -- raise HTTPException( -- status_code=400, -- detail="Request contained 0 messages", -- ) -+ # Initialize log context early so we can log any errors -+ log_ctx.set_user(auth_info.user.user_id) -+ log_ctx.set_lockid(meter.lock_id) - model_name = req.model - request_id = str(uuid.uuid4()) - t_start = time.monotonic() -- logger.info(f"[chat] call start request_id={req.messages}") -- endpoint = await state.get_model(model_name) -- if endpoint is None: -- raise HTTPException( -- status_code=status.HTTP_400_BAD_REQUEST, -- detail=f"Invalid model name {model_name}, check /v1/models for options", -- ) -- -- if not endpoint.metadata.tool_support and req.tools: -- raise HTTPException( -- status_code=400, -- detail="Model does not support tool usage, remove tools from request", -- ) - -- has_multimodal = req.has_multimodal_content() -- logger.info(f"[chat] has_multimodal: {has_multimodal}") -- if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): -- raise HTTPException( -- status_code=400, -- detail="Model does not support multimodal content, remove image inputs from request", -- ) -- -- model_url = endpoint.url + "/v1/" -+ try: -+ if len(req.messages) == 0: -+ raise HTTPException( -+ status_code=400, -+ detail="Request contained 0 messages", -+ ) -+ logger.info(f"[chat] call start request_id={req.messages}") -+ endpoint = await state.get_model(model_name) -+ if endpoint is None: -+ raise HTTPException( -+ status_code=status.HTTP_400_BAD_REQUEST, -+ detail=f"Invalid model name {model_name}, check /v1/models for options", -+ ) - -- logger.info( -- f"[chat] start request_id={request_id} user={auth_info.user.userid} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" -- ) -+ # Now we have a valid model, set it in log context -+ log_ctx.set_model(model_name) - -- client = AsyncOpenAI(base_url=model_url, api_key="") -- if auth_info.prompt_document: -- try: -- nildb_prompt: str = await get_prompt_from_nildb(auth_info.prompt_document) -- req.messages.insert( -- 0, MessageAdapter.new_message(role="system", content=nildb_prompt) -- ) -- except Exception as e: -+ if not endpoint.metadata.tool_support and req.tools: - raise HTTPException( -- status_code=status.HTTP_403_FORBIDDEN, -- detail=f"Unable to extract prompt from nilDB: {str(e)}", -+ status_code=400, -+ detail="Model does not support tool usage, remove tools from request", - ) - -- if req.nilrag: -- logger.info(f"[chat] nilrag start request_id={request_id}") -- t_nilrag = time.monotonic() -- await handle_nilrag(req) -- logger.info( -- f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" -- ) -+ has_multimodal = req.has_multimodal_content() -+ logger.info(f"[chat] has_multimodal: {has_multimodal}") -+ if has_multimodal and ( -+ not endpoint.metadata.multimodal_support or req.web_search -+ ): -+ raise HTTPException( -+ status_code=400, -+ detail="Model does not support multimodal content, remove image inputs from request", -+ ) - -- messages = req.messages -- sources: Optional[List[Source]] = None -+ model_url = endpoint.url + "/v1/" - -- if req.web_search: -- logger.info(f"[chat] web_search start request_id={request_id}") -- t_ws = time.monotonic() -- web_search_result = await handle_web_search(req, model_name, client) -- messages = web_search_result.messages -- sources = web_search_result.sources - logger.info( -- f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" -+ f"[chat] start request_id={request_id} user={auth_info.user.user_id} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" -+ ) -+ log_ctx.set_request_params( -+ temperature=req.temperature, -+ max_tokens=req.max_tokens, -+ was_streamed=req.stream or False, -+ was_multimodal=has_multimodal, -+ was_nildb=bool(auth_info.prompt_document), -+ was_nilrag=bool(req.nilrag), - ) -- logger.info(f"[chat] web_search messages: {messages}") - -- if req.stream: -+ client = AsyncOpenAI(base_url=model_url, api_key="") -+ if auth_info.prompt_document: -+ try: -+ nildb_prompt: str = await get_prompt_from_nildb( -+ auth_info.prompt_document -+ ) -+ req.messages.insert( -+ 0, MessageAdapter.new_message(role="system", content=nildb_prompt) -+ ) -+ except Exception as e: -+ raise HTTPException( -+ status_code=status.HTTP_403_FORBIDDEN, -+ detail=f"Unable to extract prompt from nilDB: {str(e)}", -+ ) - -- async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: -- t_call = time.monotonic() -- prompt_token_usage = 0 -- completion_token_usage = 0 -+ if req.nilrag: -+ logger.info(f"[chat] nilrag start request_id={request_id}") -+ t_nilrag = time.monotonic() -+ await handle_nilrag(req) -+ logger.info( -+ f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" -+ ) - -- try: -- logger.info(f"[chat] stream start request_id={request_id}") -- -- request_kwargs = { -- "model": req.model, -- "messages": messages, -- "stream": True, -- "top_p": req.top_p, -- "temperature": req.temperature, -- "max_tokens": req.max_tokens, -- "extra_body": { -- "stream_options": { -- "include_usage": True, -- "continuous_usage_stats": False, -- } -- }, -- } -- if req.tools: -- request_kwargs["tools"] = req.tools -+ messages = req.messages -+ sources: Optional[List[Source]] = None -+ -+ if req.web_search: -+ logger.info(f"[chat] web_search start request_id={request_id}") -+ t_ws = time.monotonic() -+ web_search_result = await handle_web_search(req, model_name, client) -+ messages = web_search_result.messages -+ sources = web_search_result.sources -+ logger.info( -+ f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" -+ ) -+ logger.info(f"[chat] web_search messages: {messages}") -+ -+ if req.stream: -+ -+ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: -+ t_call = time.monotonic() -+ prompt_token_usage = 0 -+ completion_token_usage = 0 -+ -+ try: -+ logger.info(f"[chat] stream start request_id={request_id}") -+ -+ log_ctx.start_model_timing() -+ -+ request_kwargs = { -+ "model": req.model, -+ "messages": messages, -+ "stream": True, -+ "top_p": req.top_p, -+ "temperature": req.temperature, -+ "max_tokens": req.max_tokens, -+ "extra_body": { -+ "stream_options": { -+ "include_usage": True, -+ "continuous_usage_stats": False, -+ } -+ }, -+ } -+ if req.tools: -+ request_kwargs["tools"] = req.tools - -- response = await client.chat.completions.create(**request_kwargs) -+ response = await client.chat.completions.create(**request_kwargs) - -- async for chunk in response: -- if chunk.usage is not None: -- prompt_token_usage = chunk.usage.prompt_tokens -- completion_token_usage = chunk.usage.completion_tokens -+ async for chunk in response: -+ if chunk.usage is not None: -+ prompt_token_usage = chunk.usage.prompt_tokens -+ completion_token_usage = chunk.usage.completion_tokens - -- payload = chunk.model_dump(exclude_unset=True) -+ payload = chunk.model_dump(exclude_unset=True) - -- if chunk.usage is not None and sources: -- payload["sources"] = [ -- s.model_dump(mode="json") for s in sources -- ] -+ if chunk.usage is not None and sources: -+ payload["sources"] = [ -+ s.model_dump(mode="json") for s in sources -+ ] - -- yield f"data: {json.dumps(payload)}\n\n" -+ yield f"data: {json.dumps(payload)}\n\n" - -- await UserManager.update_token_usage( -- auth_info.user.userid, -- prompt_tokens=prompt_token_usage, -- completion_tokens=completion_token_usage, -- ) -- meter.set_response( -- { -- "usage": LLMUsage( -- prompt_tokens=prompt_token_usage, -- completion_tokens=completion_token_usage, -- web_searches=len(sources) if sources else 0, -- ) -- } -- ) -- await QueryLogManager.log_query( -- auth_info.user.userid, -- model=req.model, -- prompt_tokens=prompt_token_usage, -- completion_tokens=completion_token_usage, -- web_search_calls=len(sources) if sources else 0, -- ) -- logger.info( -- "[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d " -- "duration_ms=%.0f total_ms=%.0f", -- request_id, -- prompt_token_usage, -- completion_token_usage, -- (time.monotonic() - t_call) * 1000, -- (time.monotonic() - t_start) * 1000, -- ) -+ log_ctx.end_model_timing() -+ meter.set_response( -+ { -+ "usage": LLMUsage( -+ prompt_tokens=prompt_token_usage, -+ completion_tokens=completion_token_usage, -+ web_searches=len(sources) if sources else 0, -+ ) -+ } -+ ) -+ log_ctx.set_usage( -+ prompt_tokens=prompt_token_usage, -+ completion_tokens=completion_token_usage, -+ web_search_calls=len(sources) if sources else 0, -+ ) -+ await log_ctx.commit() -+ logger.info( -+ "[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d " -+ "duration_ms=%.0f total_ms=%.0f", -+ request_id, -+ prompt_token_usage, -+ completion_token_usage, -+ (time.monotonic() - t_call) * 1000, -+ (time.monotonic() - t_start) * 1000, -+ ) -+ -+ except Exception as e: -+ logger.error( -+ "[chat] stream error request_id=%s error=%s", request_id, e -+ ) -+ log_ctx.set_error(error_code=500, error_message=str(e)) -+ await log_ctx.commit() -+ yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" -+ -+ return StreamingResponse( -+ chat_completion_stream_generator(), -+ media_type="text/event-stream", -+ ) - -- except Exception as e: -- logger.error( -- "[chat] stream error request_id=%s error=%s", request_id, e -- ) -- yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" -+ current_messages = messages -+ request_kwargs = { -+ "model": req.model, -+ "messages": current_messages, # type: ignore -+ "top_p": req.top_p, -+ "temperature": req.temperature, -+ "max_tokens": req.max_tokens, -+ } -+ if req.tools: -+ request_kwargs["tools"] = req.tools # type: ignore -+ request_kwargs["tool_choice"] = req.tool_choice -+ -+ logger.info(f"[chat] call start request_id={request_id}") -+ logger.info(f"[chat] call message: {current_messages}") -+ t_call = time.monotonic() -+ log_ctx.start_model_timing() -+ response = await client.chat.completions.create(**request_kwargs) # type: ignore -+ log_ctx.end_model_timing() -+ logger.info( -+ f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" -+ ) -+ logger.info(f"[chat] call response: {response}") -+ -+ # Handle tool workflow fully inside tools.router -+ log_ctx.start_tool_timing() -+ ( -+ final_completion, -+ agg_prompt_tokens, -+ agg_completion_tokens, -+ ) = await handle_tool_workflow(client, req, current_messages, response) -+ log_ctx.end_tool_timing() -+ logger.info(f"[chat] call final_completion: {final_completion}") -+ model_response = SignedChatCompletion( -+ **final_completion.model_dump(), -+ signature="", -+ sources=sources, -+ ) - -- return StreamingResponse( -- chat_completion_stream_generator(), -- media_type="text/event-stream", -+ logger.info( -+ f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" - ) - -- current_messages = messages -- request_kwargs = { -- "model": req.model, -- "messages": current_messages, # type: ignore -- "top_p": req.top_p, -- "temperature": req.temperature, -- "max_tokens": req.max_tokens, -- } -- if req.tools: -- request_kwargs["tools"] = req.tools # type: ignore -- request_kwargs["tool_choice"] = req.tool_choice -- -- logger.info(f"[chat] call start request_id={request_id}") -- logger.info(f"[chat] call message: {current_messages}") -- t_call = time.monotonic() -- response = await client.chat.completions.create(**request_kwargs) # type: ignore -- logger.info( -- f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" -- ) -- logger.info(f"[chat] call response: {response}") -- -- # Handle tool workflow fully inside tools.router -- ( -- final_completion, -- agg_prompt_tokens, -- agg_completion_tokens, -- ) = await handle_tool_workflow(client, req, current_messages, response) -- logger.info(f"[chat] call final_completion: {final_completion}") -- model_response = SignedChatCompletion( -- **final_completion.model_dump(), -- signature="", -- sources=sources, -- ) -+ if model_response.usage is None: -+ raise HTTPException( -+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, -+ detail="Model response does not contain usage statistics", -+ ) - -- logger.info( -- f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" -- ) -+ if agg_prompt_tokens or agg_completion_tokens: -+ total_prompt_tokens = response.usage.prompt_tokens -+ total_completion_tokens = response.usage.completion_tokens - -- if model_response.usage is None: -- raise HTTPException( -- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, -- detail="Model response does not contain usage statistics", -- ) -+ total_prompt_tokens += agg_prompt_tokens -+ total_completion_tokens += agg_completion_tokens - -- if agg_prompt_tokens or agg_completion_tokens: -- total_prompt_tokens = response.usage.prompt_tokens -- total_completion_tokens = response.usage.completion_tokens -+ model_response.usage.prompt_tokens = total_prompt_tokens -+ model_response.usage.completion_tokens = total_completion_tokens -+ model_response.usage.total_tokens = ( -+ total_prompt_tokens + total_completion_tokens -+ ) -+ -+ # Update token usage in DB -+ meter.set_response( -+ { -+ "usage": LLMUsage( -+ prompt_tokens=model_response.usage.prompt_tokens, -+ completion_tokens=model_response.usage.completion_tokens, -+ web_searches=len(sources) if sources else 0, -+ ) -+ } -+ ) - -- total_prompt_tokens += agg_prompt_tokens -- total_completion_tokens += agg_completion_tokens -+ # Log query with context -+ tool_calls_count = 0 -+ if final_completion.choices and final_completion.choices[0].message.tool_calls: -+ tool_calls_count = len(final_completion.choices[0].message.tool_calls) - -- model_response.usage.prompt_tokens = total_prompt_tokens -- model_response.usage.completion_tokens = total_completion_tokens -- model_response.usage.total_tokens = ( -- total_prompt_tokens + total_completion_tokens -+ log_ctx.set_usage( -+ prompt_tokens=model_response.usage.prompt_tokens, -+ completion_tokens=model_response.usage.completion_tokens, -+ tool_calls=tool_calls_count, -+ web_search_calls=len(sources) if sources else 0, - ) -+ # Use background task for successful requests to avoid blocking response -+ background_tasks.add_task(log_ctx.commit) - -- # Update token usage in DB -- await UserManager.update_token_usage( -- auth_info.user.userid, -- prompt_tokens=model_response.usage.prompt_tokens, -- completion_tokens=model_response.usage.completion_tokens, -- ) -- meter.set_response( -- { -- "usage": LLMUsage( -- prompt_tokens=model_response.usage.prompt_tokens, -- completion_tokens=model_response.usage.completion_tokens, -- web_searches=len(sources) if sources else 0, -- ) -- } -- ) -- await QueryLogManager.log_query( -- auth_info.user.userid, -- model=req.model, -- prompt_tokens=model_response.usage.prompt_tokens, -- completion_tokens=model_response.usage.completion_tokens, -- web_search_calls=len(sources) if sources else 0, -- ) -+ # Sign the response -+ response_json = model_response.model_dump_json() -+ signature = sign_message(state.private_key, response_json) -+ model_response.signature = b64encode(signature).decode() - -- # Sign the response -- response_json = model_response.model_dump_json() -- signature = sign_message(state.private_key, response_json) -- model_response.signature = b64encode(signature).decode() -+ logger.info( -+ f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" -+ ) -+ return model_response -+ except HTTPException as e: -+ # Extract error code from HTTPException, default to status code -+ error_code = e.status_code -+ error_message = str(e.detail) if e.detail else str(e) -+ logger.error( -+ f"[chat] HTTPException request_id={request_id} user={auth_info.user.user_id} " -+ f"model={model_name} error_code={error_code} error={error_message}", -+ exc_info=True, -+ ) - -- logger.info( -- f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" -- ) -- return model_response -+ # Only log server errors (5xx) to database to prevent DoS attacks via client errors -+ # Client errors (4xx) are logged to application logs only -+ if error_code >= 500: -+ # Set model if not already set (e.g., for validation errors before model validation) -+ if log_ctx.model is None: -+ log_ctx.set_model(model_name) -+ log_ctx.set_error(error_code=error_code, error_message=error_message) -+ await log_ctx.commit() -+ # For 4xx errors, we skip DB logging - they're logged above via logger.error() -+ # This prevents DoS attacks where attackers send many invalid requests -+ -+ raise -+ except Exception as e: -+ # Catch any other unexpected exceptions -+ error_message = str(e) -+ logger.error( -+ f"[chat] unexpected error request_id={request_id} user={auth_info.user.user_id} " -+ f"model={model_name} error={error_message}", -+ exc_info=True, -+ ) -+ # Set model if not already set -+ if log_ctx.model is None: -+ log_ctx.set_model(model_name) -+ log_ctx.set_error(error_code=500, error_message=error_message) -+ await log_ctx.commit() -+ raise HTTPException( -+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, -+ detail=f"Internal server error: {error_message}", -+ ) -diff --git a/tests/e2e/config.py b/tests/e2e/config.py -index e06f9d4..3111902 100644 ---- a/tests/e2e/config.py -+++ b/tests/e2e/config.py -@@ -38,7 +38,7 @@ models = { - "meta-llama/Llama-3.1-8B-Instruct", - ], - "ci": [ -- "meta-llama/Llama-3.2-1B-Instruct", -+ "llama-3.2-1b-instruct", - ], - } - -diff --git a/tests/integration/nilai_api/test_users_db_integration.py b/tests/integration/nilai_api/test_users_db_integration.py -index 82d8d02..892a3af 100644 ---- a/tests/integration/nilai_api/test_users_db_integration.py -+++ b/tests/integration/nilai_api/test_users_db_integration.py -@@ -17,37 +17,17 @@ class TestUserManagerIntegration: - async def test_simple_user_creation(self, clean_database): - """Test creating a simple user and retrieving it.""" - # Insert user with minimal data -- user = await UserManager.insert_user(name="Simple Test User") -+ user = await UserManager.insert_user(user_id="Simple Test User") - - # Verify user creation -- assert user.name == "Simple Test User" -- assert user.userid is not None -- assert user.apikey is not None -- assert user.userid != user.apikey # Should be different UUIDs -+ assert user.user_id == "Simple Test User" -+ assert user.rate_limits is not None - - # Retrieve user by ID -- found_user = await UserManager.check_user(user.userid) -+ found_user = await UserManager.check_user(user.user_id) - assert found_user is not None -- assert found_user.userid == user.userid -- assert found_user.name == "Simple Test User" -- assert found_user.apikey == user.apikey -- -- @pytest.mark.asyncio -- async def test_api_key_validation(self, clean_database): -- """Test API key validation functionality.""" -- # Create user -- user = await UserManager.insert_user("API Test User") -- -- # Validate correct API key -- api_user = await UserManager.check_api_key(user.apikey) -- assert api_user is not None -- assert api_user.apikey == user.apikey -- assert api_user.userid == user.userid -- assert api_user.name == "API Test User" -- -- # Test invalid API key -- invalid_user = await UserManager.check_api_key("invalid-api-key") -- assert invalid_user is None -+ assert found_user.user_id == user.user_id -+ assert found_user.rate_limits == user.rate_limits - - @pytest.mark.asyncio - async def test_rate_limits_json_crud_basic(self, clean_database): -@@ -66,14 +46,14 @@ class TestUserManagerIntegration: - - # CREATE: Insert user with rate limits - user = await UserManager.insert_user( -- name="Rate Limits Test User", rate_limits=rate_limits -+ user_id="Rate Limits Test User", rate_limits=rate_limits - ) - - # Verify rate limits are stored as JSON - assert user.rate_limits == rate_limits.model_dump() - - # READ: Retrieve user and verify rate limits JSON -- retrieved_user = await UserManager.check_user(user.userid) -+ retrieved_user = await UserManager.check_user(user.user_id) - assert retrieved_user is not None - assert retrieved_user.rate_limits == rate_limits.model_dump() - -@@ -98,11 +78,11 @@ class TestUserManagerIntegration: - ) - - user = await UserManager.insert_user( -- name="Update Rate Limits User", rate_limits=initial_rate_limits -+ user_id="Update Rate Limits User", rate_limits=initial_rate_limits - ) - - # Verify initial rate limits -- retrieved_user = await UserManager.check_user(user.userid) -+ retrieved_user = await UserManager.check_user(user.user_id) - assert retrieved_user is not None - assert retrieved_user.rate_limits == initial_rate_limits.model_dump() - -@@ -125,19 +105,19 @@ class TestUserManagerIntegration: - stmt = sa.text(""" - UPDATE users - SET rate_limits = :rate_limits_json -- WHERE userid = :userid -+ WHERE user_id = :user_id - """) - await session.execute( - stmt, - { - "rate_limits_json": updated_rate_limits.model_dump_json(), -- "userid": user.userid, -+ "user_id": user.user_id, - }, - ) - await session.commit() - - # READ: Verify the update worked -- updated_user = await UserManager.check_user(user.userid) -+ updated_user = await UserManager.check_user(user.user_id) - assert updated_user is not None - assert updated_user.rate_limits == updated_rate_limits.model_dump() - -@@ -162,11 +142,11 @@ class TestUserManagerIntegration: - ) - - user = await UserManager.insert_user( -- name="Partial Rate Limits User", rate_limits=partial_rate_limits -+ user_id="Partial Rate Limits User", rate_limits=partial_rate_limits - ) - - # Verify partial data is stored correctly -- retrieved_user = await UserManager.check_user(user.userid) -+ retrieved_user = await UserManager.check_user(user.user_id) - assert retrieved_user is not None - assert retrieved_user.rate_limits == partial_rate_limits.model_dump() - -@@ -183,13 +163,13 @@ class TestUserManagerIntegration: - '{user_rate_limit_hour}', - '75' - ) -- WHERE userid = :userid -+ WHERE user_id = :user_id - """) -- await session.execute(stmt, {"userid": user.userid}) -+ await session.execute(stmt, {"user_id": user.user_id}) - await session.commit() - - # Verify partial update worked -- updated_user = await UserManager.check_user(user.userid) -+ updated_user = await UserManager.check_user(user.user_id) - assert updated_user is not None - - expected_data = partial_rate_limits.model_dump() -@@ -211,7 +191,7 @@ class TestUserManagerIntegration: - ) - - user = await UserManager.insert_user( -- name="Delete Rate Limits User", rate_limits=rate_limits -+ user_id="Delete Rate Limits User", rate_limits=rate_limits - ) - - # DELETE: Set rate_limits to NULL -@@ -219,12 +199,14 @@ class TestUserManagerIntegration: - import sqlalchemy as sa - - async with get_db_session() as session: -- stmt = sa.text("UPDATE users SET rate_limits = NULL WHERE userid = :userid") -- await session.execute(stmt, {"userid": user.userid}) -+ stmt = sa.text( -+ "UPDATE users SET rate_limits = NULL WHERE user_id = :user_id" -+ ) -+ await session.execute(stmt, {"user_id": user.user_id}) - await session.commit() - - # Verify NULL handling -- null_user = await UserManager.check_user(user.userid) -+ null_user = await UserManager.check_user(user.user_id) - assert null_user is not None - assert null_user.rate_limits is None - -@@ -239,15 +221,15 @@ class TestUserManagerIntegration: - # First set some data - new_data = {"user_rate_limit_day": 500, "web_search_rate_limit_day": 25} - stmt = sa.text( -- "UPDATE users SET rate_limits = :data WHERE userid = :userid" -+ "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" - ) - await session.execute( -- stmt, {"data": json.dumps(new_data), "userid": user.userid} -+ stmt, {"data": json.dumps(new_data), "user_id": user.user_id} - ) - await session.commit() - - # Verify data was set -- updated_user = await UserManager.check_user(user.userid) -+ updated_user = await UserManager.check_user(user.user_id) - assert updated_user is not None - assert updated_user.rate_limits == new_data - -@@ -256,13 +238,13 @@ class TestUserManagerIntegration: - stmt = sa.text(""" - UPDATE users - SET rate_limits = rate_limits::jsonb - 'web_search_rate_limit_day' -- WHERE userid = :userid -+ WHERE user_id = :user_id - """) -- await session.execute(stmt, {"userid": user.userid}) -+ await session.execute(stmt, {"user_id": user.user_id}) - await session.commit() - - # Verify field was removed -- final_user = await UserManager.check_user(user.userid) -+ final_user = await UserManager.check_user(user.user_id) - expected_final_data = {"user_rate_limit_day": 500} - assert final_user is not None - assert final_user.rate_limits == expected_final_data -@@ -293,15 +275,15 @@ class TestUserManagerIntegration: - for i, test_data in enumerate(test_cases): - async with get_db_session() as session: - stmt = sa.text( -- "UPDATE users SET rate_limits = :data WHERE userid = :userid" -+ "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" - ) - await session.execute( -- stmt, {"data": json.dumps(test_data), "userid": user.userid} -+ stmt, {"data": json.dumps(test_data), "user_id": user.user_id} - ) - await session.commit() - - # Retrieve and verify -- updated_user = await UserManager.check_user(user.userid) -+ updated_user = await UserManager.check_user(user.user_id) - assert updated_user is not None - assert updated_user.rate_limits == test_data - -@@ -327,11 +309,13 @@ class TestUserManagerIntegration: - - # Test empty JSON object - async with get_db_session() as session: -- stmt = sa.text("UPDATE users SET rate_limits = '{}' WHERE userid = :userid") -- await session.execute(stmt, {"userid": user.userid}) -+ stmt = sa.text( -+ "UPDATE users SET rate_limits = '{}' WHERE user_id = :user_id" -+ ) -+ await session.execute(stmt, {"user_id": user.user_id}) - await session.commit() - -- empty_user = await UserManager.check_user(user.userid) -+ empty_user = await UserManager.check_user(user.user_id) - assert empty_user is not None - assert empty_user.rate_limits == {} - empty_rate_limits_obj = empty_user.rate_limits_obj -@@ -343,18 +327,18 @@ class TestUserManagerIntegration: - async with get_db_session() as session: - # This should work as PostgreSQL JSONB validates JSON - stmt = sa.text( -- "UPDATE users SET rate_limits = :data WHERE userid = :userid" -+ "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" - ) - await session.execute( - stmt, - { - "data": '{"user_rate_limit_day": 5000}', # Valid JSON string -- "userid": user.userid, -+ "user_id": user.user_id, - }, - ) - await session.commit() - -- json_string_user = await UserManager.check_user(user.userid) -+ json_string_user = await UserManager.check_user(user.user_id) - assert json_string_user is not None - assert json_string_user.rate_limits == {"user_rate_limit_day": 5000} - -@@ -366,16 +350,16 @@ class TestUserManagerIntegration: - async def test_rate_limits_update_workflow(self, clean_database): - """Test complete workflow: create user with no rate limits -> update rate limits -> verify update.""" - # Step 1: Create user with NO rate limits -- user = await UserManager.insert_user(name="Rate Limits Workflow User") -+ user = await UserManager.insert_user(user_id="Rate Limits Workflow User") - - # Verify user was created with no rate limits - assert user.name == "Rate Limits Workflow User" -- assert user.userid is not None -+ assert user.user_id is not None - assert user.apikey is not None - assert user.rate_limits is None # No rate limits initially - - # Step 2: Retrieve user and confirm no rate limits -- retrieved_user = await UserManager.check_user(user.userid) -+ retrieved_user = await UserManager.check_user(user.user_id) - assert retrieved_user is not None - print(retrieved_user.to_pydantic()) - assert retrieved_user is not None -@@ -401,12 +385,12 @@ class TestUserManagerIntegration: - - # Step 4: Update the user's rate limits using the new function - update_success = await UserManager.update_rate_limits( -- user.userid, new_rate_limits -+ user.user_id, new_rate_limits - ) - assert update_success is True - - # Step 5: Retrieve user again and verify rate limits were updated -- updated_user = await UserManager.check_user(user.userid) -+ updated_user = await UserManager.check_user(user.user_id) - assert updated_user is not None - assert updated_user.rate_limits is not None - assert updated_user.rate_limits == new_rate_limits.model_dump() -@@ -431,12 +415,12 @@ class TestUserManagerIntegration: - ) - - partial_update_success = await UserManager.update_rate_limits( -- user.userid, partial_rate_limits -+ user.user_id, partial_rate_limits - ) - assert partial_update_success is True - - # Step 8: Verify partial update worked -- final_user = await UserManager.check_user(user.userid) -+ final_user = await UserManager.check_user(user.user_id) - assert final_user is not None - assert final_user.rate_limits == partial_rate_limits.model_dump() - -@@ -447,8 +431,8 @@ class TestUserManagerIntegration: - # Other fields should have config defaults (not None due to get_effective_limits) - - # Step 9: Test error case - update non-existent user -- fake_userid = "non-existent-user-id" -+ fake_user_id = "non-existent-user-id" - error_update = await UserManager.update_rate_limits( -- fake_userid, new_rate_limits -+ fake_user_id, new_rate_limits - ) - assert error_update is False -diff --git a/tests/unit/nilai_api/__init__.py b/tests/unit/nilai_api/__init__.py -index 0be5261..7cbc123 100644 ---- a/tests/unit/nilai_api/__init__.py -+++ b/tests/unit/nilai_api/__init__.py -@@ -21,11 +21,11 @@ class MockUserDatabase: - - async def insert_user(self, name: str, email: str) -> Dict[str, str]: - """Insert a new user into the mock database.""" -- userid = self.generate_user_id() -+ user_id = self.generate_user_id() - apikey = self.generate_api_key() - - user_data = { -- "userid": userid, -+ "user_id": user_id, - "name": name, - "email": email, - "apikey": apikey, -@@ -36,34 +36,34 @@ class MockUserDatabase: - "last_activity": None, - } - -- self.users[userid] = user_data -- return {"userid": userid, "apikey": apikey} -+ self.users[user_id] = user_data -+ return {"user_id": user_id, "apikey": apikey} - - async def check_api_key(self, api_key: str) -> Optional[dict]: - """Validate an API key in the mock database.""" - for user in self.users.values(): - if user["apikey"] == api_key: -- return {"name": user["name"], "userid": user["userid"]} -+ return {"name": user["name"], "user_id": user["user_id"]} - return None - - async def update_token_usage( -- self, userid: str, prompt_tokens: int, completion_tokens: int -+ self, user_id: str, prompt_tokens: int, completion_tokens: int - ): - """Update token usage for a specific user.""" -- if userid in self.users: -- user = self.users[userid] -+ if user_id in self.users: -+ user = self.users[user_id] - user["prompt_tokens"] += prompt_tokens - user["completion_tokens"] += completion_tokens - user["queries"] += 1 - user["last_activity"] = datetime.now(timezone.utc) - - async def log_query( -- self, userid: str, model: str, prompt_tokens: int, completion_tokens: int -+ self, user_id: str, model: str, prompt_tokens: int, completion_tokens: int - ): - """Log a user's query in the mock database.""" - query_log = { - "id": self._next_query_log_id, -- "userid": userid, -+ "user_id": user_id, - "query_timestamp": datetime.now(timezone.utc), - "model": model, - "prompt_tokens": prompt_tokens, -@@ -74,9 +74,9 @@ class MockUserDatabase: - self.query_logs[self._next_query_log_id] = query_log - self._next_query_log_id += 1 - -- async def get_token_usage(self, userid: str) -> Optional[Dict[str, Any]]: -+ async def get_token_usage(self, user_id: str) -> Optional[Dict[str, Any]]: - """Get token usage for a specific user.""" -- user = self.users.get(userid) -+ user = self.users.get(user_id) - if user: - return { - "prompt_tokens": user["prompt_tokens"], -@@ -90,9 +90,9 @@ class MockUserDatabase: - """Retrieve all users from the mock database.""" - return list(self.users.values()) if self.users else None - -- async def get_user_token_usage(self, userid: str) -> Optional[Dict[str, int]]: -+ async def get_user_token_usage(self, user_id: str) -> Optional[Dict[str, int]]: - """Retrieve total token usage for a user.""" -- user = self.users.get(userid) -+ user = self.users.get(user_id) - if user: - return { - "prompt_tokens": user["prompt_tokens"], -diff --git a/tests/unit/nilai_api/auth/test_auth.py b/tests/unit/nilai_api/auth/test_auth.py -index 591c447..ec1aabc 100644 ---- a/tests/unit/nilai_api/auth/test_auth.py -+++ b/tests/unit/nilai_api/auth/test_auth.py -@@ -29,7 +29,7 @@ def mock_user_model(): - - mock = MagicMock(spec=UserModel) - mock.name = "Test User" -- mock.userid = "test-user-id" -+ mock.user_id = "test-user-id" - mock.apikey = "test-api-key" - mock.prompt_tokens = 0 - mock.completion_tokens = 0 -@@ -72,11 +72,9 @@ async def test_get_auth_info_valid_token( - - auth_info = await get_auth_info(credentials) - print(auth_info) -- assert auth_info.user.name == "Test User", ( -- f"Expected Test User but got {auth_info.user.name}" -- ) -- assert auth_info.user.userid == "test-user-id", ( -- f"Expected test-user-id but got {auth_info.user.userid}" -+ -+ assert auth_info.user.user_id == "test-user-id", ( -+ f"Expected test-user-id but got {auth_info.user.user_id}" - ) - - -diff --git a/tests/unit/nilai_api/auth/test_strategies.py b/tests/unit/nilai_api/auth/test_strategies.py -index 0c169f5..d362786 100644 ---- a/tests/unit/nilai_api/auth/test_strategies.py -+++ b/tests/unit/nilai_api/auth/test_strategies.py -@@ -16,7 +16,7 @@ class TestAuthStrategies: - """Mock UserModel fixture""" - mock = MagicMock(spec=UserModel) - mock.name = "Test User" -- mock.userid = "test-user-id" -+ mock.user_id = "test-user-id" - mock.apikey = "test-api-key" - mock.prompt_tokens = 0 - mock.completion_tokens = 0 -@@ -43,7 +43,6 @@ class TestAuthStrategies: - result = await api_key_strategy("test-api-key") - - assert isinstance(result, AuthenticationInfo) -- assert result.user.name == "Test User" - assert result.token_rate_limit is None - assert result.prompt_document is None - -@@ -84,7 +83,6 @@ class TestAuthStrategies: - result = await nuc_strategy("nuc-token") - - assert isinstance(result, AuthenticationInfo) -- assert result.user.name == "Test User" - assert result.token_rate_limit is None - assert result.prompt_document == mock_prompt_document - -@@ -154,7 +152,6 @@ class TestAuthStrategies: - result = await nuc_strategy("nuc-token") - - assert isinstance(result, AuthenticationInfo) -- assert result.user.name == "Test User" - assert result.token_rate_limit is None - assert result.prompt_document is None - -@@ -201,7 +198,7 @@ class TestAuthStrategies: - """Test that all strategies return AuthenticationInfo with prompt_document field""" - mock_user_model = MagicMock(spec=UserModel) - mock_user_model.name = "Test" -- mock_user_model.userid = "test" -+ mock_user_model.user_id = "test" - mock_user_model.apikey = "test" - mock_user_model.prompt_tokens = 0 - mock_user_model.completion_tokens = 0 -diff --git a/tests/unit/nilai_api/routers/test_nildb_endpoints.py b/tests/unit/nilai_api/routers/test_nildb_endpoints.py -index c0103ea..0648980 100644 ---- a/tests/unit/nilai_api/routers/test_nildb_endpoints.py -+++ b/tests/unit/nilai_api/routers/test_nildb_endpoints.py -@@ -18,8 +18,8 @@ class TestNilDBEndpoints: - """Mock user data for subscription owner""" - mock_user_model = MagicMock(spec=UserModel) - mock_user_model.name = "Subscription Owner" -- mock_user_model.userid = "owner-id" -- mock_user_model.apikey = "owner-id" # Same as userid for subscription owner -+ mock_user_model.user_id = "owner-id" -+ mock_user_model.apikey = "owner-id" # Same as user_id for subscription owner - mock_user_model.prompt_tokens = 0 - mock_user_model.completion_tokens = 0 - mock_user_model.queries = 0 -@@ -37,8 +37,8 @@ class TestNilDBEndpoints: - """Mock user data for regular user (not subscription owner)""" - mock_user_model = MagicMock(spec=UserModel) - mock_user_model.name = "Regular User" -- mock_user_model.userid = "user-id" -- mock_user_model.apikey = "different-api-key" # Different from userid -+ mock_user_model.user_id = "user-id" -+ mock_user_model.apikey = "different-api-key" # Different from user_id - mock_user_model.prompt_tokens = 0 - mock_user_model.completion_tokens = 0 - mock_user_model.queries = 0 -@@ -149,7 +149,7 @@ class TestNilDBEndpoints: - ) - - mock_user = MagicMock() -- mock_user.userid = "test-user-id" -+ mock_user.user_id = "test-user-id" - mock_user.name = "Test User" - mock_user.apikey = "test-api-key" - mock_user.rate_limits = RateLimits().get_effective_limits() -@@ -256,7 +256,7 @@ class TestNilDBEndpoints: - ) - - mock_user = MagicMock() -- mock_user.userid = "test-user-id" -+ mock_user.user_id = "test-user-id" - mock_user.name = "Test User" - mock_user.apikey = "test-api-key" - mock_user.rate_limits = RateLimits().get_effective_limits() -@@ -304,7 +304,7 @@ class TestNilDBEndpoints: - from nilai_common import ChatRequest - - mock_user = MagicMock() -- mock_user.userid = "test-user-id" -+ mock_user.user_id = "test-user-id" - mock_user.name = "Test User" - mock_user.apikey = "test-api-key" - mock_user.rate_limits = RateLimits().get_effective_limits() -@@ -419,8 +419,8 @@ class TestNilDBEndpoints: - self, mock_subscription_owner_user, mock_regular_user - ): - """Test the is_subscription_owner property""" -- # Subscription owner (userid == apikey) -+ # Subscription owner (user_id == apikey) - assert mock_subscription_owner_user.is_subscription_owner is True - -- # Regular user (userid != apikey) -+ # Regular user (user_id != apikey) - assert mock_regular_user.is_subscription_owner is False -diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py -index 1978e83..daafc86 100644 ---- a/tests/unit/nilai_api/routers/test_private.py -+++ b/tests/unit/nilai_api/routers/test_private.py -@@ -20,7 +20,7 @@ async def test_runs_in_a_loop(): - @pytest.fixture - def mock_user(): - mock = MagicMock(spec=UserModel) -- mock.userid = "test-user-id" -+ mock.user_id = "test-user-id" - mock.name = "Test User" - mock.apikey = "test-api-key" - mock.prompt_tokens = 100 -@@ -66,7 +66,7 @@ def mock_user_manager(mock_user, mocker): - UserManager, - "insert_user", - return_value={ -- "userid": "test-user-id", -+ "user_id": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, -@@ -81,12 +81,12 @@ def mock_user_manager(mock_user, mocker): - "get_all_users", - return_value=[ - { -- "userid": "test-user-id", -+ "user_id": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - { -- "userid": "test-user-id-2", -+ "user_id": "test-user-id-2", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, -diff --git a/tests/unit/nilai_api/test_db.py b/tests/unit/nilai_api/test_db.py -index dff0fd8..3979321 100644 ---- a/tests/unit/nilai_api/test_db.py -+++ b/tests/unit/nilai_api/test_db.py -@@ -15,7 +15,7 @@ async def test_insert_user(mock_db): - """Test user insertion functionality.""" - user = await mock_db.insert_user("Test User", "test@example.com") - -- assert "userid" in user -+ assert "user_id" in user - assert "apikey" in user - assert len(mock_db.users) == 1 - -@@ -38,9 +38,9 @@ async def test_token_usage(mock_db): - """Test token usage tracking.""" - user = await mock_db.insert_user("Test User", "test@example.com") - -- await mock_db.update_token_usage(user["userid"], 50, 20) -+ await mock_db.update_token_usage(user["user_id"], 50, 20) - -- token_usage = await mock_db.get_token_usage(user["userid"]) -+ token_usage = await mock_db.get_token_usage(user["user_id"]) - assert token_usage["prompt_tokens"] == 50 - assert token_usage["completion_tokens"] == 20 - assert token_usage["queries"] == 1 -@@ -51,9 +51,9 @@ async def test_query_logging(mock_db): - """Test query logging functionality.""" - user = await mock_db.insert_user("Test User", "test@example.com") - -- await mock_db.log_query(user["userid"], "test-model", 10, 15) -+ await mock_db.log_query(user["user_id"], "test-model", 10, 15) - - assert len(mock_db.query_logs) == 1 - log_entry = list(mock_db.query_logs.values())[0] -- assert log_entry["userid"] == user["userid"] -+ assert log_entry["user_id"] == user["user_id"] - assert log_entry["model"] == "test-model" -diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py -index 4cf53b0..82d2119 100644 ---- a/tests/unit/nilai_api/test_rate_limiting.py -+++ b/tests/unit/nilai_api/test_rate_limiting.py -@@ -44,7 +44,7 @@ async def test_concurrent_rate_limit(req): - rate_limit = RateLimit(concurrent_extractor=lambda _: (5, "test")) - - user_limits = UserRateLimits( -- subscription_holder=random_id(), -+ user_id=random_id(), - token_rate_limit=None, - rate_limits=RateLimits( - user_rate_limit_day=None, -@@ -77,7 +77,7 @@ async def test_concurrent_rate_limit(req): - "user_limits", - [ - UserRateLimits( -- subscription_holder=random_id(), -+ user_id=random_id(), - token_rate_limit=None, - rate_limits=RateLimits( - user_rate_limit_day=10, -@@ -91,7 +91,7 @@ async def test_concurrent_rate_limit(req): - ), - ), - UserRateLimits( -- subscription_holder=random_id(), -+ user_id=random_id(), - token_rate_limit=None, - rate_limits=RateLimits( - user_rate_limit_day=None, -@@ -105,7 +105,7 @@ async def test_concurrent_rate_limit(req): - ), - ), - UserRateLimits( -- subscription_holder=random_id(), -+ user_id=random_id(), - token_rate_limit=None, - rate_limits=RateLimits( - user_rate_limit_day=None, -@@ -119,7 +119,7 @@ async def test_concurrent_rate_limit(req): - ), - ), - UserRateLimits( -- subscription_holder=random_id(), -+ user_id=random_id(), - token_rate_limit=TokenRateLimits( - limits=[ - TokenRateLimit( -@@ -180,7 +180,7 @@ async def test_web_search_rate_limits(redis_client): - - rate_limit = RateLimit(web_search_extractor=web_search_extractor) - user_limits = UserRateLimits( -- subscription_holder=apikey, -+ user_id=apikey, - token_rate_limit=None, - rate_limits=RateLimits( - user_rate_limit_day=None, -@@ -212,7 +212,7 @@ async def test_global_web_search_rps_limit(req, redis_client, monkeypatch): - - rate_limit = RateLimit(web_search_extractor=lambda _: True) - user_limits = UserRateLimits( -- subscription_holder=random_id(), -+ user_id=random_id(), - token_rate_limit=None, - rate_limits=RateLimits( - user_rate_limit_day=None, -@@ -253,7 +253,7 @@ async def test_queueing_across_seconds(req, redis_client, monkeypatch): - - rate_limit = RateLimit(web_search_extractor=lambda _: True) - user_limits = UserRateLimits( -- subscription_holder=random_id(), -+ user_id=random_id(), - token_rate_limit=None, - rate_limits=RateLimits( - user_rate_limit_day=None, -diff --git a/uv.lock b/uv.lock -index d54a2ef..1482449 100644 ---- a/uv.lock -+++ b/uv.lock -@@ -1649,7 +1649,7 @@ requires-dist = [ - { name = "gunicorn", specifier = ">=23.0.0" }, - { name = "httpx", specifier = ">=0.27.2" }, - { name = "nilai-common", editable = "packages/nilai-common" }, -- { name = "nilauth-credit-middleware", specifier = ">=0.1.1" }, -+ { name = "nilauth-credit-middleware", specifier = ">=0.1.2" }, - { name = "nilrag", specifier = ">=0.1.11" }, - { name = "nuc", specifier = ">=0.1.0" }, - { name = "openai", specifier = ">=1.59.9" }, -@@ -1658,7 +1658,7 @@ requires-dist = [ - { name = "python-dotenv", specifier = ">=1.0.1" }, - { name = "pyyaml", specifier = ">=6.0.1" }, - { name = "redis", specifier = ">=5.2.1" }, -- { name = "secretvaults", git = "https://github.com/NillionNetwork/secretvaults-py?rev=feat%2Fbackport-did-key-and-ethr-parsing" }, -+ { name = "secretvaults", git = "https://github.com/jcabrero/secretvaults-py?rev=main" }, - { name = "sqlalchemy", specifier = ">=2.0.36" }, - { name = "trafilatura", specifier = ">=1.7.0" }, - { name = "uvicorn", specifier = ">=0.32.1" }, -@@ -1739,7 +1739,7 @@ dev = [ - - [[package]] - name = "nilauth-credit-middleware" --version = "0.1.1" -+version = "0.1.2" - source = { registry = "https://pypi.org/simple" } - dependencies = [ - { name = "fastapi", extra = ["standard"] }, -@@ -1747,9 +1747,9 @@ dependencies = [ - { name = "nuc" }, - { name = "pydantic" }, - ] --sdist = { url = "https://files.pythonhosted.org/packages/9f/cf/7716fa5f4aca83ef39d6f9f8bebc1d80d194c52c9ce6e75ee6bd1f401217/nilauth_credit_middleware-0.1.1.tar.gz", hash = "sha256:ae32c4c1e6bc083c8a7581d72a6da271ce9c0f0f9271a1694acb81ccd0a4a8bd", size = 10259, upload-time = "2025-10-16T11:15:03.918Z" } -+sdist = { url = "https://files.pythonhosted.org/packages/46/bc/ae9b2c26919151fc7193b406a98831eeef197f6ec46b0c075138e66ec016/nilauth_credit_middleware-0.1.2.tar.gz", hash = "sha256:66423a4d18aba1eb5f5d47a04c8f7ae6a19ab4e34433475aa9dc1ba398483fdd", size = 11979, upload-time = "2025-10-30T16:21:20.538Z" } - wheels = [ -- { url = "https://files.pythonhosted.org/packages/a7/b5/6e4090ae2ae8848d12e43f82d8d995cd1dff9de8e947cf5fb2b8a72a828e/nilauth_credit_middleware-0.1.1-py3-none-any.whl", hash = "sha256:10a0fda4ac11f51b9a5dd7b3a8fbabc0b28ff92a170a7729ac11eb15c7b37887", size = 14919, upload-time = "2025-10-16T11:15:02.201Z" }, -+ { url = "https://files.pythonhosted.org/packages/05/c3/73d55667aad701a64f3d1330d66c90a8c292fd19f054093ca74960aca1fb/nilauth_credit_middleware-0.1.2-py3-none-any.whl", hash = "sha256:31f3233e6706c6167b6246a4edb9a405d587eccb1399231223f95c0cdf1ce57c", size = 18121, upload-time = "2025-10-30T16:21:19.547Z" }, - ] - - [[package]] -@@ -2854,8 +2854,8 @@ sdist = { url = "https://files.pythonhosted.org/packages/9b/41/bb668a6e419230354 - - [[package]] - name = "secretvaults" --version = "0.3.0" --source = { git = "https://github.com/NillionNetwork/secretvaults-py?rev=feat%2Fbackport-did-key-and-ethr-parsing#b40aebf572c6d4c94dc381e022b82724d727df23" } -+version = "0.2.1" -+source = { git = "https://github.com/jcabrero/secretvaults-py?rev=main#498ee5304fdcc730d1810fcf6172e56fa6dd7d16" } - dependencies = [ - { name = "aiohttp" }, - { name = "blindfold" }, diff --git a/tests/e2e/test_chat_completions.py b/tests/e2e/test_chat_completions.py index b24137a1..7121d6d7 100644 --- a/tests/e2e/test_chat_completions.py +++ b/tests/e2e/test_chat_completions.py @@ -480,12 +480,7 @@ def test_usage_endpoint(client): assert isinstance(usage_data, dict), "Usage data should be a dictionary" # Check for expected keys - expected_keys = [ - "total_tokens", - "completion_tokens", - "prompt_tokens", - "queries", - ] + expected_keys = ["total_tokens", "completion_tokens", "prompt_tokens"] for key in expected_keys: assert key in usage_data, f"Expected key {key} not found in usage data" diff --git a/tests/e2e/test_chat_completions_http.py b/tests/e2e/test_chat_completions_http.py index 5b8b9da7..080c5370 100644 --- a/tests/e2e/test_chat_completions_http.py +++ b/tests/e2e/test_chat_completions_http.py @@ -139,7 +139,6 @@ def test_usage_endpoint(client): "total_tokens", "completion_tokens", "prompt_tokens", - "queries", ] for key in expected_keys: diff --git a/tests/unit/nilai_api/auth/test_auth.py b/tests/unit/nilai_api/auth/test_auth.py index 7aee50a7..47559272 100644 --- a/tests/unit/nilai_api/auth/test_auth.py +++ b/tests/unit/nilai_api/auth/test_auth.py @@ -1,4 +1,3 @@ -from datetime import datetime, timezone import logging from unittest.mock import MagicMock @@ -23,14 +22,7 @@ def mock_user_model(): from nilai_api.db.users import UserModel mock = MagicMock(spec=UserModel) - mock.name = "Test User" mock.user_id = "test-user-id" - mock.apikey = "test-api-key" - mock.prompt_tokens = 0 - mock.completion_tokens = 0 - mock.queries = 0 - mock.signup_date = datetime.now(timezone.utc) - mock.last_activity = datetime.now(timezone.utc) mock.rate_limits = RateLimits().get_effective_limits().model_dump_json() mock.rate_limits_obj = RateLimits().get_effective_limits() return mock diff --git a/tests/unit/nilai_api/routers/test_chat_completions_private.py b/tests/unit/nilai_api/routers/test_chat_completions_private.py index 35621a3c..5a96366f 100644 --- a/tests/unit/nilai_api/routers/test_chat_completions_private.py +++ b/tests/unit/nilai_api/routers/test_chat_completions_private.py @@ -21,14 +21,6 @@ async def test_runs_in_a_loop(): def mock_user(): mock = MagicMock(spec=UserModel) mock.user_id = "test-user-id" - mock.name = "Test User" - mock.apikey = "test-api-key" - mock.prompt_tokens = 100 - mock.completion_tokens = 50 - mock.total_tokens = 150 - mock.completion_tokens_details = None - mock.prompt_tokens_details = None - mock.queries = 10 mock.rate_limits = RateLimits().get_effective_limits().model_dump_json() mock.rate_limits_obj = RateLimits().get_effective_limits() return mock From 1d5851d253cf56cc4edd144744a7a998f30a882c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Fri, 28 Nov 2025 15:44:51 +0100 Subject: [PATCH 4/7] fix: pytests unit and e2e errors --- .../src/nilai_api/attestation/__init__.py | 3 +- nilai-api/src/nilai_api/db/__init__.py | 6 +- nilai-api/src/nilai_api/rate_limiting.py | 6 +- .../src/nilai_api/routers/endpoints/chat.py | 15 +- .../nilai_api/routers/endpoints/responses.py | 27 ++- nilai-api/src/nilai_api/routers/private.py | 5 +- .../api_models/chat_completion_model.py | 7 +- tests/e2e/conftest.py | 215 ++++++++++++++++++ tests/e2e/test_chat_completions.py | 91 ++------ tests/e2e/test_chat_completions_http.py | 87 +------ tests/e2e/test_responses.py | 74 ++---- tests/e2e/test_responses_http.py | 95 -------- tests/unit/nilai-common/test_discovery.py | 2 +- .../routers/test_chat_completions_private.py | 9 +- .../nilai_api/routers/test_nildb_endpoints.py | 71 +++--- .../routers/test_responses_private.py | 27 ++- tests/unit/nilai_api/test_rate_limiting.py | 2 +- 17 files changed, 353 insertions(+), 389 deletions(-) create mode 100644 tests/e2e/conftest.py diff --git a/nilai-api/src/nilai_api/attestation/__init__.py b/nilai-api/src/nilai_api/attestation/__init__.py index 61697229..ed13f68e 100644 --- a/nilai-api/src/nilai_api/attestation/__init__.py +++ b/nilai-api/src/nilai_api/attestation/__init__.py @@ -5,7 +5,7 @@ ATTESTATION_URL = "http://nilcc-attester/v2/report" -async def get_attestation_report() -> AttestationReport: +async def get_attestation_report(nonce: str) -> AttestationReport: """Get the attestation report""" try: @@ -13,6 +13,7 @@ async def get_attestation_report() -> AttestationReport: response: httpx.Response = await client.get(ATTESTATION_URL) response_json = response.json() return AttestationReport( + nonce=nonce, gpu_attestation=response_json["report"], cpu_attestation=response_json["gpu_token"], verifying_key="", # Added later by the API diff --git a/nilai-api/src/nilai_api/db/__init__.py b/nilai-api/src/nilai_api/db/__init__.py index ee70ffe0..0cc8a623 100644 --- a/nilai-api/src/nilai_api/db/__init__.py +++ b/nilai-api/src/nilai_api/db/__init__.py @@ -14,11 +14,11 @@ from nilai_api.config import CONFIG -_engine: Optional[sqlalchemy.ext.asyncio.AsyncEngine] = None +_engine: Optional[sqlalchemy.ext.asyncio.AsyncEngine] = None # type: ignore[reportAttributeAccessIssue] _SessionLocal: Optional[sessionmaker] = None # Create base and engine with improved configuration -Base = sqlalchemy.orm.declarative_base() +Base = sqlalchemy.orm.declarative_base() # type: ignore[reportAttributeAccessIssue] logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def from_env() -> "DatabaseConfig": return DatabaseConfig(database_url) -def get_engine() -> sqlalchemy.ext.asyncio.AsyncEngine: +def get_engine() -> sqlalchemy.ext.asyncio.AsyncEngine: # type: ignore[reportAttributeAccessIssue] global _engine if _engine is None: config = DatabaseConfig.from_env() diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index ae1e63ae..0a70f15f 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -161,21 +161,21 @@ async def __call__( await self.check_bucket( redis, redis_rate_limit_command, - f"web_search_minute:{user_limits.subscription_holder}", + f"web_search_minute:{user_limits.user_id}", user_limits.rate_limits.web_search_rate_limit_minute, MINUTE_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"web_search_hour:{user_limits.subscription_holder}", + f"web_search_hour:{user_limits.user_id}", user_limits.rate_limits.web_search_rate_limit_hour, HOUR_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"web_search_day:{user_limits.subscription_holder}", + f"web_search_day:{user_limits.user_id}", user_limits.rate_limits.web_search_rate_limit_day, DAY_MS, ) diff --git a/nilai-api/src/nilai_api/routers/endpoints/chat.py b/nilai-api/src/nilai_api/routers/endpoints/chat.py index 783dcc8f..a72c9594 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/chat.py +++ b/nilai-api/src/nilai_api/routers/endpoints/chat.py @@ -5,7 +5,15 @@ from base64 import b64encode from typing import AsyncGenerator, Optional, Union, List, Tuple -from fastapi import APIRouter, Body, Depends, HTTPException, status, Request, BackgroundTasks +from fastapi import ( + APIRouter, + Body, + Depends, + HTTPException, + status, + Request, + BackgroundTasks, +) from fastapi.responses import StreamingResponse from openai import AsyncOpenAI @@ -13,8 +21,7 @@ from nilai_api.config import CONFIG from nilai_api.crypto import sign_message from nilai_api.credit import LLMMeter, LLMUsage -from nilai_api.db.logs import QueryLogManager, QueryLogContext -from nilai_api.db.users import UserManager +from nilai_api.db.logs import QueryLogContext from nilai_api.handlers.nildb.handler import get_prompt_from_nildb from nilai_api.handlers.nilrag import handle_nilrag from nilai_api.handlers.tools.tool_router import handle_tool_workflow @@ -442,4 +449,4 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error: {error_message}", - ) \ No newline at end of file + ) diff --git a/nilai-api/src/nilai_api/routers/endpoints/responses.py b/nilai-api/src/nilai_api/routers/endpoints/responses.py index fa519529..54f9ca92 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/responses.py +++ b/nilai-api/src/nilai_api/routers/endpoints/responses.py @@ -5,7 +5,15 @@ from base64 import b64encode from typing import AsyncGenerator, Optional, Union, List, Tuple -from fastapi import APIRouter, Body, Depends, HTTPException, status, Request, BackgroundTasks +from fastapi import ( + APIRouter, + Body, + Depends, + HTTPException, + status, + Request, + BackgroundTasks, +) from fastapi.responses import StreamingResponse from openai import AsyncOpenAI @@ -13,8 +21,7 @@ from nilai_api.config import CONFIG from nilai_api.crypto import sign_message from nilai_api.credit import LLMMeter, LLMUsage -from nilai_api.db.logs import QueryLogManager, QueryLogContext -from nilai_api.db.users import UserManager +from nilai_api.db.logs import QueryLogContext from nilai_api.handlers.nildb.handler import get_prompt_from_nildb # from nilai_api.handlers.nilrag import handle_nilrag_for_responses @@ -157,7 +164,9 @@ async def create_response( ) has_multimodal = req.has_multimodal_content() - if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): + if has_multimodal and ( + not endpoint.metadata.multimodal_support or req.web_search + ): raise HTTPException( status_code=400, detail="Model does not support multimodal content, remove image inputs from request", @@ -180,7 +189,9 @@ async def create_response( client = AsyncOpenAI(base_url=model_url, api_key="") if auth_info.prompt_document: try: - nildb_prompt: str = await get_prompt_from_nildb(auth_info.prompt_document) + nildb_prompt: str = await get_prompt_from_nildb( + auth_info.prompt_document + ) req.ensure_instructions(nildb_prompt) except Exception as e: raise HTTPException( @@ -244,7 +255,9 @@ async def response_stream_generator() -> AsyncGenerator[str, None]: usage = event.response.usage prompt_token_usage = usage.input_tokens completion_token_usage = usage.output_tokens - payload["response"]["usage"] = usage.model_dump(mode="json") + payload["response"]["usage"] = usage.model_dump( + mode="json" + ) if sources: if "data" not in payload: @@ -389,7 +402,7 @@ async def response_stream_generator() -> AsyncGenerator[str, None]: log_ctx.set_model(model_name) log_ctx.set_error(error_code=error_code, error_message=error_message) await log_ctx.commit() - + raise except Exception as e: error_message = str(e) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 1cb33935..5cbcbf64 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -72,12 +72,13 @@ async def get_usage(auth_info: AuthenticationInfo = Depends(get_auth_info)) -> U @router.get("/v1/attestation/report", tags=["Attestation"]) async def get_attestation( + nonce: str, auth_info: AuthenticationInfo = Depends(get_auth_info), ) -> AttestationReport: """ Generate a cryptographic attestation report. - - **attestation_request**: Attestation request containing a nonce + - **nonce**: Nonce for the attestation request (64 character hex string) - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Attestation details for service verification @@ -90,7 +91,7 @@ async def get_attestation( Provides cryptographic proof of the service's integrity and environment. """ - attestation_report = await get_attestation_report() + attestation_report = await get_attestation_report(nonce) attestation_report.verifying_key = state.b64_public_key return attestation_report diff --git a/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py b/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py index 743285ce..b34ce06e 100644 --- a/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py +++ b/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py @@ -1,3 +1,5 @@ +import uuid + from typing import ( Annotated, Iterable, @@ -54,10 +56,6 @@ "ResultContent", "Choice", "Source", - "SearchResult", - "Topic", - "TopicResponse", - "TopicQuery", "MessageAdapter", "WebSearchEnhancedMessages", "WebSearchContext", @@ -77,7 +75,6 @@ class ResultContent(BaseModel): text: str truncated: bool = False -Message: TypeAlias = ChatCompletionMessageParam class Choice(OpenaAIChoice): diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py new file mode 100644 index 00000000..7c132d82 --- /dev/null +++ b/tests/e2e/conftest.py @@ -0,0 +1,215 @@ +from .config import BASE_URL, api_key_getter +from .nuc import ( + get_rate_limited_nuc_token, + get_invalid_rate_limited_nuc_token, + get_document_id_nuc_token, +) +import httpx +import pytest +import pytest_asyncio +from openai import OpenAI, AsyncOpenAI + + +# ============================================================================ +# HTTP Client Fixtures (for test_chat_completions_http.py, test_responses_http.py) +# ============================================================================ + + +@pytest.fixture +def http_client(): + """Create an HTTPX client with default headers for HTTP-based tests""" + invocation_token: str = api_key_getter() + print("invocation_token", invocation_token) + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + verify=False, + timeout=None, + ) + + +# Alias for backward compatibility +client = http_client + + +@pytest.fixture +def rate_limited_http_client(): + """Create an HTTPX client with rate limiting for HTTP-based tests""" + invocation_token = get_rate_limited_nuc_token(rate_limit=1) + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + timeout=None, + verify=False, + ) + + +# Alias for backward compatibility +rate_limited_client = rate_limited_http_client + + +@pytest.fixture +def invalid_rate_limited_http_client(): + """Create an HTTPX client with invalid rate limiting for HTTP-based tests""" + invocation_token = get_invalid_rate_limited_nuc_token() + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + timeout=None, + verify=False, + ) + + +# Alias for backward compatibility +invalid_rate_limited_client = invalid_rate_limited_http_client + + +@pytest.fixture +def nillion_2025_client(): + """Create an HTTPX client with default headers""" + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": "Bearer Nillion2025", + }, + verify=False, + timeout=None, + ) + + +@pytest.fixture +def document_id_client(): + """Create an HTTPX client with default headers""" + invocation_token = get_document_id_nuc_token() + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + verify=False, + timeout=None, + ) + + +# ============================================================================ +# OpenAI SDK Client Fixtures (for test_chat_completions.py, test_responses.py) +# ============================================================================ + + +def _create_openai_client(api_key: str) -> OpenAI: + """Helper function to create an OpenAI client with SSL verification disabled""" + transport = httpx.HTTPTransport(verify=False) + return OpenAI( + base_url=BASE_URL, + api_key=api_key, + http_client=httpx.Client(transport=transport), + ) + + +def _create_async_openai_client(api_key: str) -> AsyncOpenAI: + """Helper function to create an async OpenAI client with SSL verification disabled""" + transport = httpx.AsyncHTTPTransport(verify=False) + return AsyncOpenAI( + base_url=BASE_URL, + api_key=api_key, + http_client=httpx.AsyncClient(transport=transport), + ) + + +@pytest.fixture +def openai_client(): + """Create an OpenAI SDK client configured to use the Nilai API""" + invocation_token: str = api_key_getter() + return _create_openai_client(invocation_token) + + +@pytest_asyncio.fixture +async def async_openai_client(): + """Create an async OpenAI SDK client configured to use the Nilai API""" + invocation_token: str = api_key_getter() + transport = httpx.AsyncHTTPTransport(verify=False) + httpx_client = httpx.AsyncClient(transport=transport) + client = AsyncOpenAI( + base_url=BASE_URL, api_key=invocation_token, http_client=httpx_client + ) + yield client + await httpx_client.aclose() + + +@pytest.fixture +def rate_limited_openai_client(): + """Create an OpenAI SDK client with rate limiting""" + invocation_token = get_rate_limited_nuc_token(rate_limit=1) + return _create_openai_client(invocation_token) + + +@pytest.fixture +def invalid_rate_limited_openai_client(): + """Create an OpenAI SDK client with invalid rate limiting""" + invocation_token = get_invalid_rate_limited_nuc_token() + return _create_openai_client(invocation_token) + + +@pytest.fixture +def document_id_openai_client(): + """Create an OpenAI SDK client with document ID token""" + invocation_token = get_document_id_nuc_token() + return _create_openai_client(invocation_token) + + +@pytest.fixture +def high_web_search_rate_limit(monkeypatch): + """Set high rate limits for web search for RPS tests""" + monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_MINUTE", "9999") + monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_HOUR", "9999") + monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_DAY", "9999") + monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT", "9999") + monkeypatch.setenv("USER_RATE_LIMIT_MINUTE", "9999") + monkeypatch.setenv("USER_RATE_LIMIT_HOUR", "9999") + monkeypatch.setenv("USER_RATE_LIMIT_DAY", "9999") + monkeypatch.setenv("USER_RATE_LIMIT", "9999") + monkeypatch.setenv( + "MODEL_CONCURRENT_RATE_LIMIT", + ( + '{"meta-llama/Llama-3.2-1B-Instruct": 500, ' + '"meta-llama/Llama-3.2-3B-Instruct": 500, ' + '"meta-llama/Llama-3.1-8B-Instruct": 300, ' + '"cognitivecomputations/Dolphin3.0-Llama3.1-8B": 300, ' + '"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": 50, ' + '"hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4": 50, ' + '"openai/gpt-oss-20b": 500, ' + '"google/gemma-3-27b-it": 500, ' + '"default": 500}' + ), + ) + + +# ============================================================================ +# Convenience Aliases for OpenAI SDK Tests +# These allow test files to use 'client' instead of 'openai_client' +# Note: These will be shadowed by local fixtures in test_chat_completions.py +# and test_responses.py if those files redefine them +# ============================================================================ + +# Uncomment these if you want to use the conftest fixtures without shadowing: +# client = openai_client +# async_client = async_openai_client +# rate_limited_client = rate_limited_openai_client +# invalid_rate_limited_client = invalid_rate_limited_openai_client +# nildb_client = document_id_openai_client diff --git a/tests/e2e/test_chat_completions.py b/tests/e2e/test_chat_completions.py index 7121d6d7..dd362885 100644 --- a/tests/e2e/test_chat_completions.py +++ b/tests/e2e/test_chat_completions.py @@ -11,98 +11,41 @@ import json import os import re -import httpx import pytest import pytest_asyncio -from openai import OpenAI -from openai import AsyncOpenAI from openai.types.chat import ChatCompletion from .config import BASE_URL, ENVIRONMENT, test_models, AUTH_STRATEGY, api_key_getter -from .nuc import ( - get_rate_limited_nuc_token, - get_invalid_rate_limited_nuc_token, -) - - -def _create_openai_client(api_key: str) -> OpenAI: - """Helper function to create an OpenAI client with SSL verification disabled""" - transport = httpx.HTTPTransport(verify=False) - return OpenAI( - base_url=BASE_URL, - api_key=api_key, - http_client=httpx.Client(transport=transport), - ) -def _create_async_openai_client(api_key: str) -> AsyncOpenAI: - transport = httpx.AsyncHTTPTransport(verify=False) - return AsyncOpenAI( - base_url=BASE_URL, - api_key=api_key, - http_client=httpx.AsyncClient(transport=transport), - ) +# ============================================================================ +# Fixture Aliases for OpenAI SDK Tests +# These create local aliases that reference the centralized fixtures in conftest.py +# This allows tests to use 'client' instead of 'openai_client', maintaining backward compatibility +# ============================================================================ @pytest.fixture -def client(): - """Create an OpenAI client configured to use the Nilai API""" - invocation_token: str = api_key_getter() - - return _create_openai_client(invocation_token) +def client(openai_client): + """Alias for openai_client fixture from conftest.py""" + return openai_client @pytest_asyncio.fixture -async def async_client(): - invocation_token: str = api_key_getter() - transport = httpx.AsyncHTTPTransport(verify=False) - httpx_client = httpx.AsyncClient(transport=transport) - client = AsyncOpenAI( - base_url=BASE_URL, api_key=invocation_token, http_client=httpx_client - ) - yield client - await httpx_client.aclose() +async def async_client(async_openai_client): + """Alias for async_openai_client fixture from conftest.py""" + return async_openai_client @pytest.fixture -def rate_limited_client(): - """Create an OpenAI client configured to use the Nilai API with rate limiting""" - invocation_token = get_rate_limited_nuc_token(rate_limit=1) - return _create_openai_client(invocation_token) +def rate_limited_client(rate_limited_openai_client): + """Alias for rate_limited_openai_client fixture from conftest.py""" + return rate_limited_openai_client @pytest.fixture -def invalid_rate_limited_client(): - """Create an OpenAI client configured to use the Nilai API with rate limiting""" - invocation_token = get_invalid_rate_limited_nuc_token() - print(f"invocation_token: {invocation_token}") - return _create_openai_client(invocation_token) - - -@pytest.fixture -def high_web_search_rate_limit(monkeypatch): - """Set high rate limits for web search for RPS tests""" - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_MINUTE", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_HOUR", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_DAY", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT", "9999") - monkeypatch.setenv("USER_RATE_LIMIT_MINUTE", "9999") - monkeypatch.setenv("USER_RATE_LIMIT_HOUR", "9999") - monkeypatch.setenv("USER_RATE_LIMIT_DAY", "9999") - monkeypatch.setenv("USER_RATE_LIMIT", "9999") - monkeypatch.setenv( - "MODEL_CONCURRENT_RATE_LIMIT", - ( - '{"meta-llama/Llama-3.2-1B-Instruct": 500, ' - '"meta-llama/Llama-3.2-3B-Instruct": 500, ' - '"meta-llama/Llama-3.1-8B-Instruct": 300, ' - '"cognitivecomputations/Dolphin3.0-Llama3.1-8B": 300, ' - '"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": 50, ' - '"hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4": 50, ' - '"openai/gpt-oss-20b": 500, ' - '"google/gemma-3-27b-it": 500, ' - '"default": 500}' - ), - ) +def invalid_rate_limited_client(invalid_rate_limited_openai_client): + """Alias for invalid_rate_limited_openai_client fixture from conftest.py""" + return invalid_rate_limited_openai_client @pytest.mark.parametrize( diff --git a/tests/e2e/test_chat_completions_http.py b/tests/e2e/test_chat_completions_http.py index 080c5370..6791f830 100644 --- a/tests/e2e/test_chat_completions_http.py +++ b/tests/e2e/test_chat_completions_http.py @@ -12,96 +12,11 @@ import os import re -from .config import BASE_URL, ENVIRONMENT, test_models, AUTH_STRATEGY, api_key_getter -from .nuc import ( - get_rate_limited_nuc_token, - get_invalid_rate_limited_nuc_token, - get_document_id_nuc_token, -) +from .config import BASE_URL, ENVIRONMENT, test_models, AUTH_STRATEGY import httpx import pytest -@pytest.fixture -def client(): - """Create an HTTPX client with default headers""" - invocation_token: str = api_key_getter() - print("invocation_token", invocation_token) - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - verify=False, - timeout=None, - ) - - -@pytest.fixture -def rate_limited_client(): - """Create an HTTPX client with default headers""" - invocation_token = get_rate_limited_nuc_token(rate_limit=1) - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def invalid_rate_limited_client(): - """Create an HTTPX client with default headers""" - invocation_token = get_invalid_rate_limited_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def nillion_2025_client(): - """Create an HTTPX client with default headers""" - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": "Bearer Nillion2025", - }, - verify=False, - timeout=None, - ) - - -@pytest.fixture -def document_id_client(): - """Create an HTTPX client with default headers""" - invocation_token = get_document_id_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - verify=False, - timeout=None, - ) - - def test_health_endpoint(client): """Test the health endpoint""" response = client.get("health") diff --git a/tests/e2e/test_responses.py b/tests/e2e/test_responses.py index f5f931c5..9e7ba64e 100644 --- a/tests/e2e/test_responses.py +++ b/tests/e2e/test_responses.py @@ -1,80 +1,46 @@ import json import os -import httpx import pytest import pytest_asyncio -from openai import OpenAI -from openai import AsyncOpenAI from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter -from .nuc import ( - get_rate_limited_nuc_token, - get_invalid_rate_limited_nuc_token, - get_nildb_nuc_token, -) - -def _create_openai_client(api_key: str) -> OpenAI: - """Helper function to create an OpenAI client with SSL verification disabled""" - transport = httpx.HTTPTransport(verify=False) - return OpenAI( - base_url=BASE_URL, - api_key=api_key, - http_client=httpx.Client(transport=transport), - ) - -def _create_async_openai_client(api_key: str) -> AsyncOpenAI: - transport = httpx.AsyncHTTPTransport(verify=False) - return AsyncOpenAI( - base_url=BASE_URL, - api_key=api_key, - http_client=httpx.AsyncClient(transport=transport), - ) +# ============================================================================ +# Fixture Aliases for OpenAI SDK Tests +# These create local aliases that reference the centralized fixtures in conftest.py +# This allows tests to use 'client' instead of 'openai_client', maintaining backward compatibility +# ============================================================================ @pytest.fixture -def client(): - invocation_token: str = api_key_getter() - return _create_openai_client(invocation_token) +def client(openai_client): + """Alias for openai_client fixture from conftest.py""" + return openai_client @pytest_asyncio.fixture -async def async_client(): - invocation_token: str = api_key_getter() - transport = httpx.AsyncHTTPTransport(verify=False) - httpx_client = httpx.AsyncClient(transport=transport) - client = AsyncOpenAI( - base_url=BASE_URL, api_key=invocation_token, http_client=httpx_client - ) - yield client - await httpx_client.aclose() - - -@pytest.fixture -def rate_limited_client(): - invocation_token = get_rate_limited_nuc_token(rate_limit=1) - return _create_openai_client(invocation_token.token) +async def async_client(async_openai_client): + """Alias for async_openai_client fixture from conftest.py""" + return async_openai_client @pytest.fixture -def invalid_rate_limited_client(): - invocation_token = get_invalid_rate_limited_nuc_token() - return _create_openai_client(invocation_token.token) +def rate_limited_client(rate_limited_openai_client): + """Alias for rate_limited_openai_client fixture from conftest.py""" + return rate_limited_openai_client @pytest.fixture -def nildb_client(): - invocation_token = get_nildb_nuc_token() - return _create_openai_client(invocation_token.token) +def invalid_rate_limited_client(invalid_rate_limited_openai_client): + """Alias for invalid_rate_limited_openai_client fixture from conftest.py""" + return invalid_rate_limited_openai_client @pytest.fixture -def high_web_search_rate_limit(monkeypatch): - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_MINUTE", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_HOUR", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_DAY", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT", "9999") +def nildb_client(document_id_openai_client): + """Alias for document_id_openai_client fixture from conftest.py""" + return document_id_openai_client @pytest.mark.parametrize("model", test_models) diff --git a/tests/e2e/test_responses_http.py b/tests/e2e/test_responses_http.py index a92c8ddf..ae9e4f74 100644 --- a/tests/e2e/test_responses_http.py +++ b/tests/e2e/test_responses_http.py @@ -5,101 +5,6 @@ import pytest from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter -from .nuc import ( - get_rate_limited_nuc_token, - get_invalid_rate_limited_nuc_token, - get_nildb_nuc_token, - get_document_id_nuc_token, -) - - -@pytest.fixture -def client(): - invocation_token: str = api_key_getter() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - verify=False, - timeout=None, - ) - - -@pytest.fixture -def rate_limited_client(): - invocation_token = get_rate_limited_nuc_token(rate_limit=1) - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token.token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def invalid_rate_limited_client(): - invocation_token = get_invalid_rate_limited_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token.token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def nildb_client(): - invocation_token = get_nildb_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token.token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def nillion_2025_client(): - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": "Bearer Nillion2025", - }, - verify=False, - timeout=None, - ) - - -@pytest.fixture -def document_id_client(): - invocation_token = get_document_id_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token.token}", - }, - verify=False, - timeout=None, - ) @pytest.mark.parametrize("model", test_models) diff --git a/tests/unit/nilai-common/test_discovery.py b/tests/unit/nilai-common/test_discovery.py index 8899eaff..363e99d0 100644 --- a/tests/unit/nilai-common/test_discovery.py +++ b/tests/unit/nilai-common/test_discovery.py @@ -2,7 +2,7 @@ import pytest import pytest_asyncio -from nilai_common.api_model import ModelEndpoint, ModelMetadata +from nilai_common.api_models import ModelEndpoint, ModelMetadata from nilai_common.discovery import ModelServiceDiscovery diff --git a/tests/unit/nilai_api/routers/test_chat_completions_private.py b/tests/unit/nilai_api/routers/test_chat_completions_private.py index 5a96366f..4c1f30b2 100644 --- a/tests/unit/nilai_api/routers/test_chat_completions_private.py +++ b/tests/unit/nilai_api/routers/test_chat_completions_private.py @@ -77,6 +77,7 @@ def mock_state(mocker): # Patch get_attestation method attestation_response = AttestationReport( + nonce="0" * 64, verifying_key="test-verifying-key", cpu_attestation="test-cpu-attestation", gpu_attestation="test-gpu-attestation", @@ -179,9 +180,7 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien "nilai_api.routers.endpoints.chat.handle_tool_workflow", return_value=(response_data, 0, 0), ) - mocker.patch( - "nilai_api.routers.private.QueryLogContext.commit", new_callable=AsyncMock - ) + mocker.patch("nilai_api.db.logs.QueryLogContext.commit", new_callable=AsyncMock) response = client.post( "/v1/chat/completions", json={ @@ -222,9 +221,7 @@ def test_chat_completion_stream_includes_sources( "nilai_api.routers.endpoints.chat.handle_web_search", new=AsyncMock(return_value=mock_web_search_result), ) - mocker.patch( - "nilai_api.routers.private.QueryLogContext.commit", new_callable=AsyncMock - ) + mocker.patch("nilai_api.db.logs.QueryLogContext.commit", new_callable=AsyncMock) class MockChunk: def __init__(self, data, usage=None): diff --git a/tests/unit/nilai_api/routers/test_nildb_endpoints.py b/tests/unit/nilai_api/routers/test_nildb_endpoints.py index 98a68da0..ff1ecdd6 100644 --- a/tests/unit/nilai_api/routers/test_nildb_endpoints.py +++ b/tests/unit/nilai_api/routers/test_nildb_endpoints.py @@ -7,7 +7,6 @@ from nilai_api.handlers.nildb.api_model import ( PromptDelegationToken, ) -from datetime import datetime, timezone from nilai_common import ResponseRequest @@ -177,12 +176,6 @@ async def test_chat_completion_with_prompt_document_injection(self): patch( "nilai_api.routers.endpoints.chat.handle_web_search" ) as mock_handle_web_search, - patch( - "nilai_api.routers.endpoints.chat.UserManager.update_token_usage" - ) as mock_update_usage, - patch( - "nilai_api.routers.endpoints.chat.QueryLogManager.log_query" - ) as mock_log_query, patch( "nilai_api.routers.endpoints.chat.handle_tool_workflow" ) as mock_handle_tool_workflow, @@ -244,7 +237,10 @@ async def test_chat_completion_with_prompt_document_injection(self): # Call the function (this will test the prompt injection logic) await chat_completion( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) mock_get_prompt.assert_called_once_with(mock_prompt_document) @@ -354,12 +350,6 @@ async def test_chat_completion_without_prompt_document(self): patch( "nilai_api.routers.endpoints.chat.handle_web_search" ) as mock_handle_web_search, - patch( - "nilai_api.routers.endpoints.chat.UserManager.update_token_usage" - ) as mock_update_usage, - patch( - "nilai_api.routers.endpoints.chat.QueryLogManager.log_query" - ) as mock_log_query, patch( "nilai_api.routers.endpoints.chat.handle_tool_workflow" ) as mock_handle_tool_workflow, @@ -415,7 +405,10 @@ async def test_chat_completion_without_prompt_document(self): # Call the function await chat_completion( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) # Should not call get_prompt_from_nildb when no prompt document @@ -431,7 +424,7 @@ async def test_responses_with_prompt_document_injection(self): ) mock_user = MagicMock() - mock_user.userid = "test-user-id" + mock_user.user_id = "test-user-id" mock_user.name = "Test User" mock_user.apikey = "test-api-key" mock_user.rate_limits = RateLimits().get_effective_limits() @@ -471,12 +464,6 @@ async def test_responses_with_prompt_document_injection(self): patch( "nilai_api.routers.endpoints.responses.state.get_model" ) as mock_get_model, - patch( - "nilai_api.routers.endpoints.responses.UserManager.update_token_usage" - ) as mock_update_usage, - patch( - "nilai_api.routers.endpoints.responses.QueryLogManager.log_query" - ) as mock_log_query, patch( "nilai_api.routers.endpoints.responses.handle_responses_tool_workflow" ) as mock_handle_tool_workflow, @@ -489,9 +476,6 @@ async def test_responses_with_prompt_document_injection(self): mock_model_endpoint.metadata.multimodal_support = True mock_get_model.return_value = mock_model_endpoint - mock_update_usage.return_value = None - mock_log_query.return_value = None - mock_client_instance = MagicMock() mock_response = MagicMock() mock_response.model_dump.return_value = response_payload @@ -505,8 +489,13 @@ async def test_responses_with_prompt_document_injection(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + mock_log_ctx = MagicMock() + await create_response( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) mock_get_prompt.assert_called_once_with(mock_prompt_document) @@ -521,7 +510,7 @@ async def test_responses_prompt_document_extraction_error(self): ) mock_user = MagicMock() - mock_user.userid = "test-user-id" + mock_user.user_id = "test-user-id" mock_user.name = "Test User" mock_user.apikey = "test-api-key" mock_user.rate_limits = RateLimits().get_effective_limits() @@ -548,8 +537,16 @@ async def test_responses_prompt_document_extraction_error(self): mock_get_prompt.side_effect = Exception("Unable to extract prompt") + mock_meter = MagicMock() + mock_log_ctx = MagicMock() + with pytest.raises(HTTPException) as exc_info: - await create_response(req=request, auth_info=mock_auth_info) + await create_response( + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, + ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert ( @@ -563,7 +560,7 @@ async def test_responses_without_prompt_document(self): from nilai_api.routers.endpoints.responses import create_response mock_user = MagicMock() - mock_user.userid = "test-user-id" + mock_user.user_id = "test-user-id" mock_user.name = "Test User" mock_user.apikey = "test-api-key" mock_user.rate_limits = RateLimits().get_effective_limits() @@ -605,12 +602,6 @@ async def test_responses_without_prompt_document(self): patch( "nilai_api.routers.endpoints.responses.state.get_model" ) as mock_get_model, - patch( - "nilai_api.routers.endpoints.responses.UserManager.update_token_usage" - ) as mock_update_usage, - patch( - "nilai_api.routers.endpoints.responses.QueryLogManager.log_query" - ) as mock_log_query, patch( "nilai_api.routers.endpoints.responses.handle_responses_tool_workflow" ) as mock_handle_tool_workflow, @@ -621,9 +612,6 @@ async def test_responses_without_prompt_document(self): mock_model_endpoint.metadata.multimodal_support = True mock_get_model.return_value = mock_model_endpoint - mock_update_usage.return_value = None - mock_log_query.return_value = None - mock_client_instance = MagicMock() mock_response = MagicMock() mock_response.model_dump.return_value = response_payload @@ -637,8 +625,13 @@ async def test_responses_without_prompt_document(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + mock_log_ctx = MagicMock() + await create_response( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) mock_get_prompt.assert_not_called() diff --git a/tests/unit/nilai_api/routers/test_responses_private.py b/tests/unit/nilai_api/routers/test_responses_private.py index 4e2f7932..d5962dfc 100644 --- a/tests/unit/nilai_api/routers/test_responses_private.py +++ b/tests/unit/nilai_api/routers/test_responses_private.py @@ -24,7 +24,7 @@ async def test_runs_in_a_loop(): @pytest.fixture def mock_user(): mock = MagicMock(spec=UserModel) - mock.userid = "test-user-id" + mock.user_id = "test-user-id" mock.name = "Test User" mock.apikey = "test-api-key" mock.prompt_tokens = 100 @@ -50,18 +50,19 @@ def mock_user_manager(mock_user, mocker): return_value={ "prompt_tokens": 100, "completion_tokens": 50, + "total_tokens": 150, "queries": 10, }, ) mocker.patch.object(QueryLogManager, "log_query") - - # Patch UserManager.check_user instead of check_api_key - mocker.patch.object( - UserManager, - "check_user", + + # Mock validate_credential for authentication + mocker.patch( + "nilai_api.auth.strategies.validate_credential", + new_callable=AsyncMock, return_value=mock_user, ) - + return UserManager @@ -103,7 +104,7 @@ def mock_metering_context(mocker): @pytest.fixture -def client(mock_user_manager, mock_metering_context): +def client(mock_user_manager, mock_state, mock_metering_context): from nilai_api.app import app from nilai_api.credit import LLMMeter @@ -175,6 +176,11 @@ def test_create_response(mock_user, mock_state, mock_user_manager, mocker, clien "nilai_api.routers.endpoints.responses.handle_responses_tool_workflow", return_value=(response_data, 0, 0), ) + mocker.patch( + "nilai_api.routers.endpoints.responses.state.get_model", + return_value=model_endpoint, + ) + mocker.patch("nilai_api.db.logs.QueryLogContext.commit", new_callable=AsyncMock) payload = { "model": "meta-llama/Llama-3.2-1B-Instruct", @@ -273,6 +279,11 @@ async def chunk_generator(): "nilai_api.routers.endpoints.responses.AsyncOpenAI", return_value=mock_async_openai_instance, ) + mocker.patch( + "nilai_api.routers.endpoints.responses.state.get_model", + return_value=model_endpoint, + ) + mocker.patch("nilai_api.db.logs.QueryLogContext.commit", new_callable=AsyncMock) payload = { "model": "meta-llama/Llama-3.2-1B-Instruct", diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index ff85470e..bd8a41e4 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -86,7 +86,7 @@ async def web_search_extractor(_): rate_limit = RateLimit(web_search_extractor=web_search_extractor) user_limits = UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, From d0b3a4bfbd2508ce9d6d99a7e7cac465076a361d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Fri, 28 Nov 2025 16:04:10 +0100 Subject: [PATCH 5/7] fix: improved and fixed gpt oss CI docker compose file --- .../compose/docker-compose.gpt-20b-gpu.ci.yml | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/docker/compose/docker-compose.gpt-20b-gpu.ci.yml b/docker/compose/docker-compose.gpt-20b-gpu.ci.yml index dcfef4cb..5fb24352 100644 --- a/docker/compose/docker-compose.gpt-20b-gpu.ci.yml +++ b/docker/compose/docker-compose.gpt-20b-gpu.ci.yml @@ -7,7 +7,7 @@ services: devices: - driver: nvidia count: 1 - capabilities: [gpu] + capabilities: [ gpu ] ulimits: memlock: -1 @@ -16,27 +16,20 @@ services: - .env restart: unless-stopped depends_on: - etcd: + redis: condition: service_healthy command: > - --model openai/gpt-oss-20b - --gpu-memory-utilization 0.95 - --max-model-len 10000 - --max-num-batched-tokens 10000 - --max-num-seqs 2 - --tensor-parallel-size 1 - --uvicorn-log-level warning - --async-scheduling + --model openai/gpt-oss-20b --gpu-memory-utilization 0.95 --max-model-len 10000 --max-num-batched-tokens 10000 --max-num-seqs 2 --tensor-parallel-size 1 --uvicorn-log-level warning --async-scheduling environment: - SVC_HOST=gpt_20b_gpu - SVC_PORT=8000 - - ETCD_HOST=etcd - - ETCD_PORT=2379 + - DISCOVERY_HOST=redis + - DISCOVERY_PORT=6379 - TOOL_SUPPORT=true volumes: - - hugging_face_models:/root/.cache/huggingface # cache models + - hugging_face_models:/root/.cache/huggingface # cache models healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + test: [ "CMD", "curl", "-f", "http://localhost:8000/health" ] interval: 30s retries: 10 start_period: 900s From 7d78d8aea362f2110cd85479bea682b98a4f8019 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 1 Dec 2025 10:13:03 +0100 Subject: [PATCH 6/7] fix: e2e fixes for responses endpoints --- tests/e2e/conftest.py | 26 +++++++++++++++++++++++++- tests/e2e/nuc.py | 23 +++++++++++++++++++++++ tests/e2e/test_responses.py | 17 +++++++++++++---- tests/e2e/test_responses_http.py | 11 +++++++---- 4 files changed, 68 insertions(+), 9 deletions(-) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 7c132d82..36f783b3 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -3,6 +3,7 @@ get_rate_limited_nuc_token, get_invalid_rate_limited_nuc_token, get_document_id_nuc_token, + get_invalid_nildb_nuc_token, ) import httpx import pytest @@ -107,6 +108,22 @@ def document_id_client(): ) +@pytest.fixture +def invalid_nildb(): + """Create an HTTPX client with default headers""" + invocation_token = get_invalid_nildb_nuc_token() + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + timeout=None, + verify=False, + ) + + # ============================================================================ # OpenAI SDK Client Fixtures (for test_chat_completions.py, test_responses.py) # ============================================================================ @@ -173,6 +190,13 @@ def document_id_openai_client(): return _create_openai_client(invocation_token) +@pytest.fixture +def nildb_openai_client(): + """Create an OpenAI SDK client with document ID token""" + invocation_token = get_invalid_nildb_nuc_token() + return _create_openai_client(invocation_token) + + @pytest.fixture def high_web_search_rate_limit(monkeypatch): """Set high rate limits for web search for RPS tests""" @@ -212,4 +236,4 @@ def high_web_search_rate_limit(monkeypatch): # async_client = async_openai_client # rate_limited_client = rate_limited_openai_client # invalid_rate_limited_client = invalid_rate_limited_openai_client -# nildb_client = document_id_openai_client +nildb_client = document_id_openai_client diff --git a/tests/e2e/nuc.py b/tests/e2e/nuc.py index 0dac8edc..034b3a95 100644 --- a/tests/e2e/nuc.py +++ b/tests/e2e/nuc.py @@ -2,6 +2,9 @@ NilAuthPrivateKey, ) +from nuc.builder import NucTokenBuilder +from nuc.token import Did, InvocationBody, Command + from nilai_py import ( Client, DelegationTokenServer, @@ -124,3 +127,23 @@ def get_document_id_nuc_client() -> Client: def get_document_id_nuc_token() -> str: """Convenience function for getting NILDB NUC tokens.""" return get_document_id_nuc_client()._get_invocation_token() + + +def get_invalid_nildb_nuc_token() -> str: + """Convenience function for getting NILDB NUC tokens.""" + http_client = DefaultHttpxClient(verify=False) + client = Client( + base_url="https://localhost/nuc/v1", + auth_type=AuthType.API_KEY, + http_client=http_client, + api_key=PRIVATE_KEY, + ) + + invocation_token: str = ( + NucTokenBuilder.extending(client.root_token) + .body(InvocationBody(args={})) + .audience(Did(client.nilai_public_key.serialize())) + .command(Command(["nil", "db", "generate"])) + .build(NilAuthPrivateKey(PRIVATE_KEY)) + ) + return invocation_token diff --git a/tests/e2e/test_responses.py b/tests/e2e/test_responses.py index 9e7ba64e..95da3d0c 100644 --- a/tests/e2e/test_responses.py +++ b/tests/e2e/test_responses.py @@ -3,7 +3,7 @@ import pytest import pytest_asyncio -from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter +from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter, ENVIRONMENT # ============================================================================ @@ -43,6 +43,12 @@ def nildb_client(document_id_openai_client): return document_id_openai_client +@pytest.fixture +def invalid_nildb(invalid_document_id_openai_client): + """Alias for invalid_document_id_openai_client fixture from conftest.py""" + return invalid_document_id_openai_client + + @pytest.mark.parametrize("model", test_models) def test_response_generation(client, model): """Test basic response generation with different models""" @@ -161,14 +167,14 @@ def test_invalid_rate_limiting_nucs(invalid_rate_limited_client, model): @pytest.mark.skipif( AUTH_STRATEGY != "nuc", reason="NUC rate limiting not used with API key" ) -def test_invalid_nildb_command_nucs(nildb_client, model): +def test_invalid_nildb_command_nucs(invalid_nildb, model): """Test invalid NILDB command handling""" import openai forbidden = False for _ in range(4): try: - nildb_client.responses.create( + invalid_nildb.responses.create( model=model, input="What is the capital of France?", instructions="You are a helpful assistant that provides accurate and concise information.", @@ -447,7 +453,6 @@ def test_usage_endpoint(client): "total_tokens", "completion_tokens", "prompt_tokens", - "queries", ] for key in expected_keys: assert key in usage_data, f"Expected key {key} not found in usage data" @@ -458,6 +463,10 @@ def test_usage_endpoint(client): pytest.fail(f"Error testing usage endpoint: {str(e)}") +@pytest.mark.skipif( + ENVIRONMENT != "mainnet", + reason="Attestation endpoint not available in non-mainnet environment", +) def test_attestation_endpoint(client): """Test retrieving attestation report""" try: diff --git a/tests/e2e/test_responses_http.py b/tests/e2e/test_responses_http.py index ae9e4f74..17d632a3 100644 --- a/tests/e2e/test_responses_http.py +++ b/tests/e2e/test_responses_http.py @@ -4,7 +4,7 @@ import httpx import pytest -from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter +from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter, ENVIRONMENT @pytest.mark.parametrize("model", test_models) @@ -435,12 +435,12 @@ def test_invalid_rate_limiting_nucs(invalid_rate_limited_client): @pytest.mark.skipif( AUTH_STRATEGY != "nuc", reason="NUC rate limiting not used with API key" ) -def test_invalid_nildb_command_nucs(nildb_client): +def test_invalid_nildb_command_nucs(invalid_nildb): payload = { "model": test_models[0], "input": "What is your name?", } - response = nildb_client.post("/responses", json=payload) + response = invalid_nildb.post("/responses", json=payload) assert response.status_code == 401, "Invalid NILDB command should return 401" @@ -627,7 +627,6 @@ def test_usage_endpoint(client): "total_tokens", "completion_tokens", "prompt_tokens", - "queries", ] for key in expected_keys: assert key in usage_data, f"Expected key {key} not found in usage data" @@ -638,6 +637,10 @@ def test_usage_endpoint(client): pytest.fail(f"Error testing usage endpoint: {str(e)}") +@pytest.mark.skipif( + ENVIRONMENT != "mainnet", + reason="Attestation endpoint not available in non-mainnet environment", +) def test_attestation_endpoint(client): try: import requests From 9717772449ac05f7448131e767b5bc96e32be55f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Mon, 1 Dec 2025 13:21:29 +0100 Subject: [PATCH 7/7] fix: nildb commands --- tests/e2e/conftest.py | 2 +- tests/e2e/nuc.py | 3 ++- tests/e2e/test_responses.py | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 36f783b3..fc57b1c6 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -191,7 +191,7 @@ def document_id_openai_client(): @pytest.fixture -def nildb_openai_client(): +def invalid_nildb_openai_client(): """Create an OpenAI SDK client with document ID token""" invocation_token = get_invalid_nildb_nuc_token() return _create_openai_client(invocation_token) diff --git a/tests/e2e/nuc.py b/tests/e2e/nuc.py index 034b3a95..f63248c8 100644 --- a/tests/e2e/nuc.py +++ b/tests/e2e/nuc.py @@ -131,6 +131,7 @@ def get_document_id_nuc_token() -> str: def get_invalid_nildb_nuc_token() -> str: """Convenience function for getting NILDB NUC tokens.""" + private_key = NilAuthPrivateKey(bytes.fromhex(PRIVATE_KEY)) http_client = DefaultHttpxClient(verify=False) client = Client( base_url="https://localhost/nuc/v1", @@ -144,6 +145,6 @@ def get_invalid_nildb_nuc_token() -> str: .body(InvocationBody(args={})) .audience(Did(client.nilai_public_key.serialize())) .command(Command(["nil", "db", "generate"])) - .build(NilAuthPrivateKey(PRIVATE_KEY)) + .build(private_key) ) return invocation_token diff --git a/tests/e2e/test_responses.py b/tests/e2e/test_responses.py index 95da3d0c..79679d7c 100644 --- a/tests/e2e/test_responses.py +++ b/tests/e2e/test_responses.py @@ -44,9 +44,9 @@ def nildb_client(document_id_openai_client): @pytest.fixture -def invalid_nildb(invalid_document_id_openai_client): - """Alias for invalid_document_id_openai_client fixture from conftest.py""" - return invalid_document_id_openai_client +def invalid_nildb(invalid_nildb_openai_client): + """Alias for invalid_nildb_openai_client fixture from conftest.py""" + return invalid_nildb_openai_client @pytest.mark.parametrize("model", test_models)