diff --git a/.gitignore b/.gitignore index f3d8ab42..7cbf1bfd 100644 --- a/.gitignore +++ b/.gitignore @@ -179,3 +179,5 @@ private_key.key.lock development-compose.yml production-compose.yml + +.vscode/ diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 24c9e519..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "workbench.colorCustomizations": { - "activityBar.activeBackground": "#65c89b", - "activityBar.background": "#65c89b", - "activityBar.foreground": "#15202b", - "activityBar.inactiveForeground": "#15202b99", - "activityBarBadge.background": "#945bc4", - "activityBarBadge.foreground": "#e7e7e7", - "commandCenter.border": "#15202b99", - "sash.hoverBorder": "#65c89b", - "statusBar.background": "#42b883", - "statusBar.foreground": "#15202b", - "statusBarItem.hoverBackground": "#359268", - "statusBarItem.remoteBackground": "#42b883", - "statusBarItem.remoteForeground": "#15202b", - "titleBar.activeBackground": "#42b883", - "titleBar.activeForeground": "#15202b", - "titleBar.inactiveBackground": "#42b88399", - "titleBar.inactiveForeground": "#15202b99" - }, - "peacock.color": "#42b883" -} diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index a3c77e01..fe4ed99d 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -33,8 +33,6 @@ services: condition: service_healthy nilauth-credit-server: condition: service_healthy - environment: - - POSTGRES_DB=${POSTGRES_DB_NUC} volumes: - ./nilai-api/:/app/nilai-api/ - ./packages/:/app/packages/ @@ -96,7 +94,7 @@ services: retries: 5 nilauth-credit-server: - image: ghcr.io/nillionnetwork/nilauth-credit:sha-cb9e36a + image: ghcr.io/nillionnetwork/nilauth-credit:sha-6754a1d platform: linux/amd64 # for macOS to force running on Rosetta 2 environment: DATABASE_URL: postgresql://nilauth:nilauth_dev_password@nilauth-postgres:5432/nilauth_credit diff --git a/docker/compose/docker-compose.gpt-20b-gpu.ci.yml b/docker/compose/docker-compose.gpt-20b-gpu.ci.yml index dcfef4cb..5fb24352 100644 --- a/docker/compose/docker-compose.gpt-20b-gpu.ci.yml +++ b/docker/compose/docker-compose.gpt-20b-gpu.ci.yml @@ -7,7 +7,7 @@ services: devices: - driver: nvidia count: 1 - capabilities: [gpu] + capabilities: [ gpu ] ulimits: memlock: -1 @@ -16,27 +16,20 @@ services: - .env restart: unless-stopped depends_on: - etcd: + redis: condition: service_healthy command: > - --model openai/gpt-oss-20b - --gpu-memory-utilization 0.95 - --max-model-len 10000 - --max-num-batched-tokens 10000 - --max-num-seqs 2 - --tensor-parallel-size 1 - --uvicorn-log-level warning - --async-scheduling + --model openai/gpt-oss-20b --gpu-memory-utilization 0.95 --max-model-len 10000 --max-num-batched-tokens 10000 --max-num-seqs 2 --tensor-parallel-size 1 --uvicorn-log-level warning --async-scheduling environment: - SVC_HOST=gpt_20b_gpu - SVC_PORT=8000 - - ETCD_HOST=etcd - - ETCD_PORT=2379 + - DISCOVERY_HOST=redis + - DISCOVERY_PORT=6379 - TOOL_SUPPORT=true volumes: - - hugging_face_models:/root/.cache/huggingface # cache models + - hugging_face_models:/root/.cache/huggingface # cache models healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + test: [ "CMD", "curl", "-f", "http://localhost:8000/health" ] interval: 30s retries: 10 start_period: 900s diff --git a/grafana/runtime-data/dashboards/nuc-query-data.json b/grafana/runtime-data/dashboards/nuc-query-data.json index d66fd428..c7bbb6b2 100644 --- a/grafana/runtime-data/dashboards/nuc-query-data.json +++ b/grafana/runtime-data/dashboards/nuc-query-data.json @@ -126,7 +126,7 @@ "editorMode": "code", "format": "time_series", "rawQuery": true, - "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", + "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", "refId": "A", "sql": { "columns": [ @@ -218,7 +218,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", + "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", "refId": "A", "sql": { "columns": [ @@ -352,7 +352,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", "refId": "A", "sql": { "columns": [ @@ -360,7 +360,7 @@ "alias": "\"User\"", "parameters": [ { - "name": "userid", + "name": "user_id", "type": "functionParameter" } ], @@ -381,7 +381,7 @@ "groupBy": [ { "property": { - "name": "userid", + "name": "user_id", "type": "string" }, "type": "groupBy" @@ -481,7 +481,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", "refId": "A", "sql": { "columns": [ diff --git a/grafana/runtime-data/dashboards/query-data.json b/grafana/runtime-data/dashboards/query-data.json index 8e0b774f..f33f87a8 100644 --- a/grafana/runtime-data/dashboards/query-data.json +++ b/grafana/runtime-data/dashboards/query-data.json @@ -126,7 +126,7 @@ "editorMode": "code", "format": "time_series", "rawQuery": true, - "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", + "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", "refId": "A", "sql": { "columns": [ @@ -218,7 +218,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", + "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", "refId": "A", "sql": { "columns": [ @@ -352,7 +352,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", "refId": "A", "sql": { "columns": [ @@ -360,7 +360,7 @@ "alias": "\"User\"", "parameters": [ { - "name": "userid", + "name": "user_id", "type": "functionParameter" } ], @@ -381,7 +381,7 @@ "groupBy": [ { "property": { - "name": "userid", + "name": "user_id", "type": "string" }, "type": "groupBy" @@ -481,7 +481,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", "refId": "A", "sql": { "columns": [ diff --git a/grafana/runtime-data/dashboards/testnet-nuc-query-data.json b/grafana/runtime-data/dashboards/testnet-nuc-query-data.json index f98d70e9..358ba4eb 100644 --- a/grafana/runtime-data/dashboards/testnet-nuc-query-data.json +++ b/grafana/runtime-data/dashboards/testnet-nuc-query-data.json @@ -126,7 +126,7 @@ "editorMode": "code", "format": "time_series", "rawQuery": true, - "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", + "rawSql": "SELECT \n date_trunc('${time_granularity}', q.query_timestamp) AS \"time\", \n COUNT(q.id) AS \"Queries\"\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:value}' = 'All' OR u.name = '${user_filter:value}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY date_trunc('${time_granularity}', q.query_timestamp)\nORDER BY \"time\";", "refId": "A", "sql": { "columns": [ @@ -218,7 +218,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", + "rawSql": "SELECT \n q.model, \n COUNT(q.id) AS total_queries\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\nGROUP BY q.model\nORDER BY total_queries DESC;", "refId": "A", "sql": { "columns": [ @@ -352,7 +352,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 12 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name\nORDER BY \"Queries\" DESC;", "refId": "A", "sql": { "columns": [ @@ -360,7 +360,7 @@ "alias": "\"User\"", "parameters": [ { - "name": "userid", + "name": "user_id", "type": "functionParameter" } ], @@ -381,7 +381,7 @@ "groupBy": [ { "property": { - "name": "userid", + "name": "user_id", "type": "string" }, "type": "groupBy" @@ -481,7 +481,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", + "rawSql": "SELECT \n CASE \n WHEN LENGTH(u.name) > 8 THEN LEFT(u.name, 3) || '...' || RIGHT(u.name, 3)\n ELSE u.name\n END AS \"User\",\n q.model AS \"Model\",\n COUNT(q.id) AS \"Queries\",\n MIN(q.query_timestamp) AS \"First Query\",\n MAX(q.query_timestamp) AS \"Last Query\"\nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')\n AND ('${model_filter:single}' = 'All' OR q.model = '${model_filter:single}')\nGROUP BY u.name, q.model\nORDER BY \"Queries\" DESC\nLIMIT 20;", "refId": "A", "sql": { "columns": [ diff --git a/grafana/runtime-data/dashboards/totals-data.json b/grafana/runtime-data/dashboards/totals-data.json index 2db20c7d..ff66ce0e 100644 --- a/grafana/runtime-data/dashboards/totals-data.json +++ b/grafana/runtime-data/dashboards/totals-data.json @@ -83,7 +83,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "rawSql": "SELECT \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", "refId": "A", "sql": { "columns": [ @@ -165,7 +165,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT SUM(total_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "rawSql": "SELECT SUM(total_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", "refId": "A", "sql": { "columns": [ @@ -248,7 +248,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT SUM(q.prompt_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "rawSql": "SELECT SUM(q.prompt_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", "refId": "A", "sql": { "columns": [ @@ -331,7 +331,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT SUM(q.completion_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.userid = u.userid\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", + "rawSql": "SELECT SUM(q.completion_tokens) AS total_tokens\nFROM query_logs q\nLEFT JOIN users u ON q.user_id = u.user_id\nWHERE \n q.query_timestamp >= $__timeFrom()\n AND q.query_timestamp <= $__timeTo()\n AND ('${user_filter:single}' = 'All' OR u.name = '${user_filter:single}')", "refId": "A", "sql": { "columns": [ @@ -397,4 +397,4 @@ "uid": "aex54yzf0nmyoc", "version": 1, "weekStart": "" -} \ No newline at end of file +} diff --git a/grafana/runtime-data/dashboards/usage-data.json b/grafana/runtime-data/dashboards/usage-data.json index 88857f91..a22bf914 100644 --- a/grafana/runtime-data/dashboards/usage-data.json +++ b/grafana/runtime-data/dashboards/usage-data.json @@ -299,7 +299,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE q.query_timestamp >= NOW() - INTERVAL '1 hours'\nGROUP BY u.email;", + "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE q.query_timestamp >= NOW() - INTERVAL '1 hours'\nGROUP BY u.email;", "refId": "A", "sql": { "columns": [ @@ -307,7 +307,7 @@ "alias": "\"User ID\"", "parameters": [ { - "name": "userid", + "name": "user_id", "type": "functionParameter" } ], @@ -328,7 +328,7 @@ "groupBy": [ { "property": { - "name": "userid", + "name": "user_id", "type": "string" }, "type": "groupBy" @@ -430,7 +430,7 @@ "editorMode": "code", "format": "table", "rawQuery": true, - "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.userid = u.userid\nWHERE q.query_timestamp >= NOW() - INTERVAL '7 days'\nGROUP BY u.email;", + "rawSql": "SELECT \n u.email AS \"User ID\", \n COUNT(q.id) AS \"Queries\" \nFROM query_logs q \nJOIN users u ON q.user_id = u.user_id\nWHERE q.query_timestamp >= NOW() - INTERVAL '7 days'\nGROUP BY u.email;", "refId": "A", "sql": { "columns": [ diff --git a/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py b/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py new file mode 100644 index 00000000..ebaca5a6 --- /dev/null +++ b/nilai-api/alembic/versions/0ba073468afc_chore_improved_database_schema.py @@ -0,0 +1,206 @@ +"""chore: merged database schema updates + +Revision ID: 0ba073468afc +Revises: ea942d6c7a00 +Create Date: 2025-10-31 09:43:12.022675 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "0ba073468afc" +down_revision: Union[str, None] = "9ddf28cf6b6f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### merged commands from ea942d6c7a00 and 0ba073468afc ### + # query_logs: new telemetry columns (with defaults to backfill existing rows) + op.add_column( + "query_logs", + sa.Column( + "tool_calls", sa.Integer(), server_default=sa.text("0"), nullable=False + ), + ) + op.add_column( + "query_logs", + sa.Column( + "temperature", sa.Float(), server_default=sa.text("0.9"), nullable=True + ), + ) + op.add_column( + "query_logs", + sa.Column( + "max_tokens", sa.Integer(), server_default=sa.text("4096"), nullable=True + ), + ) + op.add_column( + "query_logs", + sa.Column( + "response_time_ms", + sa.Integer(), + server_default=sa.text("-1"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "model_response_time_ms", + sa.Integer(), + server_default=sa.text("-1"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "tool_response_time_ms", + sa.Integer(), + server_default=sa.text("-1"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "was_streamed", + sa.Boolean(), + server_default=sa.text("False"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "was_multimodal", + sa.Boolean(), + server_default=sa.text("False"), + nullable=False, + ), + ) + op.add_column( + "query_logs", + sa.Column( + "was_nildb", sa.Boolean(), server_default=sa.text("False"), nullable=False + ), + ) + op.add_column( + "query_logs", + sa.Column( + "was_nilrag", sa.Boolean(), server_default=sa.text("False"), nullable=False + ), + ) + op.add_column( + "query_logs", + sa.Column( + "error_code", sa.Integer(), server_default=sa.text("200"), nullable=False + ), + ) + op.add_column( + "query_logs", + sa.Column( + "error_message", sa.Text(), server_default=sa.text("'OK'"), nullable=False + ), + ) + + # query_logs: remove FK to users.userid before dropping the column later + op.drop_constraint("query_logs_userid_fkey", "query_logs", type_="foreignkey") + + # query_logs: add lockid and index, drop legacy userid and its index + op.add_column( + "query_logs", sa.Column("lockid", sa.String(length=75), nullable=False) + ) + op.drop_index("ix_query_logs_userid", table_name="query_logs") + op.create_index( + op.f("ix_query_logs_lockid"), "query_logs", ["lockid"], unique=False + ) + op.drop_column("query_logs", "userid") + + # users: drop legacy token counters + op.drop_column("users", "prompt_tokens") + op.drop_column("users", "completion_tokens") + + # users: reshape identity columns and indexes + op.add_column("users", sa.Column("user_id", sa.String(length=75), nullable=False)) + op.drop_index("ix_users_apikey", table_name="users") + op.drop_index("ix_users_userid", table_name="users") + op.create_index(op.f("ix_users_user_id"), "users", ["user_id"], unique=False) + op.drop_column("users", "last_activity") + op.drop_column("users", "userid") + op.drop_column("users", "apikey") + op.drop_column("users", "signup_date") + op.drop_column("users", "queries") + op.drop_column("users", "name") + # ### end merged commands ### + + +def downgrade() -> None: + # ### revert merged commands back to 9ddf28cf6b6f ### + # users: restore legacy columns and indexes + op.add_column("users", sa.Column("name", sa.VARCHAR(length=100), nullable=False)) + op.add_column("users", sa.Column("queries", sa.INTEGER(), nullable=False)) + op.add_column( + "users", + sa.Column( + "signup_date", + postgresql.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + ) + op.add_column("users", sa.Column("apikey", sa.VARCHAR(length=75), nullable=False)) + op.add_column("users", sa.Column("userid", sa.VARCHAR(length=75), nullable=False)) + op.add_column( + "users", + sa.Column("last_activity", postgresql.TIMESTAMP(timezone=True), nullable=True), + ) + op.drop_index(op.f("ix_users_user_id"), table_name="users") + op.create_index("ix_users_userid", "users", ["userid"], unique=False) + op.create_index("ix_users_apikey", "users", ["apikey"], unique=False) + op.drop_column("users", "user_id") + op.add_column( + "users", + sa.Column( + "completion_tokens", + sa.INTEGER(), + server_default=sa.text("0"), + nullable=False, + ), + ) + op.add_column( + "users", + sa.Column( + "prompt_tokens", sa.INTEGER(), server_default=sa.text("0"), nullable=False + ), + ) + + # query_logs: restore userid, index and FK; drop new columns + op.add_column( + "query_logs", sa.Column("userid", sa.VARCHAR(length=75), nullable=False) + ) + op.drop_index(op.f("ix_query_logs_lockid"), table_name="query_logs") + op.create_index("ix_query_logs_userid", "query_logs", ["userid"], unique=False) + op.create_foreign_key( + "query_logs_userid_fkey", "query_logs", "users", ["userid"], ["userid"] + ) + op.drop_column("query_logs", "lockid") + op.drop_column("query_logs", "error_message") + op.drop_column("query_logs", "error_code") + op.drop_column("query_logs", "was_nilrag") + op.drop_column("query_logs", "was_nildb") + op.drop_column("query_logs", "was_multimodal") + op.drop_column("query_logs", "was_streamed") + op.drop_column("query_logs", "tool_response_time_ms") + op.drop_column("query_logs", "model_response_time_ms") + op.drop_column("query_logs", "response_time_ms") + op.drop_column("query_logs", "max_tokens") + op.drop_column("query_logs", "temperature") + op.drop_column("query_logs", "tool_calls") + # ### end revert ### diff --git a/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py b/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py new file mode 100644 index 00000000..4c20bb6d --- /dev/null +++ b/nilai-api/alembic/versions/43b23c73035b_fix_userid_change_to_user_id.py @@ -0,0 +1,37 @@ +"""fix: userid change to user_id + +Revision ID: 43b23c73035b +Revises: 0ba073468afc +Create Date: 2025-11-03 11:33:03.006101 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "43b23c73035b" +down_revision: Union[str, None] = "0ba073468afc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "query_logs", sa.Column("user_id", sa.String(length=75), nullable=False) + ) + op.create_index( + op.f("ix_query_logs_user_id"), "query_logs", ["user_id"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_query_logs_user_id"), table_name="query_logs") + op.drop_column("query_logs", "user_id") + # ### end Alembic commands ### diff --git a/nilai-api/examples/users.py b/nilai-api/examples/users.py deleted file mode 100644 index b6b206d5..00000000 --- a/nilai-api/examples/users.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/python - -from nilai_api.db.logs import QueryLogManager -from nilai_api.db.users import UserManager - - -# Example Usage -async def main(): - # Add some users - bob = await UserManager.insert_user("Bob", "bob@example.com") - alice = await UserManager.insert_user("Alice", "alice@example.com") - - print(f"Bob's details: {bob}") - print(f"Alice's details: {alice}") - - # Check API key - user_name = await UserManager.check_api_key(bob.apikey) - print(f"API key validation: {user_name}") - - # Update and retrieve token usage - await UserManager.update_token_usage( - bob.userid, prompt_tokens=50, completion_tokens=20 - ) - usage = await UserManager.get_user_token_usage(bob.userid) - print(f"Bob's token usage: {usage}") - - # Log a query - await QueryLogManager.log_query( - userid=bob.userid, - model="gpt-3.5-turbo", - prompt_tokens=8, - completion_tokens=7, - web_search_calls=1, - ) - - -if __name__ == "__main__": - import asyncio - from dotenv import load_dotenv - - load_dotenv() - - asyncio.run(main()) diff --git a/nilai-api/pyproject.toml b/nilai-api/pyproject.toml index 42a1cf4f..fd6f1eef 100644 --- a/nilai-api/pyproject.toml +++ b/nilai-api/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "trafilatura>=1.7.0", "secretvaults", "e2b-code-interpreter>=1.0.3", - "nilauth-credit-middleware==0.1.1", + "nilauth-credit-middleware>=0.1.2", ] diff --git a/nilai-api/src/nilai_api/attestation/__init__.py b/nilai-api/src/nilai_api/attestation/__init__.py index 61697229..ed13f68e 100644 --- a/nilai-api/src/nilai_api/attestation/__init__.py +++ b/nilai-api/src/nilai_api/attestation/__init__.py @@ -5,7 +5,7 @@ ATTESTATION_URL = "http://nilcc-attester/v2/report" -async def get_attestation_report() -> AttestationReport: +async def get_attestation_report(nonce: str) -> AttestationReport: """Get the attestation report""" try: @@ -13,6 +13,7 @@ async def get_attestation_report() -> AttestationReport: response: httpx.Response = await client.get(ATTESTATION_URL) response_json = response.json() return AttestationReport( + nonce=nonce, gpu_attestation=response_json["report"], cpu_attestation=response_json["gpu_token"], verifying_key="", # Added later by the API diff --git a/nilai-api/src/nilai_api/auth/__init__.py b/nilai-api/src/nilai_api/auth/__init__.py index 2e7cd6f7..2123685a 100644 --- a/nilai-api/src/nilai_api/auth/__init__.py +++ b/nilai-api/src/nilai_api/auth/__init__.py @@ -4,7 +4,6 @@ from logging import getLogger from nilai_api.config import CONFIG -from nilai_api.db.users import UserManager from nilai_api.auth.strategies import AuthenticationStrategy from nuc.validate import ValidationException @@ -36,7 +35,6 @@ async def get_auth_info( ) auth_info = await strategy(credentials.credentials) - await UserManager.update_last_activity(userid=auth_info.user.userid) return auth_info except AuthenticationError as e: raise e diff --git a/nilai-api/src/nilai_api/auth/nuc.py b/nilai-api/src/nilai_api/auth/nuc.py index 46459356..614d9ef1 100644 --- a/nilai-api/src/nilai_api/auth/nuc.py +++ b/nilai-api/src/nilai_api/auth/nuc.py @@ -86,11 +86,11 @@ def validate_nuc(nuc_token: str) -> Tuple[str, str]: # Validate the # Return the subject of the token, the subscription holder - subscription_holder = token.subject.public_key.hex() - user = token.issuer.public_key.hex() + subscription_holder = token.subject + user = token.issuer logger.info(f"Subscription holder: {subscription_holder}") logger.info(f"User: {user}") - return subscription_holder, user + return str(subscription_holder), str(user) def get_token_rate_limit(nuc_token: str) -> Optional[TokenRateLimits]: diff --git a/nilai-api/src/nilai_api/auth/strategies.py b/nilai-api/src/nilai_api/auth/strategies.py index 9917ee39..089e7e94 100644 --- a/nilai-api/src/nilai_api/auth/strategies.py +++ b/nilai-api/src/nilai_api/auth/strategies.py @@ -1,6 +1,6 @@ from typing import Callable, Awaitable, Optional -from datetime import datetime, timezone +from fastapi import HTTPException from nilai_api.db.users import UserManager, UserModel, UserData from nilai_api.auth.nuc import ( validate_nuc, @@ -11,11 +11,18 @@ from nilai_api.auth.common import ( PromptDocument, TokenRateLimits, - AuthenticationInfo, AuthenticationError, + AuthenticationInfo, +) + +from nilauth_credit_middleware import ( + CreditClientSingleton, ) +from nilauth_credit_middleware.api_model import ValidateCredentialResponse + from enum import Enum + # All strategies must return a UserModel # The strategies can raise any exception, which will be caught and converted to an AuthenticationError # The exception detail will be passed to the client @@ -44,18 +51,10 @@ async def wrapper(token) -> AuthenticationInfo: return await function(token) if token == allowed_token: - user_model: UserModel | None = await UserManager.check_user( - allowed_token + user_model = UserModel( + user_id=allowed_token, + rate_limits=None, ) - if user_model is None: - user_model = UserModel( - userid=allowed_token, - name=allowed_token, - apikey=allowed_token, - signup_date=datetime.now(timezone.utc), - ) - await UserManager.insert_user_model(user_model) - return AuthenticationInfo( user=UserData.from_sqlalchemy(user_model), token_rate_limit=None, @@ -68,16 +67,41 @@ async def wrapper(token) -> AuthenticationInfo: return decorator +async def validate_credential(credential: str, is_public: bool) -> UserModel: + """ + Validate a credential with nilauth credit middleware and return the user model + """ + credit_client = CreditClientSingleton.get_client() + try: + validate_response: ValidateCredentialResponse = ( + await credit_client.validate_credential(credential, is_public=is_public) + ) + except HTTPException as e: + if e.status_code == 404: + raise AuthenticationError(f"Credential not found: {e.detail}") + elif e.status_code == 401: + raise AuthenticationError(f"Credential is inactive: {e.detail}") + else: + raise AuthenticationError(f"Failed to validate credential: {e.detail}") + + user_model = await UserManager.check_user(validate_response.user_id) + if user_model is None: + user_model = UserModel( + user_id=validate_response.user_id, + rate_limits=None, + ) + return user_model + + @allow_token(CONFIG.docs.token) async def api_key_strategy(api_key: str) -> AuthenticationInfo: - user_model: Optional[UserModel] = await UserManager.check_api_key(api_key) - if user_model: - return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), - token_rate_limit=None, - prompt_document=None, - ) - raise AuthenticationError("Missing or invalid API key") + user_model = await validate_credential(api_key, is_public=False) + + return AuthenticationInfo( + user=UserData.from_sqlalchemy(user_model), + token_rate_limit=None, + prompt_document=None, + ) @allow_token(CONFIG.docs.token) @@ -89,20 +113,7 @@ async def nuc_strategy(nuc_token) -> AuthenticationInfo: token_rate_limits: Optional[TokenRateLimits] = get_token_rate_limit(nuc_token) prompt_document: Optional[PromptDocument] = get_token_prompt_document(nuc_token) - user_model: Optional[UserModel] = await UserManager.check_user(user) - if user_model: - return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), - token_rate_limit=token_rate_limits, - prompt_document=prompt_document, - ) - - user_model = UserModel( - userid=user, - name=user, - apikey=subscription_holder, - ) - await UserManager.insert_user_model(user_model) + user_model = await validate_credential(subscription_holder, is_public=True) return AuthenticationInfo( user=UserData.from_sqlalchemy(user_model), token_rate_limit=token_rate_limits, diff --git a/nilai-api/src/nilai_api/commands/add_user.py b/nilai-api/src/nilai_api/commands/add_user.py index e9f49e55..202b70d4 100644 --- a/nilai-api/src/nilai_api/commands/add_user.py +++ b/nilai-api/src/nilai_api/commands/add_user.py @@ -6,9 +6,7 @@ @click.command() -@click.option("--name", type=str, required=True, help="User Name") -@click.option("--apikey", type=str, help="API Key") -@click.option("--userid", type=str, help="User Id") +@click.option("--user_id", type=str, help="User Id") @click.option("--ratelimit-day", type=int, help="number of request per day") @click.option("--ratelimit-hour", type=int, help="number of request per hour") @click.option("--ratelimit-minute", type=int, help="number of request per minute") @@ -26,9 +24,7 @@ help="number of web search request per minute", ) def main( - name, - apikey: str | None, - userid: str | None, + user_id: str | None, ratelimit_day: int | None, ratelimit_hour: int | None, ratelimit_minute: int | None, @@ -38,9 +34,7 @@ def main( ): async def add_user(): user: UserModel = await UserManager.insert_user( - name, - apikey, - userid, + user_id, RateLimits( user_rate_limit_day=ratelimit_day, user_rate_limit_hour=ratelimit_hour, @@ -52,9 +46,7 @@ async def add_user(): ) json_user = json.dumps( { - "userid": user.userid, - "name": user.name, - "apikey": user.apikey, + "user_id": user.user_id, "ratelimit_day": user.rate_limits_obj.user_rate_limit_day, "ratelimit_hour": user.rate_limits_obj.user_rate_limit_hour, "ratelimit_minute": user.rate_limits_obj.user_rate_limit_minute, diff --git a/nilai-api/src/nilai_api/config/__init__.py b/nilai-api/src/nilai_api/config/__init__.py index 7af72318..3f19f85e 100644 --- a/nilai-api/src/nilai_api/config/__init__.py +++ b/nilai-api/src/nilai_api/config/__init__.py @@ -70,3 +70,4 @@ def prettify(self): ] logging.info(CONFIG.prettify()) +print(CONFIG.prettify()) diff --git a/nilai-api/src/nilai_api/credit.py b/nilai-api/src/nilai_api/credit.py index 3a06135e..b9d7ea6f 100644 --- a/nilai-api/src/nilai_api/credit.py +++ b/nilai-api/src/nilai_api/credit.py @@ -20,6 +20,9 @@ class NoOpMeteringContext: """A no-op metering context for requests that should skip metering (e.g., Docs Token).""" + def __init__(self): + self.lock_id: str = "noop-lock-id" + def set_response(self, response_data: dict) -> None: """No-op method that does nothing.""" pass diff --git a/nilai-api/src/nilai_api/db/__init__.py b/nilai-api/src/nilai_api/db/__init__.py index ee70ffe0..0cc8a623 100644 --- a/nilai-api/src/nilai_api/db/__init__.py +++ b/nilai-api/src/nilai_api/db/__init__.py @@ -14,11 +14,11 @@ from nilai_api.config import CONFIG -_engine: Optional[sqlalchemy.ext.asyncio.AsyncEngine] = None +_engine: Optional[sqlalchemy.ext.asyncio.AsyncEngine] = None # type: ignore[reportAttributeAccessIssue] _SessionLocal: Optional[sessionmaker] = None # Create base and engine with improved configuration -Base = sqlalchemy.orm.declarative_base() +Base = sqlalchemy.orm.declarative_base() # type: ignore[reportAttributeAccessIssue] logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def from_env() -> "DatabaseConfig": return DatabaseConfig(database_url) -def get_engine() -> sqlalchemy.ext.asyncio.AsyncEngine: +def get_engine() -> sqlalchemy.ext.asyncio.AsyncEngine: # type: ignore[reportAttributeAccessIssue] global _engine if _engine is None: config = DatabaseConfig.from_env() diff --git a/nilai-api/src/nilai_api/db/logs.py b/nilai-api/src/nilai_api/db/logs.py index 030c8696..4a78c8a7 100644 --- a/nilai-api/src/nilai_api/db/logs.py +++ b/nilai-api/src/nilai_api/db/logs.py @@ -1,12 +1,14 @@ import logging +import time from datetime import datetime, timezone +from typing import Optional +from nilai_common import Usage import sqlalchemy -from sqlalchemy import ForeignKey, Integer, String, DateTime, Text +from sqlalchemy import Integer, String, DateTime, Text, Boolean, Float from sqlalchemy.exc import SQLAlchemyError from nilai_api.db import Base, Column, get_db_session -from nilai_api.db.users import UserModel logger = logging.getLogger(__name__) @@ -16,9 +18,8 @@ class QueryLog(Base): __tablename__ = "query_logs" id: int = Column(Integer, primary_key=True, autoincrement=True) # type: ignore - userid: str = Column( - String(75), ForeignKey(UserModel.userid), nullable=False, index=True - ) # type: ignore + user_id: str = Column(String(75), nullable=False, index=True) # type: ignore + lockid: str = Column(String(75), nullable=False, index=True) # type: ignore query_timestamp: datetime = Column( DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False ) # type: ignore @@ -26,51 +27,285 @@ class QueryLog(Base): prompt_tokens: int = Column(Integer, nullable=False) # type: ignore completion_tokens: int = Column(Integer, nullable=False) # type: ignore total_tokens: int = Column(Integer, nullable=False) # type: ignore + tool_calls: int = Column(Integer, nullable=False) # type: ignore web_search_calls: int = Column(Integer, nullable=False) # type: ignore + temperature: Optional[float] = Column(Float, nullable=True) # type: ignore + max_tokens: Optional[int] = Column(Integer, nullable=True) # type: ignore + + response_time_ms: int = Column(Integer, nullable=False) # type: ignore + model_response_time_ms: int = Column(Integer, nullable=False) # type: ignore + tool_response_time_ms: int = Column(Integer, nullable=False) # type: ignore + + was_streamed: bool = Column(Boolean, nullable=False) # type: ignore + was_multimodal: bool = Column(Boolean, nullable=False) # type: ignore + was_nildb: bool = Column(Boolean, nullable=False) # type: ignore + was_nilrag: bool = Column(Boolean, nullable=False) # type: ignore + + error_code: int = Column(Integer, nullable=False) # type: ignore + error_message: str = Column(Text, nullable=False) # type: ignore def __repr__(self): - return f"" + return f"" + + +class QueryLogContext: + """ + Context manager for logging query metrics during a request. + Used as a FastAPI dependency to track request metrics. + """ + + def __init__(self): + self.user_id: Optional[str] = None + self.lockid: Optional[str] = None + self.model: Optional[str] = None + self.prompt_tokens: int = 0 + self.completion_tokens: int = 0 + self.tool_calls: int = 0 + self.web_search_calls: int = 0 + self.temperature: Optional[float] = None + self.max_tokens: Optional[int] = None + self.was_streamed: bool = False + self.was_multimodal: bool = False + self.was_nildb: bool = False + self.was_nilrag: bool = False + self.error_code: int = 0 + self.error_message: str = "" + + # Timing tracking + self.start_time: float = time.monotonic() + self.model_start_time: Optional[float] = None + self.model_end_time: Optional[float] = None + self.tool_start_time: Optional[float] = None + self.tool_end_time: Optional[float] = None + + def set_user(self, user_id: str) -> None: + """Set the user ID for this query.""" + self.user_id = user_id + + def set_lockid(self, lockid: str) -> None: + """Set the lock ID for this query.""" + self.lockid = lockid + + def set_model(self, model: str) -> None: + """Set the model name for this query.""" + self.model = model + + def set_request_params( + self, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + was_streamed: bool = False, + was_multimodal: bool = False, + was_nildb: bool = False, + was_nilrag: bool = False, + ) -> None: + """Set request parameters.""" + self.temperature = temperature + self.max_tokens = max_tokens + self.was_streamed = was_streamed + self.was_multimodal = was_multimodal + self.was_nildb = was_nildb + self.was_nilrag = was_nilrag + + def set_usage( + self, + prompt_tokens: int = 0, + completion_tokens: int = 0, + tool_calls: int = 0, + web_search_calls: int = 0, + ) -> None: + """Set token usage and feature usage.""" + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + self.tool_calls = tool_calls + self.web_search_calls = web_search_calls + + def set_error(self, error_code: int, error_message: str) -> None: + """Set error information.""" + self.error_code = error_code + self.error_message = error_message + + def start_model_timing(self) -> None: + """Mark the start of model inference.""" + self.model_start_time = time.monotonic() + + def end_model_timing(self) -> None: + """Mark the end of model inference.""" + self.model_end_time = time.monotonic() + + def start_tool_timing(self) -> None: + """Mark the start of tool execution.""" + self.tool_start_time = time.monotonic() + + def end_tool_timing(self) -> None: + """Mark the end of tool execution.""" + self.tool_end_time = time.monotonic() + + def _calculate_timings(self) -> tuple[int, int, int]: + """Calculate response times in milliseconds.""" + total_ms = int((time.monotonic() - self.start_time) * 1000) + + model_ms = 0 + if self.model_start_time and self.model_end_time: + model_ms = int((self.model_end_time - self.model_start_time) * 1000) + + tool_ms = 0 + if self.tool_start_time and self.tool_end_time: + tool_ms = int((self.tool_end_time - self.tool_start_time) * 1000) + + return total_ms, model_ms, tool_ms + + async def commit(self) -> None: + """ + Commit the query log to the database. + Should be called at the end of the request lifecycle. + """ + if not self.user_id or not self.model: + logger.warning( + "Skipping query log: user_id or model not set " + f"(user_id={self.user_id}, model={self.model})" + ) + return + + total_ms, model_ms, tool_ms = self._calculate_timings() + total_tokens = self.prompt_tokens + self.completion_tokens + + try: + async with get_db_session() as session: + query_log = QueryLog( + user_id=self.user_id, + lockid=self.lockid, + model=self.model, + prompt_tokens=self.prompt_tokens, + completion_tokens=self.completion_tokens, + total_tokens=total_tokens, + tool_calls=self.tool_calls, + web_search_calls=self.web_search_calls, + temperature=self.temperature, + max_tokens=self.max_tokens, + query_timestamp=datetime.now(timezone.utc), + response_time_ms=total_ms, + model_response_time_ms=model_ms, + tool_response_time_ms=tool_ms, + was_streamed=self.was_streamed, + was_multimodal=self.was_multimodal, + was_nilrag=self.was_nilrag, + was_nildb=self.was_nildb, + error_code=self.error_code, + error_message=self.error_message, + ) + session.add(query_log) + await session.commit() + logger.info( + f"Query logged for user {self.user_id}: model={self.model}, " + f"tokens={total_tokens}, total_ms={total_ms}" + ) + except SQLAlchemyError as e: + logger.error(f"Error logging query: {e}") + # Don't raise - logging failure shouldn't break the request class QueryLogManager: + """Static methods for direct query logging (legacy support).""" + @staticmethod async def log_query( - userid: str, + user_id: str, + lockid: str, model: str, prompt_tokens: int, completion_tokens: int, + response_time_ms: int, web_search_calls: int, + was_streamed: bool, + was_multimodal: bool, + was_nilrag: bool, + was_nildb: bool, + tool_calls: int = 0, + temperature: float = 1.0, + max_tokens: int = 0, + model_response_time_ms: int = 0, + tool_response_time_ms: int = 0, + error_code: int = 0, + error_message: str = "", ): """ - Log a user's query. - - Args: - userid (str): User's unique ID - model (str): The model that generated the response - prompt_tokens (int): Number of input tokens used - completion_tokens (int): Number of tokens in the generated response + Log a user's query (legacy method). + Consider using QueryLogContext as a dependency instead. """ total_tokens = prompt_tokens + completion_tokens try: async with get_db_session() as session: query_log = QueryLog( - userid=userid, + user_id=user_id, + lockid=lockid, model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, - query_timestamp=datetime.now(timezone.utc), + tool_calls=tool_calls, web_search_calls=web_search_calls, + temperature=temperature, + max_tokens=max_tokens, + query_timestamp=datetime.now(timezone.utc), + response_time_ms=response_time_ms, + model_response_time_ms=model_response_time_ms, + tool_response_time_ms=tool_response_time_ms, + was_streamed=was_streamed, + was_multimodal=was_multimodal, + was_nilrag=was_nilrag, + was_nildb=was_nildb, + error_code=error_code, + error_message=error_message, ) session.add(query_log) await session.commit() logger.info( - f"Query logged for user {userid} with total tokens {total_tokens}." + f"Query logged for user {user_id} with total tokens {total_tokens}." ) except SQLAlchemyError as e: logger.error(f"Error logging query: {e}") raise + @staticmethod + async def get_user_token_usage(user_id: str) -> Optional[Usage]: + """ + Get aggregated token usage for a specific user using server-side SQL aggregation. + This is more efficient than fetching all records and calculating in Python. + """ + try: + async with get_db_session() as session: + # Use SQL aggregation functions to calculate on the database server + query = ( + sqlalchemy.select( + sqlalchemy.func.coalesce( + sqlalchemy.func.sum(QueryLog.prompt_tokens), 0 + ).label("prompt_tokens"), + sqlalchemy.func.coalesce( + sqlalchemy.func.sum(QueryLog.completion_tokens), 0 + ).label("completion_tokens"), + sqlalchemy.func.coalesce( + sqlalchemy.func.sum(QueryLog.total_tokens), 0 + ).label("total_tokens"), + sqlalchemy.func.count().label("queries"), + ).where(QueryLog.user_id == user_id) # type: ignore[arg-type] + ) + + result = await session.execute(query) + row = result.one_or_none() + + if row is None: + return None + + return Usage( + prompt_tokens=int(row.prompt_tokens), + completion_tokens=int(row.completion_tokens), + total_tokens=int(row.total_tokens), + ) + except SQLAlchemyError as e: + logger.error(f"Error getting token usage: {e}") + return None + -__all__ = ["QueryLogManager", "QueryLog"] +__all__ = ["QueryLogManager", "QueryLog", "QueryLogContext"] diff --git a/nilai-api/src/nilai_api/db/users.py b/nilai-api/src/nilai_api/db/users.py index 515ba389..e475c424 100644 --- a/nilai-api/src/nilai_api/db/users.py +++ b/nilai-api/src/nilai_api/db/users.py @@ -2,11 +2,10 @@ import uuid from pydantic import BaseModel, ConfigDict, Field -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional +from typing import Optional import sqlalchemy -from sqlalchemy import Integer, String, DateTime, JSON +from sqlalchemy import String, JSON from sqlalchemy.exc import SQLAlchemyError from nilai_api.db import Base, Column, get_db_session @@ -57,21 +56,11 @@ def get_effective_limits(self) -> "RateLimits": # Enhanced User Model with additional constraints and validation class UserModel(Base): __tablename__ = "users" - - userid: str = Column(String(75), primary_key=True, index=True) # type: ignore - name: str = Column(String(100), nullable=False) # type: ignore - apikey: str = Column(String(75), unique=False, nullable=False, index=True) # type: ignore - prompt_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore - completion_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore - queries: int = Column(Integer, default=0, nullable=False) # type: ignore - signup_date: datetime = Column( - DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False - ) # type: ignore - last_activity: datetime = Column(DateTime(timezone=True), nullable=True) # type: ignore + user_id: str = Column(String(75), primary_key=True, index=True) # type: ignore rate_limits: dict = Column(JSON, nullable=True) # type: ignore def __repr__(self): - return f"" + return f"" @property def rate_limits_obj(self) -> RateLimits: @@ -85,14 +74,7 @@ def to_pydantic(self) -> "UserData": class UserData(BaseModel): - userid: str - name: str - apikey: str - prompt_tokens: int = 0 - completion_tokens: int = 0 - queries: int = 0 - signup_date: datetime - last_activity: Optional[datetime] = None + user_id: str # apikey or subscription holder public key rate_limits: RateLimits = Field(default_factory=RateLimits().get_effective_limits) model_config = ConfigDict(from_attributes=True) @@ -100,21 +82,10 @@ class UserData(BaseModel): @classmethod def from_sqlalchemy(cls, user: UserModel) -> "UserData": return cls( - userid=user.userid, - name=user.name, - apikey=user.apikey, - prompt_tokens=user.prompt_tokens or 0, - completion_tokens=user.completion_tokens or 0, - queries=user.queries or 0, - signup_date=user.signup_date or datetime.now(timezone.utc), - last_activity=user.last_activity, + user_id=user.user_id, rate_limits=user.rate_limits_obj, ) - @property - def is_subscription_owner(self): - return self.userid == self.apikey - class UserManager: @staticmethod @@ -127,31 +98,9 @@ def generate_api_key() -> str: """Generate a unique API key.""" return str(uuid.uuid4()) - @staticmethod - async def update_last_activity(userid: str): - """ - Update the last activity timestamp for a user. - - Args: - userid (str): User's unique ID - """ - try: - async with get_db_session() as session: - user = await session.get(UserModel, userid) - if user: - user.last_activity = datetime.now(timezone.utc) - await session.commit() - logger.info(f"Updated last activity for user {userid}") - else: - logger.warning(f"User {userid} not found") - except SQLAlchemyError as e: - logger.error(f"Error updating last activity: {e}") - @staticmethod async def insert_user( - name: str, - apikey: str | None = None, - userid: str | None = None, + user_id: str | None = None, rate_limits: RateLimits | None = None, ) -> UserModel: """ @@ -160,19 +109,16 @@ async def insert_user( Args: name (str): Name of the user apikey (str): API key for the user - userid (str): Unique ID for the user + user_id (str): Unique ID for the user rate_limits (RateLimits): Rate limit configuration Returns: UserModel: The created user model """ - userid = userid if userid else UserManager.generate_user_id() - apikey = apikey if apikey else UserManager.generate_api_key() + user_id = user_id if user_id else UserManager.generate_user_id() user = UserModel( - userid=userid, - name=name, - apikey=apikey, + user_id=user_id, rate_limits=rate_limits.model_dump() if rate_limits else None, ) return await UserManager.insert_user_model(user) @@ -189,35 +135,14 @@ async def insert_user_model(user: UserModel) -> UserModel: async with get_db_session() as session: session.add(user) await session.commit() - logger.info(f"User {user.name} added successfully.") + logger.info(f"User {user.user_id} added successfully.") return user except SQLAlchemyError as e: logger.error(f"Error inserting user: {e}") raise @staticmethod - async def check_user(userid: str) -> Optional[UserModel]: - """ - Validate a user. - - Args: - userid (str): User ID to validate - - Returns: - User's name if user is valid, None otherwise - """ - try: - async with get_db_session() as session: - query = sqlalchemy.select(UserModel).filter(UserModel.userid == userid) # type: ignore - user = await session.execute(query) - user = user.scalar_one_or_none() - return user - except SQLAlchemyError as e: - logger.error(f"Error checking API key: {e}") - return None - - @staticmethod - async def check_api_key(api_key: str) -> Optional[UserModel]: + async def check_user(user_id: str) -> Optional[UserModel]: """ Validate an API key. @@ -225,118 +150,27 @@ async def check_api_key(api_key: str) -> Optional[UserModel]: api_key (str): API key to validate Returns: - User's name if API key is valid, None otherwise + User's rate limits if user id is valid, None otherwise """ try: async with get_db_session() as session: - query = sqlalchemy.select(UserModel).filter(UserModel.apikey == api_key) # type: ignore + query = sqlalchemy.select(UserModel).filter( + UserModel.user_id == user_id # type: ignore + ) user = await session.execute(query) user = user.scalar_one_or_none() return user except SQLAlchemyError as e: - logger.error(f"Error checking API key: {e}") - return None - - @staticmethod - async def update_token_usage( - userid: str, prompt_tokens: int, completion_tokens: int - ): - """ - Update token usage for a specific user. - - Args: - userid (str): User's unique ID - prompt_tokens (int): Number of input tokens - completion_tokens (int): Number of generated tokens - """ - try: - async with get_db_session() as session: - user = await session.get(UserModel, userid) - if user: - user.prompt_tokens += prompt_tokens - user.completion_tokens += completion_tokens - user.queries += 1 - await session.commit() - logger.info(f"Updated token usage for user {userid}") - else: - logger.warning(f"User {userid} not found") - except SQLAlchemyError as e: - logger.error(f"Error updating token usage: {e}") - - @staticmethod - async def get_token_usage(userid: str) -> Optional[Dict[str, Any]]: - """ - Get token usage for a specific user. - - Args: - userid (str): User's unique ID - """ - try: - async with get_db_session() as session: - user = await session.get(UserModel, userid) - if user: - return { - "prompt_tokens": user.prompt_tokens, - "completion_tokens": user.completion_tokens, - "total_tokens": user.prompt_tokens + user.completion_tokens, - "queries": user.queries, - } - else: - logger.warning(f"User {userid} not found") - return None - except SQLAlchemyError as e: - logger.error(f"Error updating token usage: {e}") - return None - - @staticmethod - async def get_all_users() -> Optional[List[UserData]]: - """ - Retrieve all users from the database. - - Returns: - List of UserData or None if no users found - """ - try: - async with get_db_session() as session: - users = await session.execute(sqlalchemy.select(UserModel)) - users = users.scalars().all() - return [UserData.from_sqlalchemy(user) for user in users] - except SQLAlchemyError as e: - logger.error(f"Error retrieving all users: {e}") - return None - - @staticmethod - async def get_user_token_usage(userid: str) -> Optional[Dict[str, int]]: - """ - Retrieve total token usage for a user. - - Args: - userid (str): User's unique ID - - Returns: - Dict of token usage or None if user not found - """ - try: - async with get_db_session() as session: - user = await session.get(UserModel, userid) - if user: - return { - "prompt_tokens": user.prompt_tokens, - "completion_tokens": user.completion_tokens, - "queries": user.queries, - } - return None - except SQLAlchemyError as e: - logger.error(f"Error retrieving token usage: {e}") + logger.error(f"Rate limit checking user id: {e}") return None @staticmethod - async def update_rate_limits(userid: str, rate_limits: RateLimits) -> bool: + async def update_rate_limits(user_id: str, rate_limits: RateLimits) -> bool: """ Update rate limits for a specific user. Args: - userid (str): User's unique ID + user_id (str): User's unique ID rate_limits (RateLimits): New rate limit configuration Returns: @@ -344,14 +178,14 @@ async def update_rate_limits(userid: str, rate_limits: RateLimits) -> bool: """ try: async with get_db_session() as session: - user = await session.get(UserModel, userid) + user = await session.get(UserModel, user_id) if user: user.rate_limits = rate_limits.model_dump() await session.commit() - logger.info(f"Updated rate limits for user {userid}") + logger.info(f"Updated rate limits for user {user_id}") return True else: - logger.warning(f"User {userid} not found") + logger.warning(f"User {user_id} not found") return False except SQLAlchemyError as e: logger.error(f"Error updating rate limits: {e}") diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index c2d03273..0a70f15f 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -1,3 +1,4 @@ +import logging from asyncio import iscoroutine from typing import Callable, Tuple, Awaitable, Annotated @@ -11,6 +12,8 @@ from nilai_api.auth import get_auth_info, AuthenticationInfo, TokenRateLimits from nilai_api.config import CONFIG +logger = logging.getLogger(__name__) + LUA_RATE_LIMIT_SCRIPT = """ local key = KEYS[1] local limit = tonumber(ARGV[1]) @@ -52,7 +55,7 @@ async def _extract_coroutine_result(maybe_future, request: Request): class UserRateLimits(BaseModel): - subscription_holder: str + user_id: str token_rate_limit: TokenRateLimits | None rate_limits: RateLimits @@ -60,14 +63,13 @@ class UserRateLimits(BaseModel): def get_user_limits( auth_info: Annotated[AuthenticationInfo, Depends(get_auth_info)], ) -> UserRateLimits: - # TODO: When the only allowed strategy is NUC, we can change the apikey name to subscription_holder - # In apikey mode, the apikey is unique as the userid. - # In nuc mode, the apikey is associated with a subscription holder and the userid is the user + # In apikey mode, the apikey is unique as the user_id. + # In nuc mode, the apikey is associated with a subscription holder and the user_id is the user # For NUCs we want the rate limit to be per subscription holder, not per user - # In JWT mode, the apikey is the userid too + # In JWT mode, the apikey is the user_id too # So we use the apikey as the id return UserRateLimits( - subscription_holder=auth_info.user.apikey, + user_id=auth_info.user.user_id, token_rate_limit=auth_info.token_rate_limit, rate_limits=auth_info.user.rate_limits, ) @@ -105,21 +107,21 @@ async def __call__( await self.check_bucket( redis, redis_rate_limit_command, - f"minute:{user_limits.subscription_holder}", + f"minute:{user_limits.user_id}", user_limits.rate_limits.user_rate_limit_minute, MINUTE_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"hour:{user_limits.subscription_holder}", + f"hour:{user_limits.user_id}", user_limits.rate_limits.user_rate_limit_hour, HOUR_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"day:{user_limits.subscription_holder}", + f"day:{user_limits.user_id}", user_limits.rate_limits.user_rate_limit_day, DAY_MS, ) @@ -127,7 +129,7 @@ async def __call__( await self.check_bucket( redis, redis_rate_limit_command, - f"user:{user_limits.subscription_holder}", + f"user:{user_limits.user_id}", user_limits.rate_limits.user_rate_limit, 0, # No expiration for for-good rate limit ) @@ -159,21 +161,21 @@ async def __call__( await self.check_bucket( redis, redis_rate_limit_command, - f"web_search_minute:{user_limits.subscription_holder}", + f"web_search_minute:{user_limits.user_id}", user_limits.rate_limits.web_search_rate_limit_minute, MINUTE_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"web_search_hour:{user_limits.subscription_holder}", + f"web_search_hour:{user_limits.user_id}", user_limits.rate_limits.web_search_rate_limit_hour, HOUR_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"web_search_day:{user_limits.subscription_holder}", + f"web_search_day:{user_limits.user_id}", user_limits.rate_limits.web_search_rate_limit_day, DAY_MS, ) @@ -187,7 +189,7 @@ async def __call__( await self.check_bucket( redis, redis_rate_limit_command, - f"web_search:{user_limits.subscription_holder}", + f"web_search:{user_limits.user_id}", user_limits.rate_limits.web_search_rate_limit, 0, ) @@ -206,13 +208,33 @@ async def check_bucket( times: int | None, milliseconds: int, ): + """ + Check if the rate limit is exceeded for a given key + + Args: + redis: The Redis client + redis_rate_limit_command: The Redis rate limit command + key: The key to check the rate limit for + times: The number of times allowed for the key + milliseconds: The expiration time in milliseconds of the rate limit + + Returns: + None if the rate limit is not exceeded + The number of milliseconds to wait before the rate limit is reset if the rate limit is exceeded + + Raises: + HTTPException: If the rate limit is exceeded + """ if times is None: return + # Evaluate the Lua script to check if the rate limit is exceeded expire = await redis.evalsha( redis_rate_limit_command, 1, key, str(times), str(milliseconds) ) # type: ignore - if int(expire) > 0: + logger.error( + f"Rate limit exceeded for key: {key}, expires in: {expire} milliseconds, times allowed: {times}, expiration time: {milliseconds / 1000} seconds" + ) raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too Many Requests", diff --git a/nilai-api/src/nilai_api/routers/endpoints/chat.py b/nilai-api/src/nilai_api/routers/endpoints/chat.py index 45425e31..a72c9594 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/chat.py +++ b/nilai-api/src/nilai_api/routers/endpoints/chat.py @@ -5,7 +5,15 @@ from base64 import b64encode from typing import AsyncGenerator, Optional, Union, List, Tuple -from fastapi import APIRouter, Body, Depends, HTTPException, status, Request +from fastapi import ( + APIRouter, + Body, + Depends, + HTTPException, + status, + Request, + BackgroundTasks, +) from fastapi.responses import StreamingResponse from openai import AsyncOpenAI @@ -13,8 +21,7 @@ from nilai_api.config import CONFIG from nilai_api.crypto import sign_message from nilai_api.credit import LLMMeter, LLMUsage -from nilai_api.db.logs import QueryLogManager -from nilai_api.db.users import UserManager +from nilai_api.db.logs import QueryLogContext from nilai_api.handlers.nildb.handler import get_prompt_from_nildb from nilai_api.handlers.nilrag import handle_nilrag from nilai_api.handlers.tools.tool_router import handle_tool_workflow @@ -72,6 +79,7 @@ async def chat_completion( ], ) ), + background_tasks: BackgroundTasks = BackgroundTasks(), _rate_limit=Depends( RateLimit( concurrent_extractor=chat_completion_concurrent_rate_limit, @@ -80,6 +88,7 @@ async def chat_completion( ), auth_info: AuthenticationInfo = Depends(get_auth_info), meter: MeteringContext = Depends(LLMMeter), + log_ctx: QueryLogContext = Depends(QueryLogContext), ) -> Union[SignedChatCompletion, StreamingResponse]: """ Generate a chat completion response from the AI model. @@ -133,243 +142,311 @@ async def chat_completion( ) response = await chat_completion(request, user) """ - - if len(req.messages) == 0: - raise HTTPException( - status_code=400, - detail="Request contained 0 messages", - ) + # Initialize log context early so we can log any errors + log_ctx.set_user(auth_info.user.user_id) + log_ctx.set_lockid(meter.lock_id) model_name = req.model request_id = str(uuid.uuid4()) t_start = time.monotonic() - logger.info(f"[chat] call start request_id={req.messages}") - endpoint = await state.get_model(model_name) - if endpoint is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid model name {model_name}, check /v1/models for options", - ) - - has_multimodal = req.has_multimodal_content() - logger.info(f"[chat] has_multimodal: {has_multimodal}") - if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): - raise HTTPException( - status_code=400, - detail="Model does not support multimodal content, remove image inputs from request", - ) - model_url = endpoint.url + "/v1/" - - logger.info( - f"[chat] start request_id={request_id} user={auth_info.user.userid} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" - ) - - client = AsyncOpenAI(base_url=model_url, api_key="") - if auth_info.prompt_document: - try: - nildb_prompt: str = await get_prompt_from_nildb(auth_info.prompt_document) - req.messages.insert( - 0, MessageAdapter.new_message(role="system", content=nildb_prompt) + try: + if len(req.messages) == 0: + raise HTTPException( + status_code=400, + detail="Request contained 0 messages", ) - except Exception as e: + endpoint = await state.get_model(model_name) + if endpoint is None: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Unable to extract prompt from nilDB: {str(e)}", + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid model name {model_name}, check /v1/models for options", ) - if req.nilrag: - logger.info(f"[chat] nilrag start request_id={request_id}") - t_nilrag = time.monotonic() - await handle_nilrag(req) - logger.info( - f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" - ) - - messages = req.messages - sources: Optional[List[Source]] = None + # Now we have a valid model, set it in log context + log_ctx.set_model(model_name) - if req.web_search: - logger.info(f"[chat] web_search start request_id={request_id}") - t_ws = time.monotonic() - web_search_result = await handle_web_search(req, model_name, client) - messages = web_search_result.messages - sources = web_search_result.sources - logger.info( - f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" - ) - logger.info(f"[chat] web_search messages: {messages}") - - request_kwargs = { - "model": req.model, - "messages": messages, - "top_p": req.top_p, - "temperature": req.temperature, - "max_tokens": req.max_tokens, - } - - if req.tools: - if not endpoint.metadata.tool_support: + if not endpoint.metadata.tool_support and req.tools: raise HTTPException( status_code=400, detail="Model does not support tool usage, remove tools from request", ) - if model_name == "openai/gpt-oss-20b": + + has_multimodal = req.has_multimodal_content() + logger.info(f"[chat] has_multimodal: {has_multimodal}") + if has_multimodal and ( + not endpoint.metadata.multimodal_support or req.web_search + ): raise HTTPException( status_code=400, - detail="This model only supports tool calls with responses endpoint", + detail="Model does not support multimodal content, remove image inputs from request", ) - request_kwargs["tools"] = req.tools - request_kwargs["tool_choice"] = req.tool_choice - if req.stream: + model_url = endpoint.url + "/v1/" - async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: - t_call = time.monotonic() - prompt_token_usage = 0 - completion_token_usage = 0 + logger.info( + f"[chat] start request_id={request_id} user={auth_info.user.user_id} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" + ) + log_ctx.set_request_params( + temperature=req.temperature, + max_tokens=req.max_tokens, + was_streamed=req.stream or False, + was_multimodal=has_multimodal, + was_nildb=bool(auth_info.prompt_document), + was_nilrag=bool(req.nilrag), + ) + client = AsyncOpenAI(base_url=model_url, api_key="") + if auth_info.prompt_document: try: - logger.info(f"[chat] stream start request_id={request_id}") + nildb_prompt: str = await get_prompt_from_nildb( + auth_info.prompt_document + ) + req.messages.insert( + 0, MessageAdapter.new_message(role="system", content=nildb_prompt) + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Unable to extract prompt from nilDB: {str(e)}", + ) + + if req.nilrag: + logger.info(f"[chat] nilrag start request_id={request_id}") + t_nilrag = time.monotonic() + await handle_nilrag(req) + logger.info( + f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" + ) - request_kwargs["stream"] = True - request_kwargs["extra_body"] = { - "stream_options": { - "include_usage": True, - "continuous_usage_stats": False, + messages = req.messages + sources: Optional[List[Source]] = None + + if req.web_search: + logger.info(f"[chat] web_search start request_id={request_id}") + t_ws = time.monotonic() + web_search_result = await handle_web_search(req, model_name, client) + messages = web_search_result.messages + sources = web_search_result.sources + logger.info( + f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" + ) + logger.info(f"[chat] web_search messages: {messages}") + + if req.stream: + + async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: + t_call = time.monotonic() + prompt_token_usage = 0 + completion_token_usage = 0 + + try: + logger.info(f"[chat] stream start request_id={request_id}") + + log_ctx.start_model_timing() + + request_kwargs = { + "model": req.model, + "messages": messages, + "stream": True, + "top_p": req.top_p, + "temperature": req.temperature, + "max_tokens": req.max_tokens, + "extra_body": { + "stream_options": { + "include_usage": True, + "continuous_usage_stats": False, + } + }, } - } + if req.tools: + request_kwargs["tools"] = req.tools + + response = await client.chat.completions.create(**request_kwargs) + + async for chunk in response: + if chunk.usage is not None: + prompt_token_usage = chunk.usage.prompt_tokens + completion_token_usage = chunk.usage.completion_tokens + + payload = chunk.model_dump(exclude_unset=True) + + if chunk.usage is not None and sources: + payload["sources"] = [ + s.model_dump(mode="json") for s in sources + ] + + yield f"data: {json.dumps(payload)}\n\n" + + log_ctx.end_model_timing() + meter.set_response( + { + "usage": LLMUsage( + prompt_tokens=prompt_token_usage, + completion_tokens=completion_token_usage, + web_searches=len(sources) if sources else 0, + ) + } + ) + log_ctx.set_usage( + prompt_tokens=prompt_token_usage, + completion_tokens=completion_token_usage, + web_search_calls=len(sources) if sources else 0, + ) + background_tasks.add_task(log_ctx.commit) + logger.info( + "[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d " + "duration_ms=%.0f total_ms=%.0f", + request_id, + prompt_token_usage, + completion_token_usage, + (time.monotonic() - t_call) * 1000, + (time.monotonic() - t_start) * 1000, + ) + + except Exception as e: + logger.error( + "[chat] stream error request_id=%s error=%s", request_id, e + ) + log_ctx.set_error(error_code=500, error_message=str(e)) + await log_ctx.commit() + yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" + + return StreamingResponse( + chat_completion_stream_generator(), + media_type="text/event-stream", + ) - response = await client.chat.completions.create(**request_kwargs) + current_messages = messages + request_kwargs = { + "model": req.model, + "messages": current_messages, # type: ignore + "top_p": req.top_p, + "temperature": req.temperature, + "max_tokens": req.max_tokens, + } + if req.tools: + request_kwargs["tools"] = req.tools # type: ignore + request_kwargs["tool_choice"] = req.tool_choice + + logger.info(f"[chat] call start request_id={request_id}") + logger.info(f"[chat] call message: {current_messages}") + t_call = time.monotonic() + log_ctx.start_model_timing() + response = await client.chat.completions.create(**request_kwargs) # type: ignore + log_ctx.end_model_timing() + logger.info( + f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) + logger.info(f"[chat] call response: {response}") + + # Handle tool workflow fully inside tools.router + log_ctx.start_tool_timing() + ( + final_completion, + agg_prompt_tokens, + agg_completion_tokens, + ) = await handle_tool_workflow(client, req, current_messages, response) + log_ctx.end_tool_timing() + logger.info(f"[chat] call final_completion: {final_completion}") + model_response = SignedChatCompletion( + **final_completion.model_dump(), + signature="", + sources=sources, + ) - async for chunk in response: - if chunk.usage is not None: - prompt_token_usage = chunk.usage.prompt_tokens - completion_token_usage = chunk.usage.completion_tokens + logger.info( + f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) - payload = chunk.model_dump(exclude_unset=True) + if model_response.usage is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Model response does not contain usage statistics", + ) - if chunk.usage is not None and sources: - payload["sources"] = [ - s.model_dump(mode="json") for s in sources - ] + if agg_prompt_tokens or agg_completion_tokens: + total_prompt_tokens = response.usage.prompt_tokens + total_completion_tokens = response.usage.completion_tokens - yield f"data: {json.dumps(payload)}\n\n" + total_prompt_tokens += agg_prompt_tokens + total_completion_tokens += agg_completion_tokens - await UserManager.update_token_usage( - auth_info.user.userid, - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - ) - meter.set_response( - { - "usage": LLMUsage( - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - web_searches=len(sources) if sources else 0, - ) - } - ) - await QueryLogManager.log_query( - auth_info.user.userid, - model=req.model, - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - web_search_calls=len(sources) if sources else 0, - ) - logger.info( - "[chat] stream done request_id=%s prompt_tokens=%d completion_tokens=%d " - "duration_ms=%.0f total_ms=%.0f", - request_id, - prompt_token_usage, - completion_token_usage, - (time.monotonic() - t_call) * 1000, - (time.monotonic() - t_start) * 1000, - ) + model_response.usage.prompt_tokens = total_prompt_tokens + model_response.usage.completion_tokens = total_completion_tokens + model_response.usage.total_tokens = ( + total_prompt_tokens + total_completion_tokens + ) - except Exception as e: - logger.error( - "[chat] stream error request_id=%s error=%s", request_id, e + # Update token usage in DB + meter.set_response( + { + "usage": LLMUsage( + prompt_tokens=model_response.usage.prompt_tokens, + completion_tokens=model_response.usage.completion_tokens, + web_searches=len(sources) if sources else 0, ) - yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" - - return StreamingResponse( - chat_completion_stream_generator(), - media_type="text/event-stream", + } ) - logger.info(f"[chat] call start request_id={request_id}") - t_call = time.monotonic() - response = await client.chat.completions.create(**request_kwargs) - logger.info( - f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" - ) - logger.info(f"[chat] call response: {response}") - - # Handle tool workflow fully inside tools.router - ( - final_completion, - agg_prompt_tokens, - agg_completion_tokens, - ) = await handle_tool_workflow(client, req, request_kwargs["messages"], response) - logger.info(f"[chat] call final_completion: {final_completion}") - model_response = SignedChatCompletion( - **final_completion.model_dump(), - signature="", - sources=sources, - ) - - logger.info( - f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" - ) + # Log query with context + tool_calls_count = 0 + if final_completion.choices and final_completion.choices[0].message.tool_calls: + tool_calls_count = len(final_completion.choices[0].message.tool_calls) - if model_response.usage is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Model response does not contain usage statistics", + log_ctx.set_usage( + prompt_tokens=model_response.usage.prompt_tokens, + completion_tokens=model_response.usage.completion_tokens, + tool_calls=tool_calls_count, + web_search_calls=len(sources) if sources else 0, ) + # Use background task for successful requests to avoid blocking response + background_tasks.add_task(log_ctx.commit) - if agg_prompt_tokens or agg_completion_tokens: - total_prompt_tokens = response.usage.prompt_tokens - total_completion_tokens = response.usage.completion_tokens - - total_prompt_tokens += agg_prompt_tokens - total_completion_tokens += agg_completion_tokens + # Sign the response + response_json = model_response.model_dump_json() + signature = sign_message(state.private_key, response_json) + model_response.signature = b64encode(signature).decode() - model_response.usage.prompt_tokens = total_prompt_tokens - model_response.usage.completion_tokens = total_completion_tokens - model_response.usage.total_tokens = ( - total_prompt_tokens + total_completion_tokens + logger.info( + f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" + ) + return model_response + except HTTPException as e: + # Extract error code from HTTPException, default to status code + error_code = e.status_code + error_message = str(e.detail) if e.detail else str(e) + logger.error( + f"[chat] HTTPException request_id={request_id} user={auth_info.user.user_id} " + f"model={model_name} error_code={error_code} error={error_message}", + exc_info=True, ) - # Update token usage in DB - await UserManager.update_token_usage( - auth_info.user.userid, - prompt_tokens=model_response.usage.prompt_tokens, - completion_tokens=model_response.usage.completion_tokens, - ) - meter.set_response( - { - "usage": LLMUsage( - prompt_tokens=model_response.usage.prompt_tokens, - completion_tokens=model_response.usage.completion_tokens, - web_searches=len(sources) if sources else 0, - ) - } - ) - await QueryLogManager.log_query( - auth_info.user.userid, - model=req.model, - prompt_tokens=model_response.usage.prompt_tokens, - completion_tokens=model_response.usage.completion_tokens, - web_search_calls=len(sources) if sources else 0, - ) - - # Sign the response - response_json = model_response.model_dump_json() - signature = sign_message(state.private_key, response_json) - model_response.signature = b64encode(signature).decode() - - logger.info( - f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" - ) - return model_response + # Only log server errors (5xx) to database to prevent DoS attacks via client errors + # Client errors (4xx) are logged to application logs only + if error_code >= 500: + # Set model if not already set (e.g., for validation errors before model validation) + if log_ctx.model is None: + log_ctx.set_model(model_name) + log_ctx.set_error(error_code=error_code, error_message=error_message) + await log_ctx.commit() + # For 4xx errors, we skip DB logging - they're logged above via logger.error() + # This prevents DoS attacks where attackers send many invalid requests + + raise + except Exception as e: + # Catch any other unexpected exceptions + error_message = str(e) + logger.error( + f"[chat] unexpected error request_id={request_id} user={auth_info.user.user_id} " + f"model={model_name} error={error_message}", + exc_info=True, + ) + # Set model if not already set + if log_ctx.model is None: + log_ctx.set_model(model_name) + log_ctx.set_error(error_code=500, error_message=error_message) + await log_ctx.commit() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Internal server error: {error_message}", + ) diff --git a/nilai-api/src/nilai_api/routers/endpoints/responses.py b/nilai-api/src/nilai_api/routers/endpoints/responses.py index 1b764674..54f9ca92 100644 --- a/nilai-api/src/nilai_api/routers/endpoints/responses.py +++ b/nilai-api/src/nilai_api/routers/endpoints/responses.py @@ -5,7 +5,15 @@ from base64 import b64encode from typing import AsyncGenerator, Optional, Union, List, Tuple -from fastapi import APIRouter, Body, Depends, HTTPException, status, Request +from fastapi import ( + APIRouter, + Body, + Depends, + HTTPException, + status, + Request, + BackgroundTasks, +) from fastapi.responses import StreamingResponse from openai import AsyncOpenAI @@ -13,8 +21,7 @@ from nilai_api.config import CONFIG from nilai_api.crypto import sign_message from nilai_api.credit import LLMMeter, LLMUsage -from nilai_api.db.logs import QueryLogManager -from nilai_api.db.users import UserManager +from nilai_api.db.logs import QueryLogContext from nilai_api.handlers.nildb.handler import get_prompt_from_nildb # from nilai_api.handlers.nilrag import handle_nilrag_for_responses @@ -73,6 +80,7 @@ async def create_response( "web_search": False, } ), + background_tasks: BackgroundTasks = BackgroundTasks(), _rate_limit=Depends( RateLimit( concurrent_extractor=responses_concurrent_rate_limit, @@ -81,6 +89,7 @@ async def create_response( ), auth_info: AuthenticationInfo = Depends(get_auth_info), meter: MeteringContext = Depends(LLMMeter), + log_ctx: QueryLogContext = Depends(QueryLogContext), ) -> Union[SignedResponse, StreamingResponse]: """ Generate a response from the AI model using the Responses API. @@ -124,225 +133,289 @@ async def create_response( - **500 Internal Server Error**: - Model fails to generate a response """ - if not req.input: - raise HTTPException( - status_code=400, - detail="Request 'input' field cannot be empty.", - ) + # Initialize log context early so we can log any errors + log_ctx.set_user(auth_info.user.user_id) + log_ctx.set_lockid(meter.lock_id) model_name = req.model request_id = str(uuid.uuid4()) t_start = time.monotonic() - endpoint = await state.get_model(model_name) - if endpoint is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid model name {model_name}, check /v1/models for options", - ) + try: + if not req.input: + raise HTTPException( + status_code=400, + detail="Request 'input' field cannot be empty.", + ) - if not endpoint.metadata.tool_support and req.tools: - raise HTTPException( - status_code=400, - detail="Model does not support tool usage, remove tools from request", - ) + endpoint = await state.get_model(model_name) + if endpoint is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid model name {model_name}, check /v1/models for options", + ) - has_multimodal = req.has_multimodal_content() - if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): - raise HTTPException( - status_code=400, - detail="Model does not support multimodal content, remove image inputs from request", - ) + # Now we have a valid model, set it in log context + log_ctx.set_model(model_name) - model_url = endpoint.url + "/v1/" + if not endpoint.metadata.tool_support and req.tools: + raise HTTPException( + status_code=400, + detail="Model does not support tool usage, remove tools from request", + ) - client = AsyncOpenAI(base_url=model_url, api_key="") - if auth_info.prompt_document: - try: - nildb_prompt: str = await get_prompt_from_nildb(auth_info.prompt_document) - req.ensure_instructions(nildb_prompt) - except Exception as e: + has_multimodal = req.has_multimodal_content() + if has_multimodal and ( + not endpoint.metadata.multimodal_support or req.web_search + ): raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Unable to extract prompt from nilDB: {str(e)}", + status_code=400, + detail="Model does not support multimodal content, remove image inputs from request", ) - input_items = req.input - instructions = req.instructions - sources: Optional[List[Source]] = None + model_url = endpoint.url + "/v1/" - if req.web_search: - logger.info(f"[responses] web_search start request_id={request_id}") - t_ws = time.monotonic() - web_search_result = await handle_web_search_for_responses( - req, model_name, client - ) - input_items = web_search_result.input - instructions = web_search_result.instructions - sources = web_search_result.sources logger.info( - f"[responses] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" + f"[responses] start request_id={request_id} user={auth_info.user.user_id} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" + ) + log_ctx.set_request_params( + temperature=req.temperature, + max_tokens=req.max_output_tokens, + was_streamed=req.stream or False, + was_multimodal=has_multimodal, + was_nildb=bool(auth_info.prompt_document), + was_nilrag=False, ) - if req.stream: - - async def response_stream_generator() -> AsyncGenerator[str, None]: - t_call = time.monotonic() - prompt_token_usage = 0 - completion_token_usage = 0 - + client = AsyncOpenAI(base_url=model_url, api_key="") + if auth_info.prompt_document: try: - logger.info(f"[responses] stream start request_id={request_id}") - request_kwargs = { - "model": req.model, - "input": input_items, - "instructions": instructions, - "stream": True, - "top_p": req.top_p, - "temperature": req.temperature, - "max_output_tokens": req.max_output_tokens, - "extra_body": { - "stream_options": { - "include_usage": True, - "continuous_usage_stats": False, - } - }, - } - if req.tools: - request_kwargs["tools"] = req.tools - - stream = await client.responses.create(**request_kwargs) - - async for event in stream: - payload = event.model_dump(exclude_unset=True) - - if isinstance(event, ResponseCompletedEvent): - if event.response and event.response.usage: - usage = event.response.usage - prompt_token_usage = usage.input_tokens - completion_token_usage = usage.output_tokens - payload["response"]["usage"] = usage.model_dump(mode="json") - - if sources: - if "data" not in payload: - payload["data"] = {} - payload["data"]["sources"] = [ - s.model_dump(mode="json") for s in sources - ] - - yield f"data: {json.dumps(payload)}\n\n" - - await UserManager.update_token_usage( - auth_info.user.userid, - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - ) - meter.set_response( - { - "usage": LLMUsage( - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - web_searches=len(sources) if sources else 0, - ) - } - ) - await QueryLogManager.log_query( - auth_info.user.userid, - model=req.model, - prompt_tokens=prompt_token_usage, - completion_tokens=completion_token_usage, - web_search_calls=len(sources) if sources else 0, + nildb_prompt: str = await get_prompt_from_nildb( + auth_info.prompt_document ) - logger.info( - "[responses] stream done request_id=%s prompt_tokens=%d completion_tokens=%d duration_ms=%.0f total_ms=%.0f", - request_id, - prompt_token_usage, - completion_token_usage, - (time.monotonic() - t_call) * 1000, - (time.monotonic() - t_start) * 1000, - ) - + req.ensure_instructions(nildb_prompt) except Exception as e: - logger.error( - "[responses] stream error request_id=%s error=%s", request_id, e + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Unable to extract prompt from nilDB: {str(e)}", ) - yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" - return StreamingResponse( - response_stream_generator(), media_type="text/event-stream" - ) + input_items = req.input + instructions = req.instructions + sources: Optional[List[Source]] = None - request_kwargs = { - "model": req.model, - "input": input_items, - "instructions": instructions, - "top_p": req.top_p, - "temperature": req.temperature, - "max_output_tokens": req.max_output_tokens, - } - if req.tools: - request_kwargs["tools"] = req.tools - request_kwargs["tool_choice"] = req.tool_choice - - logger.info(f"[responses] call start request_id={request_id}") - t_call = time.monotonic() - - response = await client.responses.create(**request_kwargs) - logger.info( - f"[responses] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" - ) + if req.web_search: + logger.info(f"[responses] web_search start request_id={request_id}") + t_ws = time.monotonic() + web_search_result = await handle_web_search_for_responses( + req, model_name, client + ) + input_items = web_search_result.input + instructions = web_search_result.instructions + sources = web_search_result.sources + logger.info( + f"[responses] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" + ) - ( - final_response, - agg_prompt_tokens, - agg_completion_tokens, - ) = await handle_responses_tool_workflow(client, req, input_items, response) + if req.stream: + + async def response_stream_generator() -> AsyncGenerator[str, None]: + t_call = time.monotonic() + prompt_token_usage = 0 + completion_token_usage = 0 + + try: + logger.info(f"[responses] stream start request_id={request_id}") + log_ctx.start_model_timing() + + request_kwargs = { + "model": req.model, + "input": input_items, + "instructions": instructions, + "stream": True, + "top_p": req.top_p, + "temperature": req.temperature, + "max_output_tokens": req.max_output_tokens, + "extra_body": { + "stream_options": { + "include_usage": True, + "continuous_usage_stats": False, + } + }, + } + if req.tools: + request_kwargs["tools"] = req.tools + + stream = await client.responses.create(**request_kwargs) + + async for event in stream: + payload = event.model_dump(exclude_unset=True) + + if isinstance(event, ResponseCompletedEvent): + if event.response and event.response.usage: + usage = event.response.usage + prompt_token_usage = usage.input_tokens + completion_token_usage = usage.output_tokens + payload["response"]["usage"] = usage.model_dump( + mode="json" + ) + + if sources: + if "data" not in payload: + payload["data"] = {} + payload["data"]["sources"] = [ + s.model_dump(mode="json") for s in sources + ] + + yield f"data: {json.dumps(payload)}\n\n" + + log_ctx.end_model_timing() + meter.set_response( + { + "usage": LLMUsage( + prompt_tokens=prompt_token_usage, + completion_tokens=completion_token_usage, + web_searches=len(sources) if sources else 0, + ) + } + ) + log_ctx.set_usage( + prompt_tokens=prompt_token_usage, + completion_tokens=completion_token_usage, + web_search_calls=len(sources) if sources else 0, + ) + background_tasks.add_task(log_ctx.commit) + logger.info( + "[responses] stream done request_id=%s prompt_tokens=%d completion_tokens=%d duration_ms=%.0f total_ms=%.0f", + request_id, + prompt_token_usage, + completion_token_usage, + (time.monotonic() - t_call) * 1000, + (time.monotonic() - t_start) * 1000, + ) + + except Exception as e: + logger.error( + "[responses] stream error request_id=%s error=%s", request_id, e + ) + log_ctx.set_error(error_code=500, error_message=str(e)) + await log_ctx.commit() + yield f"data: {json.dumps({'error': 'stream_failed', 'message': str(e)})}\n\n" + + return StreamingResponse( + response_stream_generator(), media_type="text/event-stream" + ) - model_response = SignedResponse( - **final_response.model_dump(), - signature="", - sources=sources, - ) + request_kwargs = { + "model": req.model, + "input": input_items, + "instructions": instructions, + "top_p": req.top_p, + "temperature": req.temperature, + "max_output_tokens": req.max_output_tokens, + } + if req.tools: + request_kwargs["tools"] = req.tools + request_kwargs["tool_choice"] = req.tool_choice + + logger.info(f"[responses] call start request_id={request_id}") + t_call = time.monotonic() + log_ctx.start_model_timing() + response = await client.responses.create(**request_kwargs) + log_ctx.end_model_timing() + logger.info( + f"[responses] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) - if model_response.usage is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Model response does not contain usage statistics", + # Handle tool workflow + log_ctx.start_tool_timing() + ( + final_response, + agg_prompt_tokens, + agg_completion_tokens, + ) = await handle_responses_tool_workflow(client, req, input_items, response) + log_ctx.end_tool_timing() + + model_response = SignedResponse( + **final_response.model_dump(), + signature="", + sources=sources, ) - if agg_prompt_tokens or agg_completion_tokens: - model_response.usage.input_tokens += agg_prompt_tokens - model_response.usage.output_tokens += agg_completion_tokens + if model_response.usage is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Model response does not contain usage statistics", + ) - prompt_tokens = model_response.usage.input_tokens - completion_tokens = model_response.usage.output_tokens + if agg_prompt_tokens or agg_completion_tokens: + model_response.usage.input_tokens += agg_prompt_tokens + model_response.usage.output_tokens += agg_completion_tokens - await UserManager.update_token_usage( - auth_info.user.userid, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - ) - meter.set_response( - { - "usage": LLMUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - web_searches=len(sources) if sources else 0, - ) - } - ) - await QueryLogManager.log_query( - auth_info.user.userid, - model=req.model, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - web_search_calls=len(sources) if sources else 0, - ) + prompt_tokens = model_response.usage.input_tokens + completion_tokens = model_response.usage.output_tokens - response_json = model_response.model_dump_json() - signature = sign_message(state.private_key, response_json) - model_response.signature = b64encode(signature).decode() + meter.set_response( + { + "usage": LLMUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + web_searches=len(sources) if sources else 0, + ) + } + ) - logger.info( - f"[responses] done request_id={request_id} prompt_tokens={prompt_tokens} completion_tokens={completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" - ) - return model_response + # Log query with context + # Note: Response object structure for tools might differ from Chat, + # but we'll assume basic usage logging is sufficient or adapt if needed. + # For now, we don't count tool calls explicitly in log_ctx for responses unless we parse them. + # Chat.py does: tool_calls_count = len(final_completion.choices[0].message.tool_calls) + # Responses API structure is different. `final_response` is a Response object. + # It might not have 'choices'. It has 'output'. + + log_ctx.set_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + web_search_calls=len(sources) if sources else 0, + ) + background_tasks.add_task(log_ctx.commit) + + response_json = model_response.model_dump_json() + signature = sign_message(state.private_key, response_json) + model_response.signature = b64encode(signature).decode() + + logger.info( + f"[responses] done request_id={request_id} prompt_tokens={prompt_tokens} completion_tokens={completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" + ) + return model_response + + except HTTPException as e: + error_code = e.status_code + error_message = str(e.detail) if e.detail else str(e) + logger.error( + f"[responses] HTTPException request_id={request_id} user={auth_info.user.user_id} " + f"model={model_name} error_code={error_code} error={error_message}", + exc_info=True, + ) + + if error_code >= 500: + if log_ctx.model is None: + log_ctx.set_model(model_name) + log_ctx.set_error(error_code=error_code, error_message=error_message) + await log_ctx.commit() + + raise + except Exception as e: + error_message = str(e) + logger.error( + f"[responses] unexpected error request_id={request_id} user={auth_info.user.user_id} " + f"model={model_name} error={error_message}", + exc_info=True, + ) + if log_ctx.model is None: + log_ctx.set_model(model_name) + log_ctx.set_error(error_code=500, error_message=error_message) + await log_ctx.commit() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Internal server error: {error_message}", + ) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 386dce2d..5cbcbf64 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -5,6 +5,7 @@ from nilai_api.attestation import get_attestation_report from nilai_api.auth import get_auth_info, AuthenticationInfo +from nilai_api.db.logs import QueryLogManager from nilai_api.handlers.nildb.api_model import ( PromptDelegationRequest, PromptDelegationToken, @@ -17,7 +18,6 @@ from nilai_common import ( AttestationReport, ModelMetadata, - Nonce, Usage, ) @@ -32,14 +32,10 @@ @router.get("/v1/delegation") async def get_prompt_store_delegation( prompt_delegation_request: PromptDelegationRequest, - auth_info: AuthenticationInfo = Depends(get_auth_info), + _: AuthenticationInfo = Depends( + get_auth_info + ), # This is to satisfy that the user is authenticated ) -> PromptDelegationToken: - if not auth_info.user.is_subscription_owner: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Prompt storage is reserved to subscription owners: {auth_info.user} is not a subscription owner, apikey: {auth_info.user}", - ) - try: return await get_nildb_delegation_token(prompt_delegation_request) except Exception as e: @@ -63,22 +59,26 @@ async def get_usage(auth_info: AuthenticationInfo = Depends(get_auth_info)) -> U usage = await get_usage(user) ``` """ - return Usage( - prompt_tokens=auth_info.user.prompt_tokens, - completion_tokens=auth_info.user.completion_tokens, - total_tokens=auth_info.user.prompt_tokens + auth_info.user.completion_tokens, - queries=auth_info.user.queries, # type: ignore # FIXME this field is not part of Usage + user_usage: Optional[Usage] = await QueryLogManager.get_user_token_usage( + auth_info.user.user_id ) + if user_usage is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found", + ) + return user_usage @router.get("/v1/attestation/report", tags=["Attestation"]) async def get_attestation( + nonce: str, auth_info: AuthenticationInfo = Depends(get_auth_info), ) -> AttestationReport: """ Generate a cryptographic attestation report. - - **attestation_request**: Attestation request containing a nonce + - **nonce**: Nonce for the attestation request (64 character hex string) - **user**: Authenticated user information (through HTTP Bearer header) - **Returns**: Attestation details for service verification @@ -91,7 +91,7 @@ async def get_attestation( Provides cryptographic proof of the service's integrity and environment. """ - attestation_report = await get_attestation_report() + attestation_report = await get_attestation_report(nonce) attestation_report.verifying_key = state.b64_public_key return attestation_report diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index 3d822cba..574b7437 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -32,6 +32,7 @@ ResponseInputItemParam, EasyInputMessageParam, ResponseFunctionToolCallParam, + Usage, ) from nilai_common.config import SETTINGS, MODEL_SETTINGS, MODEL_CAPABILITIES from nilai_common.discovery import ModelServiceDiscovery @@ -74,4 +75,5 @@ "ResponseInputItemParam", "EasyInputMessageParam", "ResponseFunctionToolCallParam", + "Usage", ] diff --git a/packages/nilai-common/src/nilai_common/api_models/__init__.py b/packages/nilai-common/src/nilai_common/api_models/__init__.py index 95243ff2..de76a353 100644 --- a/packages/nilai-common/src/nilai_common/api_models/__init__.py +++ b/packages/nilai-common/src/nilai_common/api_models/__init__.py @@ -30,6 +30,7 @@ MessageAdapter, ImageContent, TextContent, + Usage, ) from nilai_common.api_models.responses_model import ( @@ -74,6 +75,7 @@ "MessageAdapter", "ImageContent", "TextContent", + "Usage", "Response", "ResponseCompletedEvent", "ResponseRequest", diff --git a/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py b/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py index 41082ef5..b34ce06e 100644 --- a/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py +++ b/packages/nilai-common/src/nilai_common/api_models/chat_completion_model.py @@ -1,6 +1,7 @@ -from __future__ import annotations +import uuid from typing import ( + Annotated, Iterable, List, Optional, @@ -55,10 +56,6 @@ "ResultContent", "Choice", "Source", - "SearchResult", - "Topic", - "TopicResponse", - "TopicQuery", "MessageAdapter", "WebSearchEnhancedMessages", "WebSearchContext", @@ -78,7 +75,6 @@ class ResultContent(BaseModel): text: str truncated: bool = False -Message: TypeAlias = ChatCompletionMessageParam class Choice(OpenaAIChoice): diff --git a/packages/nilai-common/src/nilai_common/discovery.py b/packages/nilai-common/src/nilai_common/discovery.py index 86c90224..345a9a6d 100644 --- a/packages/nilai-common/src/nilai_common/discovery.py +++ b/packages/nilai-common/src/nilai_common/discovery.py @@ -5,7 +5,7 @@ from typing import Dict, Optional import redis.asyncio as redis -from nilai_common.api_model import ModelEndpoint, ModelMetadata +from nilai_common.api_models import ModelEndpoint, ModelMetadata from tenacity import retry, stop_after_attempt, wait_exponential # Configure logging diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py new file mode 100644 index 00000000..fc57b1c6 --- /dev/null +++ b/tests/e2e/conftest.py @@ -0,0 +1,239 @@ +from .config import BASE_URL, api_key_getter +from .nuc import ( + get_rate_limited_nuc_token, + get_invalid_rate_limited_nuc_token, + get_document_id_nuc_token, + get_invalid_nildb_nuc_token, +) +import httpx +import pytest +import pytest_asyncio +from openai import OpenAI, AsyncOpenAI + + +# ============================================================================ +# HTTP Client Fixtures (for test_chat_completions_http.py, test_responses_http.py) +# ============================================================================ + + +@pytest.fixture +def http_client(): + """Create an HTTPX client with default headers for HTTP-based tests""" + invocation_token: str = api_key_getter() + print("invocation_token", invocation_token) + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + verify=False, + timeout=None, + ) + + +# Alias for backward compatibility +client = http_client + + +@pytest.fixture +def rate_limited_http_client(): + """Create an HTTPX client with rate limiting for HTTP-based tests""" + invocation_token = get_rate_limited_nuc_token(rate_limit=1) + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + timeout=None, + verify=False, + ) + + +# Alias for backward compatibility +rate_limited_client = rate_limited_http_client + + +@pytest.fixture +def invalid_rate_limited_http_client(): + """Create an HTTPX client with invalid rate limiting for HTTP-based tests""" + invocation_token = get_invalid_rate_limited_nuc_token() + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + timeout=None, + verify=False, + ) + + +# Alias for backward compatibility +invalid_rate_limited_client = invalid_rate_limited_http_client + + +@pytest.fixture +def nillion_2025_client(): + """Create an HTTPX client with default headers""" + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": "Bearer Nillion2025", + }, + verify=False, + timeout=None, + ) + + +@pytest.fixture +def document_id_client(): + """Create an HTTPX client with default headers""" + invocation_token = get_document_id_nuc_token() + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + verify=False, + timeout=None, + ) + + +@pytest.fixture +def invalid_nildb(): + """Create an HTTPX client with default headers""" + invocation_token = get_invalid_nildb_nuc_token() + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + timeout=None, + verify=False, + ) + + +# ============================================================================ +# OpenAI SDK Client Fixtures (for test_chat_completions.py, test_responses.py) +# ============================================================================ + + +def _create_openai_client(api_key: str) -> OpenAI: + """Helper function to create an OpenAI client with SSL verification disabled""" + transport = httpx.HTTPTransport(verify=False) + return OpenAI( + base_url=BASE_URL, + api_key=api_key, + http_client=httpx.Client(transport=transport), + ) + + +def _create_async_openai_client(api_key: str) -> AsyncOpenAI: + """Helper function to create an async OpenAI client with SSL verification disabled""" + transport = httpx.AsyncHTTPTransport(verify=False) + return AsyncOpenAI( + base_url=BASE_URL, + api_key=api_key, + http_client=httpx.AsyncClient(transport=transport), + ) + + +@pytest.fixture +def openai_client(): + """Create an OpenAI SDK client configured to use the Nilai API""" + invocation_token: str = api_key_getter() + return _create_openai_client(invocation_token) + + +@pytest_asyncio.fixture +async def async_openai_client(): + """Create an async OpenAI SDK client configured to use the Nilai API""" + invocation_token: str = api_key_getter() + transport = httpx.AsyncHTTPTransport(verify=False) + httpx_client = httpx.AsyncClient(transport=transport) + client = AsyncOpenAI( + base_url=BASE_URL, api_key=invocation_token, http_client=httpx_client + ) + yield client + await httpx_client.aclose() + + +@pytest.fixture +def rate_limited_openai_client(): + """Create an OpenAI SDK client with rate limiting""" + invocation_token = get_rate_limited_nuc_token(rate_limit=1) + return _create_openai_client(invocation_token) + + +@pytest.fixture +def invalid_rate_limited_openai_client(): + """Create an OpenAI SDK client with invalid rate limiting""" + invocation_token = get_invalid_rate_limited_nuc_token() + return _create_openai_client(invocation_token) + + +@pytest.fixture +def document_id_openai_client(): + """Create an OpenAI SDK client with document ID token""" + invocation_token = get_document_id_nuc_token() + return _create_openai_client(invocation_token) + + +@pytest.fixture +def invalid_nildb_openai_client(): + """Create an OpenAI SDK client with document ID token""" + invocation_token = get_invalid_nildb_nuc_token() + return _create_openai_client(invocation_token) + + +@pytest.fixture +def high_web_search_rate_limit(monkeypatch): + """Set high rate limits for web search for RPS tests""" + monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_MINUTE", "9999") + monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_HOUR", "9999") + monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_DAY", "9999") + monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT", "9999") + monkeypatch.setenv("USER_RATE_LIMIT_MINUTE", "9999") + monkeypatch.setenv("USER_RATE_LIMIT_HOUR", "9999") + monkeypatch.setenv("USER_RATE_LIMIT_DAY", "9999") + monkeypatch.setenv("USER_RATE_LIMIT", "9999") + monkeypatch.setenv( + "MODEL_CONCURRENT_RATE_LIMIT", + ( + '{"meta-llama/Llama-3.2-1B-Instruct": 500, ' + '"meta-llama/Llama-3.2-3B-Instruct": 500, ' + '"meta-llama/Llama-3.1-8B-Instruct": 300, ' + '"cognitivecomputations/Dolphin3.0-Llama3.1-8B": 300, ' + '"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": 50, ' + '"hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4": 50, ' + '"openai/gpt-oss-20b": 500, ' + '"google/gemma-3-27b-it": 500, ' + '"default": 500}' + ), + ) + + +# ============================================================================ +# Convenience Aliases for OpenAI SDK Tests +# These allow test files to use 'client' instead of 'openai_client' +# Note: These will be shadowed by local fixtures in test_chat_completions.py +# and test_responses.py if those files redefine them +# ============================================================================ + +# Uncomment these if you want to use the conftest fixtures without shadowing: +# client = openai_client +# async_client = async_openai_client +# rate_limited_client = rate_limited_openai_client +# invalid_rate_limited_client = invalid_rate_limited_openai_client +nildb_client = document_id_openai_client diff --git a/tests/e2e/nuc.py b/tests/e2e/nuc.py index 0dac8edc..f63248c8 100644 --- a/tests/e2e/nuc.py +++ b/tests/e2e/nuc.py @@ -2,6 +2,9 @@ NilAuthPrivateKey, ) +from nuc.builder import NucTokenBuilder +from nuc.token import Did, InvocationBody, Command + from nilai_py import ( Client, DelegationTokenServer, @@ -124,3 +127,24 @@ def get_document_id_nuc_client() -> Client: def get_document_id_nuc_token() -> str: """Convenience function for getting NILDB NUC tokens.""" return get_document_id_nuc_client()._get_invocation_token() + + +def get_invalid_nildb_nuc_token() -> str: + """Convenience function for getting NILDB NUC tokens.""" + private_key = NilAuthPrivateKey(bytes.fromhex(PRIVATE_KEY)) + http_client = DefaultHttpxClient(verify=False) + client = Client( + base_url="https://localhost/nuc/v1", + auth_type=AuthType.API_KEY, + http_client=http_client, + api_key=PRIVATE_KEY, + ) + + invocation_token: str = ( + NucTokenBuilder.extending(client.root_token) + .body(InvocationBody(args={})) + .audience(Did(client.nilai_public_key.serialize())) + .command(Command(["nil", "db", "generate"])) + .build(private_key) + ) + return invocation_token diff --git a/tests/e2e/test_chat_completions.py b/tests/e2e/test_chat_completions.py index b24137a1..dd362885 100644 --- a/tests/e2e/test_chat_completions.py +++ b/tests/e2e/test_chat_completions.py @@ -11,98 +11,41 @@ import json import os import re -import httpx import pytest import pytest_asyncio -from openai import OpenAI -from openai import AsyncOpenAI from openai.types.chat import ChatCompletion from .config import BASE_URL, ENVIRONMENT, test_models, AUTH_STRATEGY, api_key_getter -from .nuc import ( - get_rate_limited_nuc_token, - get_invalid_rate_limited_nuc_token, -) - - -def _create_openai_client(api_key: str) -> OpenAI: - """Helper function to create an OpenAI client with SSL verification disabled""" - transport = httpx.HTTPTransport(verify=False) - return OpenAI( - base_url=BASE_URL, - api_key=api_key, - http_client=httpx.Client(transport=transport), - ) -def _create_async_openai_client(api_key: str) -> AsyncOpenAI: - transport = httpx.AsyncHTTPTransport(verify=False) - return AsyncOpenAI( - base_url=BASE_URL, - api_key=api_key, - http_client=httpx.AsyncClient(transport=transport), - ) +# ============================================================================ +# Fixture Aliases for OpenAI SDK Tests +# These create local aliases that reference the centralized fixtures in conftest.py +# This allows tests to use 'client' instead of 'openai_client', maintaining backward compatibility +# ============================================================================ @pytest.fixture -def client(): - """Create an OpenAI client configured to use the Nilai API""" - invocation_token: str = api_key_getter() - - return _create_openai_client(invocation_token) +def client(openai_client): + """Alias for openai_client fixture from conftest.py""" + return openai_client @pytest_asyncio.fixture -async def async_client(): - invocation_token: str = api_key_getter() - transport = httpx.AsyncHTTPTransport(verify=False) - httpx_client = httpx.AsyncClient(transport=transport) - client = AsyncOpenAI( - base_url=BASE_URL, api_key=invocation_token, http_client=httpx_client - ) - yield client - await httpx_client.aclose() +async def async_client(async_openai_client): + """Alias for async_openai_client fixture from conftest.py""" + return async_openai_client @pytest.fixture -def rate_limited_client(): - """Create an OpenAI client configured to use the Nilai API with rate limiting""" - invocation_token = get_rate_limited_nuc_token(rate_limit=1) - return _create_openai_client(invocation_token) +def rate_limited_client(rate_limited_openai_client): + """Alias for rate_limited_openai_client fixture from conftest.py""" + return rate_limited_openai_client @pytest.fixture -def invalid_rate_limited_client(): - """Create an OpenAI client configured to use the Nilai API with rate limiting""" - invocation_token = get_invalid_rate_limited_nuc_token() - print(f"invocation_token: {invocation_token}") - return _create_openai_client(invocation_token) - - -@pytest.fixture -def high_web_search_rate_limit(monkeypatch): - """Set high rate limits for web search for RPS tests""" - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_MINUTE", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_HOUR", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_DAY", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT", "9999") - monkeypatch.setenv("USER_RATE_LIMIT_MINUTE", "9999") - monkeypatch.setenv("USER_RATE_LIMIT_HOUR", "9999") - monkeypatch.setenv("USER_RATE_LIMIT_DAY", "9999") - monkeypatch.setenv("USER_RATE_LIMIT", "9999") - monkeypatch.setenv( - "MODEL_CONCURRENT_RATE_LIMIT", - ( - '{"meta-llama/Llama-3.2-1B-Instruct": 500, ' - '"meta-llama/Llama-3.2-3B-Instruct": 500, ' - '"meta-llama/Llama-3.1-8B-Instruct": 300, ' - '"cognitivecomputations/Dolphin3.0-Llama3.1-8B": 300, ' - '"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": 50, ' - '"hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4": 50, ' - '"openai/gpt-oss-20b": 500, ' - '"google/gemma-3-27b-it": 500, ' - '"default": 500}' - ), - ) +def invalid_rate_limited_client(invalid_rate_limited_openai_client): + """Alias for invalid_rate_limited_openai_client fixture from conftest.py""" + return invalid_rate_limited_openai_client @pytest.mark.parametrize( @@ -480,12 +423,7 @@ def test_usage_endpoint(client): assert isinstance(usage_data, dict), "Usage data should be a dictionary" # Check for expected keys - expected_keys = [ - "total_tokens", - "completion_tokens", - "prompt_tokens", - "queries", - ] + expected_keys = ["total_tokens", "completion_tokens", "prompt_tokens"] for key in expected_keys: assert key in usage_data, f"Expected key {key} not found in usage data" diff --git a/tests/e2e/test_chat_completions_http.py b/tests/e2e/test_chat_completions_http.py index 5b8b9da7..6791f830 100644 --- a/tests/e2e/test_chat_completions_http.py +++ b/tests/e2e/test_chat_completions_http.py @@ -12,96 +12,11 @@ import os import re -from .config import BASE_URL, ENVIRONMENT, test_models, AUTH_STRATEGY, api_key_getter -from .nuc import ( - get_rate_limited_nuc_token, - get_invalid_rate_limited_nuc_token, - get_document_id_nuc_token, -) +from .config import BASE_URL, ENVIRONMENT, test_models, AUTH_STRATEGY import httpx import pytest -@pytest.fixture -def client(): - """Create an HTTPX client with default headers""" - invocation_token: str = api_key_getter() - print("invocation_token", invocation_token) - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - verify=False, - timeout=None, - ) - - -@pytest.fixture -def rate_limited_client(): - """Create an HTTPX client with default headers""" - invocation_token = get_rate_limited_nuc_token(rate_limit=1) - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def invalid_rate_limited_client(): - """Create an HTTPX client with default headers""" - invocation_token = get_invalid_rate_limited_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def nillion_2025_client(): - """Create an HTTPX client with default headers""" - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": "Bearer Nillion2025", - }, - verify=False, - timeout=None, - ) - - -@pytest.fixture -def document_id_client(): - """Create an HTTPX client with default headers""" - invocation_token = get_document_id_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - verify=False, - timeout=None, - ) - - def test_health_endpoint(client): """Test the health endpoint""" response = client.get("health") @@ -139,7 +54,6 @@ def test_usage_endpoint(client): "total_tokens", "completion_tokens", "prompt_tokens", - "queries", ] for key in expected_keys: diff --git a/tests/e2e/test_responses.py b/tests/e2e/test_responses.py index f5f931c5..79679d7c 100644 --- a/tests/e2e/test_responses.py +++ b/tests/e2e/test_responses.py @@ -1,80 +1,52 @@ import json import os -import httpx import pytest import pytest_asyncio -from openai import OpenAI -from openai import AsyncOpenAI - -from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter -from .nuc import ( - get_rate_limited_nuc_token, - get_invalid_rate_limited_nuc_token, - get_nildb_nuc_token, -) - -def _create_openai_client(api_key: str) -> OpenAI: - """Helper function to create an OpenAI client with SSL verification disabled""" - transport = httpx.HTTPTransport(verify=False) - return OpenAI( - base_url=BASE_URL, - api_key=api_key, - http_client=httpx.Client(transport=transport), - ) +from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter, ENVIRONMENT -def _create_async_openai_client(api_key: str) -> AsyncOpenAI: - transport = httpx.AsyncHTTPTransport(verify=False) - return AsyncOpenAI( - base_url=BASE_URL, - api_key=api_key, - http_client=httpx.AsyncClient(transport=transport), - ) +# ============================================================================ +# Fixture Aliases for OpenAI SDK Tests +# These create local aliases that reference the centralized fixtures in conftest.py +# This allows tests to use 'client' instead of 'openai_client', maintaining backward compatibility +# ============================================================================ @pytest.fixture -def client(): - invocation_token: str = api_key_getter() - return _create_openai_client(invocation_token) +def client(openai_client): + """Alias for openai_client fixture from conftest.py""" + return openai_client @pytest_asyncio.fixture -async def async_client(): - invocation_token: str = api_key_getter() - transport = httpx.AsyncHTTPTransport(verify=False) - httpx_client = httpx.AsyncClient(transport=transport) - client = AsyncOpenAI( - base_url=BASE_URL, api_key=invocation_token, http_client=httpx_client - ) - yield client - await httpx_client.aclose() +async def async_client(async_openai_client): + """Alias for async_openai_client fixture from conftest.py""" + return async_openai_client @pytest.fixture -def rate_limited_client(): - invocation_token = get_rate_limited_nuc_token(rate_limit=1) - return _create_openai_client(invocation_token.token) +def rate_limited_client(rate_limited_openai_client): + """Alias for rate_limited_openai_client fixture from conftest.py""" + return rate_limited_openai_client @pytest.fixture -def invalid_rate_limited_client(): - invocation_token = get_invalid_rate_limited_nuc_token() - return _create_openai_client(invocation_token.token) +def invalid_rate_limited_client(invalid_rate_limited_openai_client): + """Alias for invalid_rate_limited_openai_client fixture from conftest.py""" + return invalid_rate_limited_openai_client @pytest.fixture -def nildb_client(): - invocation_token = get_nildb_nuc_token() - return _create_openai_client(invocation_token.token) +def nildb_client(document_id_openai_client): + """Alias for document_id_openai_client fixture from conftest.py""" + return document_id_openai_client @pytest.fixture -def high_web_search_rate_limit(monkeypatch): - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_MINUTE", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_HOUR", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT_DAY", "9999") - monkeypatch.setenv("WEB_SEARCH_RATE_LIMIT", "9999") +def invalid_nildb(invalid_nildb_openai_client): + """Alias for invalid_nildb_openai_client fixture from conftest.py""" + return invalid_nildb_openai_client @pytest.mark.parametrize("model", test_models) @@ -195,14 +167,14 @@ def test_invalid_rate_limiting_nucs(invalid_rate_limited_client, model): @pytest.mark.skipif( AUTH_STRATEGY != "nuc", reason="NUC rate limiting not used with API key" ) -def test_invalid_nildb_command_nucs(nildb_client, model): +def test_invalid_nildb_command_nucs(invalid_nildb, model): """Test invalid NILDB command handling""" import openai forbidden = False for _ in range(4): try: - nildb_client.responses.create( + invalid_nildb.responses.create( model=model, input="What is the capital of France?", instructions="You are a helpful assistant that provides accurate and concise information.", @@ -481,7 +453,6 @@ def test_usage_endpoint(client): "total_tokens", "completion_tokens", "prompt_tokens", - "queries", ] for key in expected_keys: assert key in usage_data, f"Expected key {key} not found in usage data" @@ -492,6 +463,10 @@ def test_usage_endpoint(client): pytest.fail(f"Error testing usage endpoint: {str(e)}") +@pytest.mark.skipif( + ENVIRONMENT != "mainnet", + reason="Attestation endpoint not available in non-mainnet environment", +) def test_attestation_endpoint(client): """Test retrieving attestation report""" try: diff --git a/tests/e2e/test_responses_http.py b/tests/e2e/test_responses_http.py index a92c8ddf..17d632a3 100644 --- a/tests/e2e/test_responses_http.py +++ b/tests/e2e/test_responses_http.py @@ -4,102 +4,7 @@ import httpx import pytest -from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter -from .nuc import ( - get_rate_limited_nuc_token, - get_invalid_rate_limited_nuc_token, - get_nildb_nuc_token, - get_document_id_nuc_token, -) - - -@pytest.fixture -def client(): - invocation_token: str = api_key_getter() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token}", - }, - verify=False, - timeout=None, - ) - - -@pytest.fixture -def rate_limited_client(): - invocation_token = get_rate_limited_nuc_token(rate_limit=1) - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token.token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def invalid_rate_limited_client(): - invocation_token = get_invalid_rate_limited_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token.token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def nildb_client(): - invocation_token = get_nildb_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token.token}", - }, - timeout=None, - verify=False, - ) - - -@pytest.fixture -def nillion_2025_client(): - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": "Bearer Nillion2025", - }, - verify=False, - timeout=None, - ) - - -@pytest.fixture -def document_id_client(): - invocation_token = get_document_id_nuc_token() - return httpx.Client( - base_url=BASE_URL, - headers={ - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {invocation_token.token}", - }, - verify=False, - timeout=None, - ) +from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter, ENVIRONMENT @pytest.mark.parametrize("model", test_models) @@ -530,12 +435,12 @@ def test_invalid_rate_limiting_nucs(invalid_rate_limited_client): @pytest.mark.skipif( AUTH_STRATEGY != "nuc", reason="NUC rate limiting not used with API key" ) -def test_invalid_nildb_command_nucs(nildb_client): +def test_invalid_nildb_command_nucs(invalid_nildb): payload = { "model": test_models[0], "input": "What is your name?", } - response = nildb_client.post("/responses", json=payload) + response = invalid_nildb.post("/responses", json=payload) assert response.status_code == 401, "Invalid NILDB command should return 401" @@ -722,7 +627,6 @@ def test_usage_endpoint(client): "total_tokens", "completion_tokens", "prompt_tokens", - "queries", ] for key in expected_keys: assert key in usage_data, f"Expected key {key} not found in usage data" @@ -733,6 +637,10 @@ def test_usage_endpoint(client): pytest.fail(f"Error testing usage endpoint: {str(e)}") +@pytest.mark.skipif( + ENVIRONMENT != "mainnet", + reason="Attestation endpoint not available in non-mainnet environment", +) def test_attestation_endpoint(client): try: import requests diff --git a/tests/integration/nilai_api/test_users_db_integration.py b/tests/integration/nilai_api/test_users_db_integration.py index 82d8d022..a9f5663a 100644 --- a/tests/integration/nilai_api/test_users_db_integration.py +++ b/tests/integration/nilai_api/test_users_db_integration.py @@ -4,6 +4,7 @@ These tests use a real PostgreSQL database via testcontainers. """ +import uuid import pytest import json @@ -17,37 +18,19 @@ class TestUserManagerIntegration: async def test_simple_user_creation(self, clean_database): """Test creating a simple user and retrieving it.""" # Insert user with minimal data - user = await UserManager.insert_user(name="Simple Test User") + user = await UserManager.insert_user(user_id="Simple Test User") # Verify user creation - assert user.name == "Simple Test User" - assert user.userid is not None - assert user.apikey is not None - assert user.userid != user.apikey # Should be different UUIDs + assert user.user_id == "Simple Test User" + assert user.rate_limits is None, ( + f"Rate limits are not set for user {user.user_id}" + ) # Retrieve user by ID - found_user = await UserManager.check_user(user.userid) + found_user = await UserManager.check_user(user.user_id) assert found_user is not None - assert found_user.userid == user.userid - assert found_user.name == "Simple Test User" - assert found_user.apikey == user.apikey - - @pytest.mark.asyncio - async def test_api_key_validation(self, clean_database): - """Test API key validation functionality.""" - # Create user - user = await UserManager.insert_user("API Test User") - - # Validate correct API key - api_user = await UserManager.check_api_key(user.apikey) - assert api_user is not None - assert api_user.apikey == user.apikey - assert api_user.userid == user.userid - assert api_user.name == "API Test User" - - # Test invalid API key - invalid_user = await UserManager.check_api_key("invalid-api-key") - assert invalid_user is None + assert found_user.user_id == user.user_id + assert found_user.rate_limits == user.rate_limits @pytest.mark.asyncio async def test_rate_limits_json_crud_basic(self, clean_database): @@ -66,14 +49,14 @@ async def test_rate_limits_json_crud_basic(self, clean_database): # CREATE: Insert user with rate limits user = await UserManager.insert_user( - name="Rate Limits Test User", rate_limits=rate_limits + user_id="Rate Limits Test User", rate_limits=rate_limits ) # Verify rate limits are stored as JSON assert user.rate_limits == rate_limits.model_dump() # READ: Retrieve user and verify rate limits JSON - retrieved_user = await UserManager.check_user(user.userid) + retrieved_user = await UserManager.check_user(user.user_id) assert retrieved_user is not None assert retrieved_user.rate_limits == rate_limits.model_dump() @@ -98,11 +81,11 @@ async def test_rate_limits_json_update(self, clean_database): ) user = await UserManager.insert_user( - name="Update Rate Limits User", rate_limits=initial_rate_limits + user_id="Update Rate Limits User", rate_limits=initial_rate_limits ) # Verify initial rate limits - retrieved_user = await UserManager.check_user(user.userid) + retrieved_user = await UserManager.check_user(user.user_id) assert retrieved_user is not None assert retrieved_user.rate_limits == initial_rate_limits.model_dump() @@ -125,19 +108,19 @@ async def test_rate_limits_json_update(self, clean_database): stmt = sa.text(""" UPDATE users SET rate_limits = :rate_limits_json - WHERE userid = :userid + WHERE user_id = :user_id """) await session.execute( stmt, { "rate_limits_json": updated_rate_limits.model_dump_json(), - "userid": user.userid, + "user_id": user.user_id, }, ) await session.commit() # READ: Verify the update worked - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None assert updated_user.rate_limits == updated_rate_limits.model_dump() @@ -162,11 +145,11 @@ async def test_rate_limits_json_partial_update(self, clean_database): ) user = await UserManager.insert_user( - name="Partial Rate Limits User", rate_limits=partial_rate_limits + user_id="Partial Rate Limits User", rate_limits=partial_rate_limits ) # Verify partial data is stored correctly - retrieved_user = await UserManager.check_user(user.userid) + retrieved_user = await UserManager.check_user(user.user_id) assert retrieved_user is not None assert retrieved_user.rate_limits == partial_rate_limits.model_dump() @@ -183,13 +166,13 @@ async def test_rate_limits_json_partial_update(self, clean_database): '{user_rate_limit_hour}', '75' ) - WHERE userid = :userid + WHERE user_id = :user_id """) - await session.execute(stmt, {"userid": user.userid}) + await session.execute(stmt, {"user_id": user.user_id}) await session.commit() # Verify partial update worked - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None expected_data = partial_rate_limits.model_dump() @@ -211,7 +194,7 @@ async def test_rate_limits_json_null_and_delete(self, clean_database): ) user = await UserManager.insert_user( - name="Delete Rate Limits User", rate_limits=rate_limits + user_id="Delete Rate Limits User", rate_limits=rate_limits ) # DELETE: Set rate_limits to NULL @@ -219,12 +202,14 @@ async def test_rate_limits_json_null_and_delete(self, clean_database): import sqlalchemy as sa async with get_db_session() as session: - stmt = sa.text("UPDATE users SET rate_limits = NULL WHERE userid = :userid") - await session.execute(stmt, {"userid": user.userid}) + stmt = sa.text( + "UPDATE users SET rate_limits = NULL WHERE user_id = :user_id" + ) + await session.execute(stmt, {"user_id": user.user_id}) await session.commit() # Verify NULL handling - null_user = await UserManager.check_user(user.userid) + null_user = await UserManager.check_user(user.user_id) assert null_user is not None assert null_user.rate_limits is None @@ -239,15 +224,15 @@ async def test_rate_limits_json_null_and_delete(self, clean_database): # First set some data new_data = {"user_rate_limit_day": 500, "web_search_rate_limit_day": 25} stmt = sa.text( - "UPDATE users SET rate_limits = :data WHERE userid = :userid" + "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" ) await session.execute( - stmt, {"data": json.dumps(new_data), "userid": user.userid} + stmt, {"data": json.dumps(new_data), "user_id": user.user_id} ) await session.commit() # Verify data was set - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None assert updated_user.rate_limits == new_data @@ -256,13 +241,13 @@ async def test_rate_limits_json_null_and_delete(self, clean_database): stmt = sa.text(""" UPDATE users SET rate_limits = rate_limits::jsonb - 'web_search_rate_limit_day' - WHERE userid = :userid + WHERE user_id = :user_id """) - await session.execute(stmt, {"userid": user.userid}) + await session.execute(stmt, {"user_id": user.user_id}) await session.commit() # Verify field was removed - final_user = await UserManager.check_user(user.userid) + final_user = await UserManager.check_user(user.user_id) expected_final_data = {"user_rate_limit_day": 500} assert final_user is not None assert final_user.rate_limits == expected_final_data @@ -293,15 +278,15 @@ async def test_rate_limits_json_validation_and_conversion(self, clean_database): for i, test_data in enumerate(test_cases): async with get_db_session() as session: stmt = sa.text( - "UPDATE users SET rate_limits = :data WHERE userid = :userid" + "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" ) await session.execute( - stmt, {"data": json.dumps(test_data), "userid": user.userid} + stmt, {"data": json.dumps(test_data), "user_id": user.user_id} ) await session.commit() # Retrieve and verify - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None assert updated_user.rate_limits == test_data @@ -327,11 +312,13 @@ async def test_rate_limits_json_validation_and_conversion(self, clean_database): # Test empty JSON object async with get_db_session() as session: - stmt = sa.text("UPDATE users SET rate_limits = '{}' WHERE userid = :userid") - await session.execute(stmt, {"userid": user.userid}) + stmt = sa.text( + "UPDATE users SET rate_limits = '{}' WHERE user_id = :user_id" + ) + await session.execute(stmt, {"user_id": user.user_id}) await session.commit() - empty_user = await UserManager.check_user(user.userid) + empty_user = await UserManager.check_user(user.user_id) assert empty_user is not None assert empty_user.rate_limits == {} empty_rate_limits_obj = empty_user.rate_limits_obj @@ -343,18 +330,18 @@ async def test_rate_limits_json_validation_and_conversion(self, clean_database): async with get_db_session() as session: # This should work as PostgreSQL JSONB validates JSON stmt = sa.text( - "UPDATE users SET rate_limits = :data WHERE userid = :userid" + "UPDATE users SET rate_limits = :data WHERE user_id = :user_id" ) await session.execute( stmt, { "data": '{"user_rate_limit_day": 5000}', # Valid JSON string - "userid": user.userid, + "user_id": user.user_id, }, ) await session.commit() - json_string_user = await UserManager.check_user(user.userid) + json_string_user = await UserManager.check_user(user.user_id) assert json_string_user is not None assert json_string_user.rate_limits == {"user_rate_limit_day": 5000} @@ -366,16 +353,15 @@ async def test_rate_limits_json_validation_and_conversion(self, clean_database): async def test_rate_limits_update_workflow(self, clean_database): """Test complete workflow: create user with no rate limits -> update rate limits -> verify update.""" # Step 1: Create user with NO rate limits - user = await UserManager.insert_user(name="Rate Limits Workflow User") + user_id = str(uuid.uuid4()) + user = await UserManager.insert_user(user_id=user_id) # Verify user was created with no rate limits - assert user.name == "Rate Limits Workflow User" - assert user.userid is not None - assert user.apikey is not None + assert user.user_id == user_id assert user.rate_limits is None # No rate limits initially # Step 2: Retrieve user and confirm no rate limits - retrieved_user = await UserManager.check_user(user.userid) + retrieved_user = await UserManager.check_user(user.user_id) assert retrieved_user is not None print(retrieved_user.to_pydantic()) assert retrieved_user is not None @@ -401,12 +387,12 @@ async def test_rate_limits_update_workflow(self, clean_database): # Step 4: Update the user's rate limits using the new function update_success = await UserManager.update_rate_limits( - user.userid, new_rate_limits + user.user_id, new_rate_limits ) assert update_success is True # Step 5: Retrieve user again and verify rate limits were updated - updated_user = await UserManager.check_user(user.userid) + updated_user = await UserManager.check_user(user.user_id) assert updated_user is not None assert updated_user.rate_limits is not None assert updated_user.rate_limits == new_rate_limits.model_dump() @@ -431,12 +417,12 @@ async def test_rate_limits_update_workflow(self, clean_database): ) partial_update_success = await UserManager.update_rate_limits( - user.userid, partial_rate_limits + user.user_id, partial_rate_limits ) assert partial_update_success is True # Step 8: Verify partial update worked - final_user = await UserManager.check_user(user.userid) + final_user = await UserManager.check_user(user.user_id) assert final_user is not None assert final_user.rate_limits == partial_rate_limits.model_dump() @@ -447,8 +433,8 @@ async def test_rate_limits_update_workflow(self, clean_database): # Other fields should have config defaults (not None due to get_effective_limits) # Step 9: Test error case - update non-existent user - fake_userid = "non-existent-user-id" + fake_user_id = "non-existent-user-id" error_update = await UserManager.update_rate_limits( - fake_userid, new_rate_limits + fake_user_id, new_rate_limits ) assert error_update is False diff --git a/tests/unit/nilai-common/test_discovery.py b/tests/unit/nilai-common/test_discovery.py index 8899eaff..363e99d0 100644 --- a/tests/unit/nilai-common/test_discovery.py +++ b/tests/unit/nilai-common/test_discovery.py @@ -2,7 +2,7 @@ import pytest import pytest_asyncio -from nilai_common.api_model import ModelEndpoint, ModelMetadata +from nilai_common.api_models import ModelEndpoint, ModelMetadata from nilai_common.discovery import ModelServiceDiscovery diff --git a/tests/unit/nilai_api/__init__.py b/tests/unit/nilai_api/__init__.py index 0be52613..7cbc1237 100644 --- a/tests/unit/nilai_api/__init__.py +++ b/tests/unit/nilai_api/__init__.py @@ -21,11 +21,11 @@ def generate_api_key(self) -> str: async def insert_user(self, name: str, email: str) -> Dict[str, str]: """Insert a new user into the mock database.""" - userid = self.generate_user_id() + user_id = self.generate_user_id() apikey = self.generate_api_key() user_data = { - "userid": userid, + "user_id": user_id, "name": name, "email": email, "apikey": apikey, @@ -36,34 +36,34 @@ async def insert_user(self, name: str, email: str) -> Dict[str, str]: "last_activity": None, } - self.users[userid] = user_data - return {"userid": userid, "apikey": apikey} + self.users[user_id] = user_data + return {"user_id": user_id, "apikey": apikey} async def check_api_key(self, api_key: str) -> Optional[dict]: """Validate an API key in the mock database.""" for user in self.users.values(): if user["apikey"] == api_key: - return {"name": user["name"], "userid": user["userid"]} + return {"name": user["name"], "user_id": user["user_id"]} return None async def update_token_usage( - self, userid: str, prompt_tokens: int, completion_tokens: int + self, user_id: str, prompt_tokens: int, completion_tokens: int ): """Update token usage for a specific user.""" - if userid in self.users: - user = self.users[userid] + if user_id in self.users: + user = self.users[user_id] user["prompt_tokens"] += prompt_tokens user["completion_tokens"] += completion_tokens user["queries"] += 1 user["last_activity"] = datetime.now(timezone.utc) async def log_query( - self, userid: str, model: str, prompt_tokens: int, completion_tokens: int + self, user_id: str, model: str, prompt_tokens: int, completion_tokens: int ): """Log a user's query in the mock database.""" query_log = { "id": self._next_query_log_id, - "userid": userid, + "user_id": user_id, "query_timestamp": datetime.now(timezone.utc), "model": model, "prompt_tokens": prompt_tokens, @@ -74,9 +74,9 @@ async def log_query( self.query_logs[self._next_query_log_id] = query_log self._next_query_log_id += 1 - async def get_token_usage(self, userid: str) -> Optional[Dict[str, Any]]: + async def get_token_usage(self, user_id: str) -> Optional[Dict[str, Any]]: """Get token usage for a specific user.""" - user = self.users.get(userid) + user = self.users.get(user_id) if user: return { "prompt_tokens": user["prompt_tokens"], @@ -90,9 +90,9 @@ async def get_all_users(self) -> Optional[List[Dict[str, Any]]]: """Retrieve all users from the mock database.""" return list(self.users.values()) if self.users else None - async def get_user_token_usage(self, userid: str) -> Optional[Dict[str, int]]: + async def get_user_token_usage(self, user_id: str) -> Optional[Dict[str, int]]: """Retrieve total token usage for a user.""" - user = self.users.get(userid) + user = self.users.get(user_id) if user: return { "prompt_tokens": user["prompt_tokens"], diff --git a/tests/unit/nilai_api/auth/test_auth.py b/tests/unit/nilai_api/auth/test_auth.py index 591c447a..47559272 100644 --- a/tests/unit/nilai_api/auth/test_auth.py +++ b/tests/unit/nilai_api/auth/test_auth.py @@ -1,10 +1,8 @@ -from datetime import datetime, timezone import logging from unittest.mock import MagicMock from nilai_api.db.users import RateLimits import pytest -from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials from nilai_api.config import CONFIG as config @@ -14,13 +12,9 @@ @pytest.fixture -def mock_user_manager(mocker): - from nilai_api.db.users import UserManager - - """Fixture to mock UserManager methods.""" - mocker.patch.object(UserManager, "check_api_key") - mocker.patch.object(UserManager, "update_last_activity") - return UserManager +def mock_validate_credential(mocker): + """Fixture to mock validate_credential function.""" + return mocker.patch("nilai_api.auth.strategies.validate_credential") @pytest.fixture @@ -28,14 +22,7 @@ def mock_user_model(): from nilai_api.db.users import UserModel mock = MagicMock(spec=UserModel) - mock.name = "Test User" - mock.userid = "test-user-id" - mock.apikey = "test-api-key" - mock.prompt_tokens = 0 - mock.completion_tokens = 0 - mock.queries = 0 - mock.signup_date = datetime.now(timezone.utc) - mock.last_activity = datetime.now(timezone.utc) + mock.user_id = "test-user-id" mock.rate_limits = RateLimits().get_effective_limits().model_dump_json() mock.rate_limits_obj = RateLimits().get_effective_limits() return mock @@ -49,53 +36,38 @@ def mock_user_data(mock_user_model): return UserData.from_sqlalchemy(mock_user_model) -@pytest.fixture -def mock_auth_info(): - from nilai_api.auth import AuthenticationInfo - - mock = MagicMock(spec=AuthenticationInfo) - mock.user = mock_user_data - return mock - - @pytest.mark.asyncio -async def test_get_auth_info_valid_token( - mock_user_manager, mock_auth_info, mock_user_model -): +async def test_get_auth_info_valid_token(mock_validate_credential, mock_user_model): from nilai_api.auth import get_auth_info """Test get_auth_info with a valid token.""" - mock_user_manager.check_api_key.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model credentials = HTTPAuthorizationCredentials( scheme="Bearer", credentials="valid-token" ) auth_info = await get_auth_info(credentials) print(auth_info) - assert auth_info.user.name == "Test User", ( - f"Expected Test User but got {auth_info.user.name}" - ) - assert auth_info.user.userid == "test-user-id", ( - f"Expected test-user-id but got {auth_info.user.userid}" + + assert auth_info.user.user_id == "test-user-id", ( + f"Expected test-user-id but got {auth_info.user.user_id}" ) @pytest.mark.asyncio -async def test_get_auth_info_invalid_token(mock_user_manager): +async def test_get_auth_info_invalid_token(mock_validate_credential): from nilai_api.auth import get_auth_info + from nilai_api.auth.common import AuthenticationError """Test get_auth_info with an invalid token.""" - mock_user_manager.check_api_key.return_value = None + mock_validate_credential.side_effect = AuthenticationError("Credential not found") credentials = HTTPAuthorizationCredentials( scheme="Bearer", credentials="invalid-token" ) - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(AuthenticationError) as exc_info: auth_infor = await get_auth_info(credentials) print(auth_infor) print(exc_info) - assert exc_info.value.status_code == 401, ( - f"Expected status code 401 but got {exc_info.value.status_code}" - ) - assert exc_info.value.detail == "Missing or invalid API key", ( - f"Expected Missing or invalid API key but got {exc_info.value.detail}" + assert "Credential not found" in str(exc_info.value.detail), ( + f"Expected 'Credential not found' but got {exc_info.value.detail}" ) diff --git a/tests/unit/nilai_api/auth/test_strategies.py b/tests/unit/nilai_api/auth/test_strategies.py index 0c169f53..5b65c5b0 100644 --- a/tests/unit/nilai_api/auth/test_strategies.py +++ b/tests/unit/nilai_api/auth/test_strategies.py @@ -1,7 +1,6 @@ import pytest from unittest.mock import patch, MagicMock from datetime import datetime, timezone, timedelta -from fastapi import HTTPException from nilai_api.auth.strategies import api_key_strategy, nuc_strategy from nilai_api.auth.common import AuthenticationInfo, PromptDocument @@ -15,14 +14,7 @@ class TestAuthStrategies: def mock_user_model(self): """Mock UserModel fixture""" mock = MagicMock(spec=UserModel) - mock.name = "Test User" - mock.userid = "test-user-id" - mock.apikey = "test-api-key" - mock.prompt_tokens = 0 - mock.completion_tokens = 0 - mock.queries = 0 - mock.signup_date = datetime.now(timezone.utc) - mock.last_activity = datetime.now(timezone.utc) + mock.user_id = "test-user-id" mock.rate_limits = RateLimits().get_effective_limits().model_dump_json() mock.rate_limits_obj = RateLimits().get_effective_limits() return mock @@ -37,27 +29,26 @@ def mock_prompt_document(self): @pytest.mark.asyncio async def test_api_key_strategy_success(self, mock_user_model): """Test successful API key authentication""" - with patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check: - mock_check.return_value = mock_user_model + with patch("nilai_api.auth.strategies.validate_credential") as mock_validate: + mock_validate.return_value = mock_user_model result = await api_key_strategy("test-api-key") assert isinstance(result, AuthenticationInfo) - assert result.user.name == "Test User" assert result.token_rate_limit is None assert result.prompt_document is None + mock_validate.assert_called_once_with("test-api-key", is_public=False) @pytest.mark.asyncio async def test_api_key_strategy_invalid_key(self): """Test API key authentication with invalid key""" - with patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check: - mock_check.return_value = None + from nilai_api.auth.common import AuthenticationError - with pytest.raises(HTTPException) as exc_info: - await api_key_strategy("invalid-key") + with patch("nilai_api.auth.strategies.validate_credential") as mock_validate: + mock_validate.side_effect = AuthenticationError("Credential not found") - assert exc_info.value.status_code == 401 - assert "Missing or invalid API key" in str(exc_info.value.detail) + with pytest.raises(AuthenticationError, match="Credential not found"): + await api_key_strategy("invalid-key") @pytest.mark.asyncio async def test_nuc_strategy_existing_user_with_prompt_document( @@ -65,7 +56,7 @@ async def test_nuc_strategy_existing_user_with_prompt_document( ): """Test NUC authentication with existing user and prompt document""" with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -73,23 +64,27 @@ async def test_nuc_strategy_existing_user_with_prompt_document( "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "user_id") + mock_validate_nuc.return_value = ("subscription_holder", "user_id") mock_get_rate_limit.return_value = None mock_get_prompt_doc.return_value = mock_prompt_document - mock_check_user.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model result = await nuc_strategy("nuc-token") assert isinstance(result, AuthenticationInfo) - assert result.user.name == "Test User" assert result.token_rate_limit is None assert result.prompt_document == mock_prompt_document + mock_validate_credential.assert_called_once_with( + "subscription_holder", is_public=True + ) @pytest.mark.asyncio - async def test_nuc_strategy_new_user_with_token_limits(self, mock_prompt_document): + async def test_nuc_strategy_new_user_with_token_limits( + self, mock_prompt_document, mock_user_model + ): """Test NUC authentication creating new user with token limits""" from nilai_api.auth.nuc_helpers.usage import TokenRateLimits, TokenRateLimit @@ -104,7 +99,7 @@ async def test_nuc_strategy_new_user_with_token_limits(self, mock_prompt_documen ) with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -112,30 +107,28 @@ async def test_nuc_strategy_new_user_with_token_limits(self, mock_prompt_documen "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, - patch( - "nilai_api.auth.strategies.UserManager.insert_user_model" - ) as mock_insert, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "new_user_id") + mock_validate_nuc.return_value = ("subscription_holder", "new_user_id") mock_get_rate_limit.return_value = mock_token_limits mock_get_prompt_doc.return_value = mock_prompt_document - mock_check_user.return_value = None - mock_insert.return_value = None + mock_validate_credential.return_value = mock_user_model result = await nuc_strategy("nuc-token") assert isinstance(result, AuthenticationInfo) assert result.token_rate_limit == mock_token_limits assert result.prompt_document == mock_prompt_document - mock_insert.assert_called_once() + mock_validate_credential.assert_called_once_with( + "subscription_holder", is_public=True + ) @pytest.mark.asyncio async def test_nuc_strategy_no_prompt_document(self, mock_user_model): """Test NUC authentication when no prompt document is found""" with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -143,18 +136,17 @@ async def test_nuc_strategy_no_prompt_document(self, mock_user_model): "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "user_id") + mock_validate_nuc.return_value = ("subscription_holder", "user_id") mock_get_rate_limit.return_value = None mock_get_prompt_doc.return_value = None - mock_check_user.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model result = await nuc_strategy("nuc-token") assert isinstance(result, AuthenticationInfo) - assert result.user.name == "Test User" assert result.token_rate_limit is None assert result.prompt_document is None @@ -171,7 +163,7 @@ async def test_nuc_strategy_validation_error(self): async def test_nuc_strategy_get_prompt_document_error(self, mock_user_model): """Test NUC authentication when get_token_prompt_document fails""" with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -179,15 +171,15 @@ async def test_nuc_strategy_get_prompt_document_error(self, mock_user_model): "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "user_id") + mock_validate_nuc.return_value = ("subscription_holder", "user_id") mock_get_rate_limit.return_value = None mock_get_prompt_doc.side_effect = Exception( "Prompt document extraction failed" ) - mock_check_user.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model # The function should let the exception bubble up or handle it gracefully # Based on the diff, it looks like it doesn't catch exceptions from get_token_prompt_document @@ -200,29 +192,22 @@ async def test_all_strategies_return_authentication_info_with_prompt_document_fi ): """Test that all strategies return AuthenticationInfo with prompt_document field""" mock_user_model = MagicMock(spec=UserModel) - mock_user_model.name = "Test" - mock_user_model.userid = "test" - mock_user_model.apikey = "test" - mock_user_model.prompt_tokens = 0 - mock_user_model.completion_tokens = 0 - mock_user_model.queries = 0 - mock_user_model.signup_date = datetime.now(timezone.utc) - mock_user_model.last_activity = datetime.now(timezone.utc) + mock_user_model.user_id = "test" mock_user_model.rate_limits = ( RateLimits().get_effective_limits().model_dump_json() ) mock_user_model.rate_limits_obj = RateLimits().get_effective_limits() # Test API key strategy - with patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check: - mock_check.return_value = mock_user_model + with patch("nilai_api.auth.strategies.validate_credential") as mock_validate: + mock_validate.return_value = mock_user_model result = await api_key_strategy("test-key") assert hasattr(result, "prompt_document") assert result.prompt_document is None # Test NUC strategy with ( - patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate_nuc, patch( "nilai_api.auth.strategies.get_token_rate_limit" ) as mock_get_rate_limit, @@ -230,13 +215,13 @@ async def test_all_strategies_return_authentication_info_with_prompt_document_fi "nilai_api.auth.strategies.get_token_prompt_document" ) as mock_get_prompt_doc, patch( - "nilai_api.auth.strategies.UserManager.check_user" - ) as mock_check_user, + "nilai_api.auth.strategies.validate_credential" + ) as mock_validate_credential, ): - mock_validate.return_value = ("subscription_holder", "user_id") + mock_validate_nuc.return_value = ("subscription_holder", "user_id") mock_get_rate_limit.return_value = None mock_get_prompt_doc.return_value = None - mock_check_user.return_value = mock_user_model + mock_validate_credential.return_value = mock_user_model result = await nuc_strategy("nuc-token") assert hasattr(result, "prompt_document") diff --git a/tests/unit/nilai_api/routers/test_chat_completions_private.py b/tests/unit/nilai_api/routers/test_chat_completions_private.py index 2ba88cc8..4c1f30b2 100644 --- a/tests/unit/nilai_api/routers/test_chat_completions_private.py +++ b/tests/unit/nilai_api/routers/test_chat_completions_private.py @@ -20,15 +20,7 @@ async def test_runs_in_a_loop(): @pytest.fixture def mock_user(): mock = MagicMock(spec=UserModel) - mock.userid = "test-user-id" - mock.name = "Test User" - mock.apikey = "test-api-key" - mock.prompt_tokens = 100 - mock.completion_tokens = 50 - mock.total_tokens = 150 - mock.completion_tokens_details = None - mock.prompt_tokens_details = None - mock.queries = 10 + mock.user_id = "test-user-id" mock.rate_limits = RateLimits().get_effective_limits().model_dump_json() mock.rate_limits_obj = RateLimits().get_effective_limits() return mock @@ -38,62 +30,30 @@ def mock_user(): def mock_user_manager(mock_user, mocker): from nilai_api.db.logs import QueryLogManager from nilai_api.db.users import UserManager + from nilai_common import Usage + # Mock QueryLogManager for usage tracking mocker.patch.object( - UserManager, - "get_token_usage", - return_value={ - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "queries": 10, - }, - ) - mocker.patch.object(UserManager, "update_token_usage") - mocker.patch.object( - UserManager, + QueryLogManager, "get_user_token_usage", - return_value={ - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "completion_tokens_details": None, - "prompt_tokens_details": None, - "queries": 10, - }, - ) - mocker.patch.object( - UserManager, - "insert_user", - return_value={ - "userid": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, + new_callable=AsyncMock, + return_value=Usage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + completion_tokens_details=None, + prompt_tokens_details=None, + ), ) - mocker.patch.object( - UserManager, - "check_api_key", + mocker.patch.object(QueryLogManager, "log_query", new_callable=AsyncMock) + + # Mock validate_credential for authentication + mocker.patch( + "nilai_api.auth.strategies.validate_credential", + new_callable=AsyncMock, return_value=mock_user, ) - mocker.patch.object( - UserManager, - "get_all_users", - return_value=[ - { - "userid": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - { - "userid": "test-user-id-2", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - ], - ) - mocker.patch.object(QueryLogManager, "log_query") - mocker.patch.object(UserManager, "update_last_activity") + return UserManager @@ -117,6 +77,7 @@ def mock_state(mocker): # Patch get_attestation method attestation_response = AttestationReport( + nonce="0" * 64, verifying_key="test-verifying-key", cpu_attestation="test-cpu-attestation", gpu_attestation="test-gpu-attestation", @@ -173,7 +134,6 @@ def test_get_usage(mock_user, mock_user_manager, mock_state, client): "total_tokens": 150, "completion_tokens_details": None, "prompt_tokens_details": None, - "queries": 10, } @@ -220,6 +180,7 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien "nilai_api.routers.endpoints.chat.handle_tool_workflow", return_value=(response_data, 0, 0), ) + mocker.patch("nilai_api.db.logs.QueryLogContext.commit", new_callable=AsyncMock) response = client.post( "/v1/chat/completions", json={ @@ -260,6 +221,7 @@ def test_chat_completion_stream_includes_sources( "nilai_api.routers.endpoints.chat.handle_web_search", new=AsyncMock(return_value=mock_web_search_result), ) + mocker.patch("nilai_api.db.logs.QueryLogContext.commit", new_callable=AsyncMock) class MockChunk: def __init__(self, data, usage=None): diff --git a/tests/unit/nilai_api/routers/test_nildb_endpoints.py b/tests/unit/nilai_api/routers/test_nildb_endpoints.py index b54b664c..ff1ecdd6 100644 --- a/tests/unit/nilai_api/routers/test_nildb_endpoints.py +++ b/tests/unit/nilai_api/routers/test_nildb_endpoints.py @@ -7,7 +7,6 @@ from nilai_api.handlers.nildb.api_model import ( PromptDelegationToken, ) -from datetime import datetime, timezone from nilai_common import ResponseRequest @@ -18,14 +17,7 @@ class TestNilDBEndpoints: def mock_subscription_owner_user(self): """Mock user data for subscription owner""" mock_user_model = MagicMock(spec=UserModel) - mock_user_model.name = "Subscription Owner" - mock_user_model.userid = "owner-id" - mock_user_model.apikey = "owner-id" # Same as userid for subscription owner - mock_user_model.prompt_tokens = 0 - mock_user_model.completion_tokens = 0 - mock_user_model.queries = 0 - mock_user_model.signup_date = datetime.now(timezone.utc) - mock_user_model.last_activity = datetime.now(timezone.utc) + mock_user_model.user_id = "owner-id" mock_user_model.rate_limits = ( RateLimits().get_effective_limits().model_dump_json() ) @@ -37,14 +29,7 @@ def mock_subscription_owner_user(self): def mock_regular_user(self): """Mock user data for regular user (not subscription owner)""" mock_user_model = MagicMock(spec=UserModel) - mock_user_model.name = "Regular User" - mock_user_model.userid = "user-id" - mock_user_model.apikey = "different-api-key" # Different from userid - mock_user_model.prompt_tokens = 0 - mock_user_model.completion_tokens = 0 - mock_user_model.queries = 0 - mock_user_model.signup_date = datetime.now(timezone.utc) - mock_user_model.last_activity = datetime.now(timezone.utc) + mock_user_model.user_id = "user-id" mock_user_model.rate_limits = ( RateLimits().get_effective_limits().model_dump_json() ) @@ -99,21 +84,25 @@ async def test_get_prompt_store_delegation_success( mock_get_delegation.assert_called_once_with("user-123") @pytest.mark.asyncio - async def test_get_prompt_store_delegation_forbidden_regular_user( - self, mock_auth_info_regular_user + async def test_get_prompt_store_delegation_success_regular_user( + self, mock_auth_info_regular_user, mock_prompt_delegation_token ): - """Test delegation token request by regular user (not subscription owner)""" + """Test delegation token request by regular user (endpoint no longer checks subscription ownership)""" from nilai_api.routers.private import get_prompt_store_delegation - request = "user-123" + with patch( + "nilai_api.routers.private.get_nildb_delegation_token" + ) as mock_get_delegation: + mock_get_delegation.return_value = mock_prompt_delegation_token - with pytest.raises(HTTPException) as exc_info: - await get_prompt_store_delegation(request, mock_auth_info_regular_user) + request = "user-123" - assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN - assert "Prompt storage is reserved to subscription owners" in str( - exc_info.value.detail - ) + result = await get_prompt_store_delegation( + request, mock_auth_info_regular_user + ) + + assert isinstance(result, PromptDelegationToken) + assert result.token == "delegation_token_123" @pytest.mark.asyncio async def test_get_prompt_store_delegation_handler_error( @@ -150,7 +139,7 @@ async def test_chat_completion_with_prompt_document_injection(self): ) mock_user = MagicMock() - mock_user.userid = "test-user-id" + mock_user.user_id = "test-user-id" mock_user.name = "Test User" mock_user.apikey = "test-api-key" mock_user.rate_limits = RateLimits().get_effective_limits() @@ -163,6 +152,14 @@ async def test_chat_completion_with_prompt_document_injection(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + # Mock log context + mock_log_ctx = MagicMock() + mock_log_ctx.set_user = MagicMock() + mock_log_ctx.set_model = MagicMock() + mock_log_ctx.set_request_params = MagicMock() + mock_log_ctx.start_model_timing = MagicMock() + mock_log_ctx.end_model_timing = MagicMock() + request = ChatRequest( model="test-model", messages=[{"role": "user", "content": "Hello"}] ) @@ -179,12 +176,6 @@ async def test_chat_completion_with_prompt_document_injection(self): patch( "nilai_api.routers.endpoints.chat.handle_web_search" ) as mock_handle_web_search, - patch( - "nilai_api.routers.endpoints.chat.UserManager.update_token_usage" - ) as mock_update_usage, - patch( - "nilai_api.routers.endpoints.chat.QueryLogManager.log_query" - ) as mock_log_query, patch( "nilai_api.routers.endpoints.chat.handle_tool_workflow" ) as mock_handle_tool_workflow, @@ -205,10 +196,6 @@ async def test_chat_completion_with_prompt_document_injection(self): mock_web_search_result.sources = [] mock_handle_web_search.return_value = mock_web_search_result - # Mock async database operations - mock_update_usage.return_value = None - mock_log_query.return_value = None - # Mock OpenAI client mock_client_instance = MagicMock() mock_response = MagicMock() @@ -250,7 +237,10 @@ async def test_chat_completion_with_prompt_document_injection(self): # Call the function (this will test the prompt injection logic) await chat_completion( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) mock_get_prompt.assert_called_once_with(mock_prompt_document) @@ -266,9 +256,7 @@ async def test_chat_completion_prompt_document_extraction_error(self): ) mock_user = MagicMock() - mock_user.userid = "test-user-id" - mock_user.name = "Test User" - mock_user.apikey = "test-api-key" + mock_user.user_id = "test-user-id" mock_user.rate_limits = RateLimits().get_effective_limits() mock_auth_info = AuthenticationInfo( @@ -279,6 +267,12 @@ async def test_chat_completion_prompt_document_extraction_error(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + # Mock log context + mock_log_ctx = MagicMock() + mock_log_ctx.set_user = MagicMock() + mock_log_ctx.set_model = MagicMock() + mock_log_ctx.set_request_params = MagicMock() + request = ChatRequest( model="test-model", messages=[{"role": "user", "content": "Hello"}] ) @@ -300,7 +294,10 @@ async def test_chat_completion_prompt_document_extraction_error(self): with pytest.raises(HTTPException) as exc_info: await chat_completion( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @@ -316,9 +313,7 @@ async def test_chat_completion_without_prompt_document(self): from nilai_common import ChatRequest mock_user = MagicMock() - mock_user.userid = "test-user-id" - mock_user.name = "Test User" - mock_user.apikey = "test-api-key" + mock_user.user_id = "test-user-id" mock_user.rate_limits = RateLimits().get_effective_limits() mock_auth_info = AuthenticationInfo( @@ -331,6 +326,14 @@ async def test_chat_completion_without_prompt_document(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + # Mock log context + mock_log_ctx = MagicMock() + mock_log_ctx.set_user = MagicMock() + mock_log_ctx.set_model = MagicMock() + mock_log_ctx.set_request_params = MagicMock() + mock_log_ctx.start_model_timing = MagicMock() + mock_log_ctx.end_model_timing = MagicMock() + request = ChatRequest( model="test-model", messages=[{"role": "user", "content": "Hello"}] ) @@ -347,12 +350,6 @@ async def test_chat_completion_without_prompt_document(self): patch( "nilai_api.routers.endpoints.chat.handle_web_search" ) as mock_handle_web_search, - patch( - "nilai_api.routers.endpoints.chat.UserManager.update_token_usage" - ) as mock_update_usage, - patch( - "nilai_api.routers.endpoints.chat.QueryLogManager.log_query" - ) as mock_log_query, patch( "nilai_api.routers.endpoints.chat.handle_tool_workflow" ) as mock_handle_tool_workflow, @@ -371,10 +368,6 @@ async def test_chat_completion_without_prompt_document(self): mock_web_search_result.sources = [] mock_handle_web_search.return_value = mock_web_search_result - # Mock async database operations - mock_update_usage.return_value = None - mock_log_query.return_value = None - # Mock OpenAI client mock_client_instance = MagicMock() mock_response = MagicMock() @@ -412,7 +405,10 @@ async def test_chat_completion_without_prompt_document(self): # Call the function await chat_completion( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) # Should not call get_prompt_from_nildb when no prompt document @@ -428,7 +424,7 @@ async def test_responses_with_prompt_document_injection(self): ) mock_user = MagicMock() - mock_user.userid = "test-user-id" + mock_user.user_id = "test-user-id" mock_user.name = "Test User" mock_user.apikey = "test-api-key" mock_user.rate_limits = RateLimits().get_effective_limits() @@ -468,12 +464,6 @@ async def test_responses_with_prompt_document_injection(self): patch( "nilai_api.routers.endpoints.responses.state.get_model" ) as mock_get_model, - patch( - "nilai_api.routers.endpoints.responses.UserManager.update_token_usage" - ) as mock_update_usage, - patch( - "nilai_api.routers.endpoints.responses.QueryLogManager.log_query" - ) as mock_log_query, patch( "nilai_api.routers.endpoints.responses.handle_responses_tool_workflow" ) as mock_handle_tool_workflow, @@ -486,9 +476,6 @@ async def test_responses_with_prompt_document_injection(self): mock_model_endpoint.metadata.multimodal_support = True mock_get_model.return_value = mock_model_endpoint - mock_update_usage.return_value = None - mock_log_query.return_value = None - mock_client_instance = MagicMock() mock_response = MagicMock() mock_response.model_dump.return_value = response_payload @@ -502,8 +489,13 @@ async def test_responses_with_prompt_document_injection(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + mock_log_ctx = MagicMock() + await create_response( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) mock_get_prompt.assert_called_once_with(mock_prompt_document) @@ -518,7 +510,7 @@ async def test_responses_prompt_document_extraction_error(self): ) mock_user = MagicMock() - mock_user.userid = "test-user-id" + mock_user.user_id = "test-user-id" mock_user.name = "Test User" mock_user.apikey = "test-api-key" mock_user.rate_limits = RateLimits().get_effective_limits() @@ -545,8 +537,16 @@ async def test_responses_prompt_document_extraction_error(self): mock_get_prompt.side_effect = Exception("Unable to extract prompt") + mock_meter = MagicMock() + mock_log_ctx = MagicMock() + with pytest.raises(HTTPException) as exc_info: - await create_response(req=request, auth_info=mock_auth_info) + await create_response( + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, + ) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert ( @@ -560,7 +560,7 @@ async def test_responses_without_prompt_document(self): from nilai_api.routers.endpoints.responses import create_response mock_user = MagicMock() - mock_user.userid = "test-user-id" + mock_user.user_id = "test-user-id" mock_user.name = "Test User" mock_user.apikey = "test-api-key" mock_user.rate_limits = RateLimits().get_effective_limits() @@ -602,12 +602,6 @@ async def test_responses_without_prompt_document(self): patch( "nilai_api.routers.endpoints.responses.state.get_model" ) as mock_get_model, - patch( - "nilai_api.routers.endpoints.responses.UserManager.update_token_usage" - ) as mock_update_usage, - patch( - "nilai_api.routers.endpoints.responses.QueryLogManager.log_query" - ) as mock_log_query, patch( "nilai_api.routers.endpoints.responses.handle_responses_tool_workflow" ) as mock_handle_tool_workflow, @@ -618,9 +612,6 @@ async def test_responses_without_prompt_document(self): mock_model_endpoint.metadata.multimodal_support = True mock_get_model.return_value = mock_model_endpoint - mock_update_usage.return_value = None - mock_log_query.return_value = None - mock_client_instance = MagicMock() mock_response = MagicMock() mock_response.model_dump.return_value = response_payload @@ -634,8 +625,13 @@ async def test_responses_without_prompt_document(self): mock_meter = MagicMock() mock_meter.set_response = MagicMock() + mock_log_ctx = MagicMock() + await create_response( - req=request, auth_info=mock_auth_info, meter=mock_meter + req=request, + auth_info=mock_auth_info, + meter=mock_meter, + log_ctx=mock_log_ctx, ) mock_get_prompt.assert_not_called() @@ -658,12 +654,10 @@ def test_prompt_delegation_token_model_validation(self): assert token.token == "delegation_token_123" assert token.did == "did:nil:builder123" - def test_user_is_subscription_owner_property( - self, mock_subscription_owner_user, mock_regular_user - ): - """Test the is_subscription_owner property""" - # Subscription owner (userid == apikey) - assert mock_subscription_owner_user.is_subscription_owner is True - - # Regular user (userid != apikey) - assert mock_regular_user.is_subscription_owner is False + def test_user_data_structure(self, mock_subscription_owner_user, mock_regular_user): + """Test the UserData structure has required fields""" + # Check that UserData has the expected fields + assert hasattr(mock_subscription_owner_user, "user_id") + assert hasattr(mock_subscription_owner_user, "rate_limits") + assert hasattr(mock_regular_user, "user_id") + assert hasattr(mock_regular_user, "rate_limits") diff --git a/tests/unit/nilai_api/routers/test_responses_private.py b/tests/unit/nilai_api/routers/test_responses_private.py index cf122c79..d5962dfc 100644 --- a/tests/unit/nilai_api/routers/test_responses_private.py +++ b/tests/unit/nilai_api/routers/test_responses_private.py @@ -24,7 +24,7 @@ async def test_runs_in_a_loop(): @pytest.fixture def mock_user(): mock = MagicMock(spec=UserModel) - mock.userid = "test-user-id" + mock.user_id = "test-user-id" mock.name = "Test User" mock.apikey = "test-api-key" mock.prompt_tokens = 100 @@ -43,61 +43,26 @@ def mock_user_manager(mock_user, mocker): from nilai_api.db.users import UserManager from nilai_api.db.logs import QueryLogManager + # Patch QueryLogManager for usage mocker.patch.object( - UserManager, - "get_token_usage", - return_value={ - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "queries": 10, - }, - ) - mocker.patch.object(UserManager, "update_token_usage") - mocker.patch.object( - UserManager, + QueryLogManager, "get_user_token_usage", return_value={ "prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150, - "completion_tokens_details": None, - "prompt_tokens_details": None, "queries": 10, }, ) - mocker.patch.object( - UserManager, - "insert_user", - return_value={ - "userid": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - ) - mocker.patch.object( - UserManager, - "check_api_key", + mocker.patch.object(QueryLogManager, "log_query") + + # Mock validate_credential for authentication + mocker.patch( + "nilai_api.auth.strategies.validate_credential", + new_callable=AsyncMock, return_value=mock_user, ) - mocker.patch.object( - UserManager, - "get_all_users", - return_value=[ - { - "userid": "test-user-id", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - { - "userid": "test-user-id-2", - "apikey": "test-api-key", - "rate_limits": RateLimits().get_effective_limits().model_dump_json(), - }, - ], - ) - mocker.patch.object(QueryLogManager, "log_query") - mocker.patch.object(UserManager, "update_last_activity") + return UserManager @@ -107,6 +72,7 @@ def mock_state(mocker): mock_discovery_service = mocker.Mock() mock_discovery_service.discover_models = AsyncMock(return_value=expected_models) + mock_discovery_service.initialize = AsyncMock() mocker.patch.object(state, "discovery_service", mock_discovery_service) @@ -138,7 +104,7 @@ def mock_metering_context(mocker): @pytest.fixture -def client(mock_user_manager, mock_metering_context): +def client(mock_user_manager, mock_state, mock_metering_context): from nilai_api.app import app from nilai_api.credit import LLMMeter @@ -210,6 +176,11 @@ def test_create_response(mock_user, mock_state, mock_user_manager, mocker, clien "nilai_api.routers.endpoints.responses.handle_responses_tool_workflow", return_value=(response_data, 0, 0), ) + mocker.patch( + "nilai_api.routers.endpoints.responses.state.get_model", + return_value=model_endpoint, + ) + mocker.patch("nilai_api.db.logs.QueryLogContext.commit", new_callable=AsyncMock) payload = { "model": "meta-llama/Llama-3.2-1B-Instruct", @@ -308,6 +279,11 @@ async def chunk_generator(): "nilai_api.routers.endpoints.responses.AsyncOpenAI", return_value=mock_async_openai_instance, ) + mocker.patch( + "nilai_api.routers.endpoints.responses.state.get_model", + return_value=model_endpoint, + ) + mocker.patch("nilai_api.db.logs.QueryLogContext.commit", new_callable=AsyncMock) payload = { "model": "meta-llama/Llama-3.2-1B-Instruct", diff --git a/tests/unit/nilai_api/test_db.py b/tests/unit/nilai_api/test_db.py index dff0fd8b..3979321d 100644 --- a/tests/unit/nilai_api/test_db.py +++ b/tests/unit/nilai_api/test_db.py @@ -15,7 +15,7 @@ async def test_insert_user(mock_db): """Test user insertion functionality.""" user = await mock_db.insert_user("Test User", "test@example.com") - assert "userid" in user + assert "user_id" in user assert "apikey" in user assert len(mock_db.users) == 1 @@ -38,9 +38,9 @@ async def test_token_usage(mock_db): """Test token usage tracking.""" user = await mock_db.insert_user("Test User", "test@example.com") - await mock_db.update_token_usage(user["userid"], 50, 20) + await mock_db.update_token_usage(user["user_id"], 50, 20) - token_usage = await mock_db.get_token_usage(user["userid"]) + token_usage = await mock_db.get_token_usage(user["user_id"]) assert token_usage["prompt_tokens"] == 50 assert token_usage["completion_tokens"] == 20 assert token_usage["queries"] == 1 @@ -51,9 +51,9 @@ async def test_query_logging(mock_db): """Test query logging functionality.""" user = await mock_db.insert_user("Test User", "test@example.com") - await mock_db.log_query(user["userid"], "test-model", 10, 15) + await mock_db.log_query(user["user_id"], "test-model", 10, 15) assert len(mock_db.query_logs) == 1 log_entry = list(mock_db.query_logs.values())[0] - assert log_entry["userid"] == user["userid"] + assert log_entry["user_id"] == user["user_id"] assert log_entry["model"] == "test-model" diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index 27a5c1bc..bd8a41e4 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -45,7 +45,7 @@ async def test_concurrent_rate_limit(req): rate_limit = RateLimit(concurrent_extractor=lambda _: (5, "test")) user_limits = UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, @@ -86,7 +86,7 @@ async def web_search_extractor(_): rate_limit = RateLimit(web_search_extractor=web_search_extractor) user_limits = UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, @@ -117,7 +117,7 @@ async def web_search_extractor(_): "user_limits", [ UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=10, @@ -131,7 +131,7 @@ async def web_search_extractor(_): ), ), UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, @@ -145,7 +145,7 @@ async def web_search_extractor(_): ), ), UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, @@ -159,7 +159,7 @@ async def web_search_extractor(_): ), ), UserRateLimits( - subscription_holder=random_id(), + user_id=random_id(), token_rate_limit=TokenRateLimits( limits=[ TokenRateLimit( @@ -220,7 +220,7 @@ async def web_search_extractor(request): rate_limit = RateLimit(web_search_extractor=web_search_extractor) user_limits = UserRateLimits( - subscription_holder=apikey, + user_id=apikey, token_rate_limit=None, rate_limits=RateLimits( user_rate_limit_day=None, diff --git a/uv.lock b/uv.lock index 65292720..03918e23 100644 --- a/uv.lock +++ b/uv.lock @@ -2238,7 +2238,7 @@ requires-dist = [ { name = "gunicorn", specifier = ">=23.0.0" }, { name = "httpx", specifier = ">=0.27.2" }, { name = "nilai-common", editable = "packages/nilai-common" }, - { name = "nilauth-credit-middleware", specifier = "==0.1.1" }, + { name = "nilauth-credit-middleware", specifier = ">=0.1.2" }, { name = "nilrag", specifier = ">=0.1.11" }, { name = "nuc", specifier = ">=0.1.0" }, { name = "openai", specifier = ">=1.59.9" }, @@ -2328,7 +2328,7 @@ dev = [ [[package]] name = "nilauth-credit-middleware" -version = "0.1.1" +version = "0.1.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fastapi", extra = ["standard"] }, @@ -2336,9 +2336,9 @@ dependencies = [ { name = "nuc" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/cf/7716fa5f4aca83ef39d6f9f8bebc1d80d194c52c9ce6e75ee6bd1f401217/nilauth_credit_middleware-0.1.1.tar.gz", hash = "sha256:ae32c4c1e6bc083c8a7581d72a6da271ce9c0f0f9271a1694acb81ccd0a4a8bd", size = 10259, upload-time = "2025-10-16T11:15:03.918Z" } +sdist = { url = "https://files.pythonhosted.org/packages/46/bc/ae9b2c26919151fc7193b406a98831eeef197f6ec46b0c075138e66ec016/nilauth_credit_middleware-0.1.2.tar.gz", hash = "sha256:66423a4d18aba1eb5f5d47a04c8f7ae6a19ab4e34433475aa9dc1ba398483fdd", size = 11979, upload-time = "2025-10-30T16:21:20.538Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/b5/6e4090ae2ae8848d12e43f82d8d995cd1dff9de8e947cf5fb2b8a72a828e/nilauth_credit_middleware-0.1.1-py3-none-any.whl", hash = "sha256:10a0fda4ac11f51b9a5dd7b3a8fbabc0b28ff92a170a7729ac11eb15c7b37887", size = 14919, upload-time = "2025-10-16T11:15:02.201Z" }, + { url = "https://files.pythonhosted.org/packages/05/c3/73d55667aad701a64f3d1330d66c90a8c292fd19f054093ca74960aca1fb/nilauth_credit_middleware-0.1.2-py3-none-any.whl", hash = "sha256:31f3233e6706c6167b6246a4edb9a405d587eccb1399231223f95c0cdf1ce57c", size = 18121, upload-time = "2025-10-30T16:21:19.547Z" }, ] [[package]]