diff --git a/.env.ci b/.env.ci index f448185b..d3717128 100644 --- a/.env.ci +++ b/.env.ci @@ -7,8 +7,8 @@ HF_TOKEN="Hugging Face API Token" ENVIRONMENT = "mainnet" -NILAI_GUNICORN_WORKERS = 10 -AUTH_STRATEGY = "nuc" +NILAI_GUNICORN_WORKERS = 2 +AUTH_STRATEGY = "api_key" # The domain name of the server # - It must be written as "localhost" or "test.nilai.nillion" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 42a51ad1..3b451665 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,7 +63,7 @@ jobs: github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} runners-per-machine: 1 number-of-machines: 1 - ec2-image-id: ami-0ac221d824dd88706 + ec2-image-id: ami-0174a246556e8750b ec2-instance-type: g4dn.xlarge subnet-id: subnet-0ec4c353621eabae2 security-group-id: sg-03ee5c56e1f467aa0 @@ -94,7 +94,7 @@ jobs: cache-dependency-glob: "**/pyproject.toml" - name: Install dependencies run: | - apt-get update && apt-get install curl git pkg-config automake file -y + apt-get update && apt-get install curl git pkg-config automake file python3.12-dev -y uv sync - name: Build vllm @@ -103,7 +103,7 @@ jobs: - name: Build attestation run: docker build -t nillion/nilai-attestation:latest -f docker/attestation.Dockerfile . - - name: Build nilal API + - name: Build nilai API run: docker build -t nillion/nilai-api:latest -f docker/api.Dockerfile --target nilai --platform linux/amd64 . - name: Create .env @@ -124,12 +124,22 @@ jobs: - name: Wait for services to be healthy run: bash scripts/wait_for_ci_services.sh - - name: Run E2E tests + - name: Run E2E tests for NUC run: | set -e export ENVIRONMENT=ci uv run pytest -v tests/e2e + - name: Run E2E tests for API Key + run: | + set -e + # Create a user with a rate limit of 1000 requests per minute, hour, and day + export AUTH_TOKEN=$(docker exec nilai-api uv run src/nilai_api/commands/add_user.py --name test1 --ratelimit-minute 1000 --ratelimit-hour 1000 --ratelimit-day 1000 | jq ".apikey" -r) + export ENVIRONMENT=ci + # Set the environment variable for the API key + export AUTH_STRATEGY=api_key + uv run pytest -v tests/e2e + - name: Stop Services run: | docker-compose -f docker-compose.yml \ diff --git a/caddy/Caddyfile b/caddy/Caddyfile index 61135b2c..063028d8 100644 --- a/caddy/Caddyfile +++ b/caddy/Caddyfile @@ -2,7 +2,7 @@ tls { protocols tls1.2 tls1.3 } - } +} {$NILAI_SERVER_DOMAIN} { import ssl_config @@ -12,12 +12,12 @@ reverse_proxy grafana:3000 } - handle_path /grafana { - uri strip_prefix /grafana - reverse_proxy grafana:3000 + handle_path /nuc/* { + uri strip_prefix /nuc + reverse_proxy nuc-api:8080 } handle { reverse_proxy api:8080 } - } +} diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index efd40057..c33ebbe0 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -6,6 +6,17 @@ services: volumes: - ./nilai-api/:/app/nilai-api/ - ./packages/:/app/packages/ + - ./nilai-auth/nuc-helpers/:/app/nilai-auth/nuc-helpers/ + networks: + - nilauth + nuc-api: + platform: linux/amd64 # for macOS to force running on Rosetta 2 + ports: + - "8088:8080" + volumes: + - ./nilai-api/:/app/nilai-api/ + - ./packages/:/app/packages/ + - ./nilai-auth/nuc-helpers/:/app/nilai-auth/nuc-helpers/ networks: - nilauth attestation: @@ -20,6 +31,9 @@ services: postgres: ports: - "5432:5432" + nuc-postgres: + ports: + - "5433:5432" grafana: ports: - "3000:3000" diff --git a/docker-compose.yml b/docker-compose.yml index 22210cc9..11b7f155 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -40,6 +40,25 @@ services: retries: 5 start_period: 10s timeout: 10s + + nuc-postgres: + image: postgres:16 + container_name: nuc-postgres + restart: always + env_file: + - .env + environment: + - POSTGRES_HOST=nuc-postgres + networks: + - frontend_net + volumes: + - nuc_postgres_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD", "sh", "-c", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB} -h localhost"] + interval: 30s + retries: 5 + start_period: 10s + timeout: 10s prometheus: container_name: prometheus image: prom/prometheus:v3.1.0 @@ -128,6 +147,35 @@ services: retries: 3 start_period: 15s timeout: 10s + nuc-api: + container_name: nilai-nuc-api + image: nillion/nilai-api:latest + privileged: true + volumes: + - /dev/sev-guest:/dev/sev-guest # for AMD SEV + depends_on: + etcd: + condition: service_healthy + nuc-postgres: + condition: service_healthy + api: + condition: service_healthy + restart: unless-stopped + networks: + - frontend_net + - backend_net + - proxy_net + env_file: + - .env + environment: + - AUTH_STRATEGY=nuc # Overwrite the default strategy + - POSTGRES_HOST=nuc-postgres + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/v1/health"] + interval: 30s + retries: 3 + start_period: 15s + timeout: 10s attestation: image: nillion/nilai-attestation:latest restart: unless-stopped @@ -171,3 +219,4 @@ networks: volumes: postgres_data: + nuc_postgres_data: diff --git a/nilai-api/alembic/env.py b/nilai-api/alembic/env.py index 936d88e0..7ffeb8fb 100644 --- a/nilai-api/alembic/env.py +++ b/nilai-api/alembic/env.py @@ -9,8 +9,13 @@ from alembic import context from nilai_api.db import Base +from nilai_api.db.users import UserModel +from nilai_api.db.logs import QueryLog import nilai_api.config as nilai_config +# If we don't use the models, they remain unused, and the migration fails +# This is a workaround to ensure the models are loaded +_, _ = UserModel, QueryLog # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config diff --git a/nilai-api/alembic/versions/b9642f45db1d_fix_changed_to_timestamps_with_timezone.py b/nilai-api/alembic/versions/b9642f45db1d_fix_changed_to_timestamps_with_timezone.py new file mode 100644 index 00000000..e95d18ef --- /dev/null +++ b/nilai-api/alembic/versions/b9642f45db1d_fix_changed_to_timestamps_with_timezone.py @@ -0,0 +1,75 @@ +"""fix: changed to timestamps with timezone + +Revision ID: b9642f45db1d +Revises: ca76e3ebe6ee +Create Date: 2025-05-13 09:47:30.506632 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "b9642f45db1d" +down_revision: Union[str, None] = "ca76e3ebe6ee" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "query_logs", + "query_timestamp", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + existing_server_default=sa.text("now()"), + ) + op.alter_column( + "users", + "signup_date", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + existing_server_default=sa.text("now()"), + ) + op.alter_column( + "users", + "last_activity", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "users", + "last_activity", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=True, + ) + op.alter_column( + "users", + "signup_date", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + existing_server_default=sa.text("now()"), + ) + op.alter_column( + "query_logs", + "query_timestamp", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + existing_server_default=sa.text("now()"), + ) + # ### end Alembic commands ### diff --git a/nilai-api/pyproject.toml b/nilai-api/pyproject.toml index 7dbae40e..866bfa76 100644 --- a/nilai-api/pyproject.toml +++ b/nilai-api/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "web3>=7.8.0", "click>=8.1.8", "nuc", + "nuc-helpers", ] @@ -42,3 +43,4 @@ build-backend = "hatchling.build" [tool.uv.sources] nilai-common = { workspace = true } nuc = { git = "https://github.com/NillionNetwork/nuc-py.git" } +nuc-helpers = { workspace = true } diff --git a/nilai-api/src/nilai_api/app.py b/nilai-api/src/nilai_api/app.py index 9b215094..5d11667c 100644 --- a/nilai-api/src/nilai_api/app.py +++ b/nilai-api/src/nilai_api/app.py @@ -3,7 +3,7 @@ from prometheus_fastapi_instrumentator import Instrumentator from fastapi import Depends, FastAPI -from nilai_api.auth import get_user +from nilai_api.auth import get_auth_info from nilai_api.rate_limiting import setup_redis_conn from nilai_api.routers import private, public from nilai_api import config @@ -86,7 +86,7 @@ async def lifespan(app: FastAPI): app.include_router(public.router) -app.include_router(private.router, dependencies=[Depends(get_user)]) +app.include_router(private.router, dependencies=[Depends(get_auth_info)]) origins = [ "https://docs.nillion.com", # TODO: When users want to connect from browser diff --git a/nilai-api/src/nilai_api/auth/__init__.py b/nilai-api/src/nilai_api/auth/__init__.py index 17fa9146..46101227 100644 --- a/nilai-api/src/nilai_api/auth/__init__.py +++ b/nilai-api/src/nilai_api/auth/__init__.py @@ -1,49 +1,53 @@ -from fastapi import HTTPException, Security, status +from fastapi import Security from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from logging import getLogger from nilai_api import config -from nilai_api.auth.jwt import validate_jwt -from nilai_api.db.users import UserManager, UserModel -from nilai_api.auth.strategies import STRATEGIES +from nilai_api.db.users import UserManager +from nilai_api.auth.strategies import AuthenticationStrategy from nuc.validate import ValidationException +from nuc_helpers.usage import UsageLimitError + +from nilai_api.auth.common import ( + AuthenticationInfo, + AuthenticationError, + TokenRateLimit, + TokenRateLimits, +) logger = getLogger(__name__) bearer_scheme = HTTPBearer() -class AuthenticationError(HTTPException): - def __init__(self, detail: str): - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=detail, - headers={"WWW-Authenticate": "Bearer"}, - ) - - -async def get_user( +async def get_auth_info( credentials: HTTPAuthorizationCredentials = Security(bearer_scheme), -) -> UserModel: +) -> AuthenticationInfo: try: - if config.AUTH_STRATEGY not in STRATEGIES: - logger.error(f"Invalid auth strategy: {config.AUTH_STRATEGY}") - raise AuthenticationError("Server misconfiguration: invalid auth strategy") - - user = await STRATEGIES[config.AUTH_STRATEGY](credentials.credentials) - if not user: - raise AuthenticationError("Missing or invalid API key") - await UserManager.update_last_activity(userid=user.userid) - return user + strategy_name: str = config.AUTH_STRATEGY.upper() + + try: + strategy = AuthenticationStrategy[strategy_name] + except KeyError: # If the strategy is not found, we raise an error + logger.error(f"Invalid auth strategy: {strategy_name}") + raise AuthenticationError( + f"Server misconfiguration: invalid auth strategy: {strategy_name}" + ) + + auth_info = await strategy(credentials.credentials) + await UserManager.update_last_activity(userid=auth_info.user.userid) + return auth_info except AuthenticationError as e: raise e except ValueError as e: raise AuthenticationError(detail="Authentication failed: " + str(e)) except ValidationException as e: raise AuthenticationError(detail="NUC validation failed: " + str(e)) + except UsageLimitError as e: + raise AuthenticationError(detail="Usage limit error: " + str(e)) except Exception as e: raise AuthenticationError(detail="Unexpected authentication error: " + str(e)) -__all__ = ["get_user", "validate_jwt"] +__all__ = ["get_auth_info", "AuthenticationInfo", "TokenRateLimits", "TokenRateLimit"] diff --git a/nilai-api/src/nilai_api/auth/common.py b/nilai-api/src/nilai_api/auth/common.py new file mode 100644 index 00000000..99a3bfea --- /dev/null +++ b/nilai-api/src/nilai_api/auth/common.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel +from typing import Optional +from fastapi import HTTPException, status +from nilai_api.db.users import UserData +from nuc_helpers.usage import TokenRateLimits, TokenRateLimit + + +class AuthenticationError(HTTPException): + def __init__(self, detail: str): + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + headers={"WWW-Authenticate": "Bearer"}, + ) + + +class AuthenticationInfo(BaseModel): + user: UserData + token_rate_limit: Optional[TokenRateLimits] + + +__all__ = [ + "AuthenticationError", + "AuthenticationInfo", + "TokenRateLimits", + "TokenRateLimit", +] diff --git a/nilai-api/src/nilai_api/auth/nuc.py b/nilai-api/src/nilai_api/auth/nuc.py index 0b9ab92a..cf1cc4eb 100644 --- a/nilai-api/src/nilai_api/auth/nuc.py +++ b/nilai-api/src/nilai_api/auth/nuc.py @@ -1,4 +1,5 @@ -from typing import Tuple +from datetime import datetime, timezone +from typing import Optional, Tuple from nuc.validate import NucTokenValidator, ValidationParameters, InvocationRequirement from nuc.envelope import NucTokenEnvelope from nuc.nilauth import NilauthClient @@ -6,9 +7,12 @@ from functools import lru_cache from nilai_api.config import NILAUTH_TRUSTED_ROOT_ISSUERS from nilai_api.state import state +from nilai_api.auth.common import AuthenticationError from nilai_common.logger import setup_logger +from nuc_helpers.usage import TokenRateLimits + logger = setup_logger(__name__) @@ -72,3 +76,28 @@ def validate_nuc(nuc_token: str) -> Tuple[str, str]: logger.info(f"Subscription holder: {subscription_holder}") logger.info(f"User: {user}") return subscription_holder, user + + +def get_token_rate_limit(nuc_token: str) -> Optional[TokenRateLimits]: + """ + Get the rate limit for the NUC token + + Args: + nuc_token: The NUC token to get the rate limit for + + Returns: + The rate limit for the NUC token + + Raises: + UsageLimitError: If the usage limit is not found or is invalid + """ + token_rate_limits = TokenRateLimits.from_token(nuc_token) + if not token_rate_limits: + return None + for limit in token_rate_limits.limits: + if limit.usage_limit is None: + raise AuthenticationError("Token has no usage limit") + if limit.expires_at < datetime.now(timezone.utc): + raise AuthenticationError("Token has expired") + + return token_rate_limits diff --git a/nilai-api/src/nilai_api/auth/strategies.py b/nilai-api/src/nilai_api/auth/strategies.py index 871ad2ce..a595dce1 100644 --- a/nilai-api/src/nilai_api/auth/strategies.py +++ b/nilai-api/src/nilai_api/auth/strategies.py @@ -1,39 +1,58 @@ -from nilai_api.db.users import UserManager, UserModel +from nilai_api.db.users import UserManager, UserModel, UserData from nilai_api.auth.jwt import validate_jwt -from nilai_api.auth.nuc import validate_nuc +from nilai_api.auth.nuc import validate_nuc, get_token_rate_limit +from nilai_api.auth.common import ( + TokenRateLimits, + AuthenticationInfo, + AuthenticationError, +) +from enum import Enum # All strategies must return a UserModel # The strategies can raise any exception, which will be caught and converted to an AuthenticationError # The exception detail will be passed to the client -async def api_key_strategy(api_key) -> UserModel: - return await UserManager.check_api_key(api_key) +async def api_key_strategy(api_key: str) -> AuthenticationInfo: + user_model: UserModel | None = await UserManager.check_api_key(api_key) + if user_model: + return AuthenticationInfo( + user=UserData.from_sqlalchemy(user_model), token_rate_limit=None + ) + raise AuthenticationError("Missing or invalid API key") -async def jwt_strategy(jwt_creds) -> UserModel: +async def jwt_strategy(jwt_creds: str) -> AuthenticationInfo: result = validate_jwt(jwt_creds) - user = await UserManager.check_api_key(result.user_address) - if user: - return user - user = UserModel( - userid=result.user_address, - name=result.pub_key, - apikey=result.user_address, - ) - await UserManager.insert_user_model(user) - return user + user_model: UserModel | None = await UserManager.check_api_key(result.user_address) + if user_model: + return AuthenticationInfo( + user=UserData.from_sqlalchemy(user_model), token_rate_limit=None + ) + else: + user_model = UserModel( + userid=result.user_address, + name=result.pub_key, + apikey=result.user_address, + ) + await UserManager.insert_user_model(user_model) + return AuthenticationInfo( + user=UserData.from_sqlalchemy(user_model), token_rate_limit=None + ) -async def nuc_strategy(nuc_token) -> UserModel: +async def nuc_strategy(nuc_token) -> AuthenticationInfo: """ Validate a NUC token and return the user model """ subscription_holder, user = validate_nuc(nuc_token) - - user_model = await UserManager.check_user(user) + token_rate_limits: TokenRateLimits | None = get_token_rate_limit(nuc_token) + user_model: UserModel | None = await UserManager.check_user(user) if user_model: - return user_model + return AuthenticationInfo( + user=UserData.from_sqlalchemy(user_model), + token_rate_limit=token_rate_limits, + ) user_model = UserModel( userid=user, @@ -41,13 +60,18 @@ async def nuc_strategy(nuc_token) -> UserModel: apikey=subscription_holder, ) await UserManager.insert_user_model(user_model) - return user_model + return AuthenticationInfo( + user=UserData.from_sqlalchemy(user_model), token_rate_limit=token_rate_limits + ) + + +class AuthenticationStrategy(Enum): + API_KEY = (api_key_strategy, "API Key") + JWT = (jwt_strategy, "JWT") + NUC = (nuc_strategy, "NUC") + async def __call__(self, *args, **kwargs) -> AuthenticationInfo: + return await self.value[0](*args, **kwargs) -STRATEGIES = { - "api_key": api_key_strategy, - "jwt": jwt_strategy, - "nuc": nuc_strategy, -} -__all__ = ["STRATEGIES"] +__all__ = ["AuthenticationStrategy"] diff --git a/nilai-api/src/nilai_api/db/logs.py b/nilai-api/src/nilai_api/db/logs.py index 36ab6e7c..4940de90 100644 --- a/nilai-api/src/nilai_api/db/logs.py +++ b/nilai-api/src/nilai_api/db/logs.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime +from datetime import datetime, timezone import sqlalchemy @@ -20,7 +20,7 @@ class QueryLog(Base): String(75), ForeignKey(UserModel.userid), nullable=False, index=True ) # type: ignore query_timestamp: datetime = Column( - DateTime, server_default=sqlalchemy.func.now(), nullable=False + DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False ) # type: ignore model: str = Column(Text, nullable=False) # type: ignore prompt_tokens: int = Column(Integer, nullable=False) # type: ignore @@ -55,7 +55,7 @@ async def log_query( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, - query_timestamp=datetime.now(), + query_timestamp=datetime.now(timezone.utc), ) session.add(query_log) await session.commit() diff --git a/nilai-api/src/nilai_api/db/users.py b/nilai-api/src/nilai_api/db/users.py index b94e796d..7eb71c94 100644 --- a/nilai-api/src/nilai_api/db/users.py +++ b/nilai-api/src/nilai_api/db/users.py @@ -1,8 +1,8 @@ import logging import uuid +from pydantic import BaseModel, ConfigDict -from datetime import datetime -from dataclasses import dataclass +from datetime import datetime, timezone from typing import Any, Dict, List, Optional import sqlalchemy @@ -30,9 +30,9 @@ class UserModel(Base): completion_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore queries: int = Column(Integer, default=0, nullable=False) # type: ignore signup_date: datetime = Column( - DateTime, server_default=sqlalchemy.func.now(), nullable=False + DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False ) # type: ignore - last_activity: datetime = Column(DateTime, nullable=True) # type: ignore + last_activity: datetime = Column(DateTime(timezone=True), nullable=True) # type: ignore ratelimit_day: int = Column(Integer, default=USER_RATE_LIMIT_DAY, nullable=True) # type: ignore ratelimit_hour: int = Column(Integer, default=USER_RATE_LIMIT_HOUR, nullable=True) # type: ignore ratelimit_minute: int = Column( @@ -43,14 +43,36 @@ def __repr__(self): return f"" -@dataclass -class UserData: +class UserData(BaseModel): userid: str name: str apikey: str - input_tokens: int - generated_tokens: int - queries: int + prompt_tokens: int = 0 + completion_tokens: int = 0 + queries: int = 0 + signup_date: datetime + last_activity: Optional[datetime] = None + ratelimit_day: Optional[int] = None + ratelimit_hour: Optional[int] = None + ratelimit_minute: Optional[int] = None + + model_config = ConfigDict(from_attributes=True) + + @classmethod + def from_sqlalchemy(cls, user: UserModel) -> "UserData": + return cls( + userid=user.userid, + name=user.name, + apikey=user.apikey, + prompt_tokens=user.prompt_tokens or 0, + completion_tokens=user.completion_tokens or 0, + queries=user.queries or 0, + signup_date=user.signup_date or datetime.now(timezone.utc), + last_activity=user.last_activity, + ratelimit_day=user.ratelimit_day, + ratelimit_hour=user.ratelimit_hour, + ratelimit_minute=user.ratelimit_minute, + ) class UserManager: @@ -76,7 +98,7 @@ async def update_last_activity(userid: str): async with get_db_session() as session: user = await session.get(UserModel, userid) if user: - user.last_activity = datetime.now() + user.last_activity = datetime.now(timezone.utc) await session.commit() logger.info(f"Updated last activity for user {userid}") else: @@ -252,9 +274,14 @@ async def get_all_users() -> Optional[List[UserData]]: userid=user.userid, name=user.name, apikey=user.apikey, - input_tokens=user.prompt_tokens, - generated_tokens=user.completion_tokens, + prompt_tokens=user.prompt_tokens, + completion_tokens=user.completion_tokens, queries=user.queries, + signup_date=user.signup_date, + last_activity=user.last_activity, + ratelimit_day=user.ratelimit_day, + ratelimit_hour=user.ratelimit_hour, + ratelimit_minute=user.ratelimit_minute, ) for user in users ] diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index e888ba4c..af7971b0 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -7,8 +7,7 @@ from fastapi import status, HTTPException, Request from redis.asyncio import from_url, Redis -from nilai_api.auth import get_user -from nilai_api.db.users import UserModel +from nilai_api.auth import get_auth_info, AuthenticationInfo, TokenRateLimits LUA_RATE_LIMIT_SCRIPT = """ local key = KEYS[1] @@ -41,13 +40,16 @@ async def setup_redis_conn(redis_url): class UserRateLimits(BaseModel): - id: str + subscription_holder: str day_limit: int | None hour_limit: int | None minute_limit: int | None + token_rate_limit: TokenRateLimits | None -def get_user_limits(user: Annotated[UserModel, Depends(get_user)]) -> UserRateLimits: +def get_user_limits( + auth_info: Annotated[AuthenticationInfo, Depends(get_auth_info)], +) -> UserRateLimits: # TODO: When the only allowed strategy is NUC, we can change the apikey name to subscription_holder # In apikey mode, the apikey is unique as the userid. # In nuc mode, the apikey is associated with a subscription holder and the userid is the user @@ -55,10 +57,11 @@ def get_user_limits(user: Annotated[UserModel, Depends(get_user)]) -> UserRateLi # In JWT mode, the apikey is the userid too # So we use the apikey as the id return UserRateLimits( - id=user.apikey, - day_limit=user.ratelimit_day, - hour_limit=user.ratelimit_hour, - minute_limit=user.ratelimit_minute, + subscription_holder=auth_info.user.apikey, + day_limit=auth_info.user.ratelimit_day, + hour_limit=auth_info.user.ratelimit_hour, + minute_limit=auth_info.user.ratelimit_minute, + token_rate_limit=auth_info.token_rate_limit, ) @@ -90,24 +93,43 @@ async def __call__( await self.check_bucket( redis, redis_rate_limit_command, - f"minute:{user_limits.id}", + f"minute:{user_limits.subscription_holder}", user_limits.minute_limit, MINUTE_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"hour:{user_limits.id}", + f"hour:{user_limits.subscription_holder}", user_limits.hour_limit, HOUR_MS, ) await self.check_bucket( redis, redis_rate_limit_command, - f"day:{user_limits.id}", + f"day:{user_limits.subscription_holder}", user_limits.day_limit, DAY_MS, ) + + if ( + user_limits.token_rate_limit + ): # If the token rate limit is not None, we need to check it + # We create a record in redis for the signature + # The key is the signature + # The value is the usage limit + # The expiration is the time remaining in validity of the token + # We use the time remaining to check if the token rate limit is exceeded + + for limit in user_limits.token_rate_limit.limits: + await self.check_bucket( + redis, + redis_rate_limit_command, + f"token:{limit.signature}", + limit.usage_limit, + limit.ms_remaining, + ) + key = await self.check_concurrent_and_increment(redis, request) try: yield diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index b579e0b6..dbc9fc42 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -1,7 +1,6 @@ # Fast API and serving import asyncio import logging -import os from base64 import b64encode from typing import AsyncGenerator, Optional, Union, List, Tuple from nilai_api.attestation import get_attestation_report @@ -9,11 +8,11 @@ from fastapi import APIRouter, Body, Depends, HTTPException, status, Request from fastapi.responses import StreamingResponse -from nilai_api.auth import get_user +from nilai_api.auth import get_auth_info, AuthenticationInfo from nilai_api.config import MODEL_CONCURRENT_RATE_LIMIT from nilai_api.crypto import sign_message from nilai_api.db.logs import QueryLogManager -from nilai_api.db.users import UserManager, UserModel +from nilai_api.db.users import UserManager from nilai_api.rate_limiting import RateLimit from nilai_api.state import state @@ -35,7 +34,7 @@ @router.get("/v1/usage", tags=["Usage"]) -async def get_usage(user: UserModel = Depends(get_user)) -> Usage: +async def get_usage(auth_info: AuthenticationInfo = Depends(get_auth_info)) -> Usage: """ Retrieve the current token usage for the authenticated user. @@ -49,16 +48,17 @@ async def get_usage(user: UserModel = Depends(get_user)) -> Usage: ``` """ return Usage( - prompt_tokens=user.prompt_tokens, - completion_tokens=user.completion_tokens, - total_tokens=user.prompt_tokens + user.completion_tokens, - queries=user.queries, # type: ignore # FIXME this field is not part of Usage + prompt_tokens=auth_info.user.prompt_tokens, + completion_tokens=auth_info.user.completion_tokens, + total_tokens=auth_info.user.prompt_tokens + auth_info.user.completion_tokens, + queries=auth_info.user.queries, # type: ignore # FIXME this field is not part of Usage ) @router.get("/v1/attestation/report", tags=["Attestation"]) async def get_attestation( - nonce: Optional[Nonce] = None, user: UserModel = Depends(get_user) + nonce: Optional[Nonce] = None, + auth_info: AuthenticationInfo = Depends(get_auth_info), ) -> AttestationReport: """ Generate a cryptographic attestation report. @@ -82,7 +82,9 @@ async def get_attestation( @router.get("/v1/models", tags=["Model"]) -async def get_models(user: UserModel = Depends(get_user)) -> List[ModelMetadata]: +async def get_models( + auth_info: AuthenticationInfo = Depends(get_auth_info), +) -> List[ModelMetadata]: """ List all available models in the system. @@ -95,7 +97,6 @@ async def get_models(user: UserModel = Depends(get_user)) -> List[ModelMetadata] models = await get_models(user) ``` """ - logger.info(f"Retrieving models for user {user.userid} from pid {os.getpid()}") return [endpoint.metadata for endpoint in (await state.models).values()] @@ -125,7 +126,7 @@ async def chat_completion( ) ), _=Depends(RateLimit(concurrent_extractor=chat_completion_concurrent_rate_limit)), - user: UserModel = Depends(get_user), + auth_info: AuthenticationInfo = Depends(get_auth_info), ) -> Union[SignedChatCompletion, StreamingResponse]: """ Generate a chat completion response from the AI model. @@ -187,7 +188,7 @@ async def chat_completion( model_url = endpoint.url + "/v1/" logger.info( - f"Chat completion request for model {model_name} from user {user.userid} on url: {model_url}" + f"Chat completion request for model {model_name} from user {auth_info.user.userid} on url: {model_url}" ) if req.nilrag: @@ -214,32 +215,29 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: } }, ) # type: ignore - prompt_token_usage: int = 0 completion_token_usage: int = 0 async for chunk in response: - if ( - chunk.usage is not None - and chunk.usage.prompt_tokens is not None - and chunk.usage.completion_tokens is not None - ): - prompt_token_usage = chunk.usage.prompt_tokens - completion_token_usage += chunk.usage.completion_tokens - - logger.info( - f"Prompt token usage: {chunk.usage.prompt_tokens}/{prompt_token_usage}, Completion token usage: {chunk.usage.completion_tokens}/{completion_token_usage}" - ) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" await asyncio.sleep(0) + prompt_token_usage = ( + chunk.usage.prompt_tokens if chunk.usage else prompt_token_usage + ) + completion_token_usage = ( + chunk.usage.completion_tokens + if chunk.usage + else completion_token_usage + ) + await UserManager.update_token_usage( - user.userid, + auth_info.user.userid, prompt_tokens=prompt_token_usage, completion_tokens=completion_token_usage, ) await QueryLogManager.log_query( - user.userid, + auth_info.user.userid, model=req.model, prompt_tokens=prompt_token_usage, completion_tokens=completion_token_usage, @@ -276,13 +274,13 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: ) # Update token usage await UserManager.update_token_usage( - user.userid, + auth_info.user.userid, prompt_tokens=model_response.usage.prompt_tokens, completion_tokens=model_response.usage.completion_tokens, ) await QueryLogManager.log_query( - user.userid, + auth_info.user.userid, model=req.model, prompt_tokens=model_response.usage.prompt_tokens, completion_tokens=model_response.usage.completion_tokens, diff --git a/nilai-auth/nuc-helpers/src/nuc_helpers/helpers.py b/nilai-auth/nuc-helpers/src/nuc_helpers/helpers.py index 8dc7b937..2e366043 100644 --- a/nilai-auth/nuc-helpers/src/nuc_helpers/helpers.py +++ b/nilai-auth/nuc-helpers/src/nuc_helpers/helpers.py @@ -93,8 +93,6 @@ def get_unil_balance(address: Address, grpc_endpoint: str) -> int: Returns: The balance of the user in UNIL """ - logger.info("grpc_endpoint", grpc_endpoint) - cfg = NetworkConfig( chain_id="nillion-chain-devnet", url="grpc+" + grpc_endpoint, @@ -168,9 +166,11 @@ def pay_for_subscription( def get_delegation_token( - root_token: RootToken, + root_token: RootToken | DelegationToken, private_key: NilAuthPrivateKey, user_public_key: NilAuthPublicKey, + usage_limit: int | None = None, + expires_at: datetime.datetime | None = None, ) -> DelegationToken: """ Delegate the root token to the delegated key @@ -186,8 +186,15 @@ def get_delegation_token( root_token_envelope = NucTokenEnvelope.parse(root_token.token) delegated_token = ( NucTokenBuilder.extending(root_token_envelope) + .expires_at( + expires_at + if expires_at + else datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(minutes=5) + ) .audience(Did(user_public_key.serialize())) .command(Command(["nil", "ai", "generate"])) + .meta({"usage_limit": usage_limit}) .build(private_key) ) return DelegationToken(token=delegated_token) @@ -222,7 +229,6 @@ def get_invocation_token( nilai_public_key: The nilai public key delegated_key: The private key """ - logger.info(f"Delegation token: {delegation_token}") delegated_token_envelope = NucTokenEnvelope.parse(delegation_token.token) invocation = ( @@ -259,8 +265,7 @@ def validate_token( token: The token to validate validation_parameters: The validation parameters """ + token_envelope = NucTokenEnvelope.parse(token) validator = NucTokenValidator([get_nilauth_public_key(nilauth_url)]) - validator.validate(NucTokenEnvelope.parse(token), validation_parameters) - - logger.info("[>] Token validated") + validator.validate(token_envelope, validation_parameters) diff --git a/nilai-auth/nuc-helpers/src/nuc_helpers/main.py b/nilai-auth/nuc-helpers/src/nuc_helpers/main.py index c106f5d0..b6833f7e 100644 --- a/nilai-auth/nuc-helpers/src/nuc_helpers/main.py +++ b/nilai-auth/nuc-helpers/src/nuc_helpers/main.py @@ -17,6 +17,86 @@ from nuc.validate import ValidationParameters, InvocationRequirement +def b2b2b2c_test(): + # Services must be running for this to work + PRIVATE_KEY = "l/SYifzu2Iqc3dsWoWHRP2oSMHwrORY/PDw5fDwtJDQ=" # This is an example private key with funds for testing devnet, and should not be used in production + NILAI_ENDPOINT = "localhost:8080" + NILAUTH_ENDPOINT = "localhost:30921" + NILCHAIN_GRPC = "localhost:26649" + + # Server private key + server_wallet, server_keypair, server_private_key = get_wallet_and_private_key( + PRIVATE_KEY + ) + nilauth_client = NilauthClient(f"http://{NILAUTH_ENDPOINT}") + + # Pay for the subscription + pay_for_subscription( + nilauth_client, + server_wallet, + server_keypair, + server_private_key, + f"http://{NILCHAIN_GRPC}", + ) + + # Create a root token + root_token: RootToken = get_root_token(nilauth_client, server_private_key) + + # Create a user private key and public key + user_private_key = NilAuthPrivateKey() + user_public_key = user_private_key.pubkey + + if user_public_key is None: + raise Exception("Failed to get public key") + # b64_public_key = base64.b64encode(public_key.serialize()).decode("utf-8") + + delegation_token: DelegationToken = get_delegation_token( + root_token, + server_private_key, + user_public_key, + ) + + validate_token( + f"http://{NILAUTH_ENDPOINT}", + delegation_token.token, + ValidationParameters.default(), + ) + for i in range(2): + delegation_token: DelegationToken = get_delegation_token( + delegation_token, + user_private_key, + user_public_key, + ) + validate_token( + f"http://{NILAUTH_ENDPOINT}", + delegation_token.token, + ValidationParameters.default(), + ) + print("[>] Validated delegation token: ", type(delegation_token)) + + nilai_public_key: NilAuthPublicKey = get_nilai_public_key( + f"http://{NILAI_ENDPOINT}" + ) + + invocation_token: InvocationToken = get_invocation_token( + delegation_token, + nilai_public_key, + user_private_key, + ) + + print("Root token type: ", type(root_token)) + default_validation_parameters = ValidationParameters.default() + default_validation_parameters.token_requirements = InvocationRequirement( + audience=Did(nilai_public_key.serialize()) + ) + + validate_token( + f"http://{NILAUTH_ENDPOINT}", + invocation_token.token, + default_validation_parameters, + ) + + def b2b2c_test(): # Services must be running for this to work PRIVATE_KEY = "l/SYifzu2Iqc3dsWoWHRP2oSMHwrORY/PDw5fDwtJDQ=" # This is an example private key with funds for testing devnet, and should not be used in production @@ -130,8 +210,9 @@ def main(): """ Main function to test the helpers """ - b2b2c_test() - b2c_test() + b2b2b2c_test() + # b2b2c_test() + # b2c_test() if __name__ == "__main__": diff --git a/nilai-auth/nuc-helpers/src/nuc_helpers/usage.py b/nilai-auth/nuc-helpers/src/nuc_helpers/usage.py new file mode 100644 index 00000000..0f21ae41 --- /dev/null +++ b/nilai-auth/nuc-helpers/src/nuc_helpers/usage.py @@ -0,0 +1,139 @@ +from datetime import datetime, timedelta, timezone +from functools import lru_cache +from typing import Optional, List +from nuc.envelope import NucTokenEnvelope +from enum import Enum +import logging +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +class TokenRateLimit(BaseModel): + signature: str + expires_at: datetime + usage_limit: Optional[int] + + @property + def ms_remaining(self) -> int: + if self.expires_at is None: + return 0 # Or handle as infinite, e.g., float('inf'), or raise error + return int( + (self.expires_at - datetime.now(timezone.utc)).total_seconds() * 1000 + ) + + +class UsageLimitKind(Enum): + INCONSISTENT = "Inconsistent usage limit across proofs" + INVALID_TYPE = "Invalid usage limit type. Usage limit must be an integer." + + +class UsageLimitError(Exception): + """ + Usage limit error. + """ + + def __init__(self, kind: UsageLimitKind, message: str) -> None: + super().__init__(self, f"validation failed: {kind}: {message}") + self.kind = kind + + +def is_reduction_of(base: int, reduced: int) -> bool: + """Check if `reduced` is a valid reduction of `base`.""" + return 0 < reduced <= base + + +class TokenRateLimits(BaseModel): + limits: List[TokenRateLimit] = Field(default_factory=list, min_length=1) + + @property + def last(self) -> TokenRateLimit: + if len(self.limits) == 0: + raise ValueError("No limits found") + return self.limits[-1] + + def get_limit(self, signature: str) -> Optional[TokenRateLimit]: + for limit in self.limits: + if limit.signature == signature: + return limit + return None + + @staticmethod + @lru_cache(maxsize=128) + def from_token(token: str) -> Optional["TokenRateLimits"]: + """ + Extracts the effective usage limits from a valid NUC delegation token proof chain, ensuring consistency across proofs. + + The token is expected to be a valid NUC delegation token. If the token is invalid, + the function may not raise an error, but the result can be incorrect. + + This function parses the provided token and inspects all associated proofs and the invocation + token (if present) to determine the applicable usage limits. The behavior is as follows: + + - If multiple proofs include a `usage_limit` in their metadata, they must all be reductions of + the same base usage limit. Inconsistencies will raise an error. + - If the invocation token includes a `usage_limit`, it is ignored. + - If no usage limits are found in either proofs or invocation, the function returns `None`. + + The function is cached based on the token string to avoid redundant parsing and validation. + + Note: This function is cached, so it will return the same result for the same token string. + If you need to invalidate the cache, call `get_usage_limit.cache_clear()`. + + + Args: + token (str): The serialized delegation token. + + Returns: + Tuple[str, str, Optional[int]]: The signature, the effective usage limit, and the expiration date, or `None` if no usage limit is found. + + Raises: + UsageLimitInconsistencyError: If usage limits across proofs or invocation are inconsistent. + """ + token_envelope = NucTokenEnvelope.parse(token) + + usage_limits = [] + + # Iterate over proofs and collect usage limits from the root token -> last delegation token + for i, proof in enumerate(token_envelope.proofs[::-1]): + meta = proof.token.meta if proof.token else None + logger.info(f"Proof {i} meta: {meta}") + if meta and "usage_limit" in meta and meta["usage_limit"] is not None: + token_usage_limit = meta["usage_limit"] + logger.info(f"Proof {i} usage limit: {token_usage_limit}") + if not isinstance(token_usage_limit, int): + logger.error( + f"Proof {i} has invalid usage limit type: {type(token_usage_limit)} and value: {token_usage_limit}." + ) + raise UsageLimitError( + UsageLimitKind.INVALID_TYPE, + f"Proof {i} has invalid usage limit type: {type(token_usage_limit)} and value: {token_usage_limit}.", + ) + # We have a usage limit, we need to check if it is a reduction of the previous usage limit + if len(usage_limits) > 0 and not is_reduction_of( + usage_limits[-1].usage_limit, token_usage_limit + ): + error_message = f"Inconsistent usage limit: {token_usage_limit} is not a reduction of {usage_limits[-1].usage_limit}" + logger.error(error_message) + raise UsageLimitError( + UsageLimitKind.INCONSISTENT, + error_message, + ) + logger.info(f"Usage limit updated to: {token_usage_limit}") + sig = proof.signature.hex() + expires_at = ( + proof.token.expires_at + if proof.token.expires_at is not None + else datetime.now(timezone.utc) - timedelta(days=1) + ) # Set to a past date to indicate that the token is expired and invalid + usage_limits.append( + TokenRateLimit( + signature=sig, + expires_at=expires_at, + usage_limit=token_usage_limit, + ) + ) + + if len(usage_limits) == 0: + return None + return TokenRateLimits(limits=usage_limits) diff --git a/packages/nilai-common/src/nilai_common/discovery.py b/packages/nilai-common/src/nilai_common/discovery.py index a0ebdb61..7d3b1cf8 100644 --- a/packages/nilai-common/src/nilai_common/discovery.py +++ b/packages/nilai-common/src/nilai_common/discovery.py @@ -3,7 +3,7 @@ from typing import Dict, Optional from asyncio import CancelledError -from datetime import datetime +from datetime import datetime, timezone from tenacity import retry, wait_exponential, stop_after_attempt @@ -130,7 +130,7 @@ async def unregister_model(self, model_id: str): ) async def _refresh_lease(self, lease): lease.refresh() - self.last_refresh = datetime.now() + self.last_refresh = datetime.now(timezone.utc) self.is_healthy = True async def keep_alive(self, lease): diff --git a/scripts/wait_for_ci_services.sh b/scripts/wait_for_ci_services.sh index ec67638e..210780b4 100755 --- a/scripts/wait_for_ci_services.sh +++ b/scripts/wait_for_ci_services.sh @@ -3,16 +3,17 @@ # Wait for the services to be ready API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-llama_1b_gpu 2>/dev/null) +NUC_API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-nuc-api 2>/dev/null) MAX_ATTEMPTS=20 ATTEMPT=1 while [ $ATTEMPT -le $MAX_ATTEMPTS ]; do - echo "Waiting for nilai to become healthy... API:[$API_HEALTH_STATUS] MODEL:[$MODEL_HEALTH_STATUS] (Attempt $ATTEMPT/$MAX_ATTEMPTS)" + echo "Waiting for nilai to become healthy... API:[$API_HEALTH_STATUS] MODEL:[$MODEL_HEALTH_STATUS] NUC_API:[$NUC_API_HEALTH_STATUS] (Attempt $ATTEMPT/$MAX_ATTEMPTS)" sleep 30 API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-llama_1b_gpu 2>/dev/null) - - if [ "$API_HEALTH_STATUS" = "healthy" ] && [ "$MODEL_HEALTH_STATUS" = "healthy" ]; then + NUC_API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-nuc-api 2>/dev/null) + if [ "$API_HEALTH_STATUS" = "healthy" ] && [ "$MODEL_HEALTH_STATUS" = "healthy" ] && [ "$NUC_API_HEALTH_STATUS" = "healthy" ]; then break fi @@ -30,3 +31,9 @@ if [ "$MODEL_HEALTH_STATUS" != "healthy" ]; then echo "Error: nilai-llama_1b_gpu failed to become healthy after $MAX_ATTEMPTS attempts" exit 1 fi + +echo "NUC_API_HEALTH_STATUS: $NUC_API_HEALTH_STATUS" +if [ "$NUC_API_HEALTH_STATUS" != "healthy" ]; then + echo "Error: nilai-nuc-api failed to become healthy after $MAX_ATTEMPTS attempts" + exit 1 +fi diff --git a/tests/e2e/config.py b/tests/e2e/config.py index 14d90078..c49eff97 100644 --- a/tests/e2e/config.py +++ b/tests/e2e/config.py @@ -1,19 +1,27 @@ import os +from .nuc import get_nuc_token -ENVIRONMENT = os.getenv("ENVIRONMENT", "dev") +ENVIRONMENT = os.getenv("ENVIRONMENT", "ci") # Left for API key for backwards compatibility AUTH_TOKEN = os.getenv("AUTH_TOKEN", "") +AUTH_STRATEGY = os.getenv("AUTH_STRATEGY", "nuc") -if ENVIRONMENT == "dev": - BASE_URL = "http://localhost:8080/v1" -elif ENVIRONMENT == "ci": - BASE_URL = "http://127.0.0.1:8080/v1" -elif ENVIRONMENT == "mainnet": - BASE_URL = "https://nilai-e176.nillion.network/v1" -else: - raise ValueError(f"Invalid environment: {ENVIRONMENT}") +match AUTH_STRATEGY: + case "nuc": + BASE_URL = "https://localhost/nuc/v1" + def api_key_getter(): + return get_nuc_token().token + case "api_key": + BASE_URL = "https://localhost/v1" + def api_key_getter(): + return AUTH_TOKEN + case _: + raise ValueError(f"Invalid AUTH_STRATEGY: {AUTH_STRATEGY}") + + +print(f"USING {AUTH_STRATEGY}") models = { "mainnet": [ "meta-llama/Llama-3.2-3B-Instruct", @@ -29,4 +37,9 @@ ], } +if ENVIRONMENT not in models: + ENVIRONMENT = "ci" + print( + f"Environment {ENVIRONMENT} not found in models, using {ENVIRONMENT} as default" + ) test_models = models[ENVIRONMENT] diff --git a/tests/e2e/nuc.py b/tests/e2e/nuc.py index c7c27adb..27cf6f30 100644 --- a/tests/e2e/nuc.py +++ b/tests/e2e/nuc.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta, timezone from nuc_helpers import ( get_wallet_and_private_key, pay_for_subscription, @@ -8,6 +9,9 @@ InvocationToken, RootToken, NilAuthPublicKey, + NilAuthPrivateKey, + get_delegation_token, + DelegationToken, ) from nuc.nilauth import NilauthClient from nuc.token import Did @@ -60,3 +64,141 @@ def get_nuc_token() -> InvocationToken: ) return invocation_token + + +def get_rate_limited_nuc_token(rate_limit: int = 3) -> InvocationToken: + # Services must be running for this to work + PRIVATE_KEY = "l/SYifzu2Iqc3dsWoWHRP2oSMHwrORY/PDw5fDwtJDQ=" # This is an example private key with funds for testing devnet, and should not be used in production + NILAI_ENDPOINT = "localhost:8080" + NILAUTH_ENDPOINT = "localhost:30921" + NILCHAIN_GRPC = "localhost:26649" + + # Server private key + server_wallet, server_keypair, server_private_key = get_wallet_and_private_key( + PRIVATE_KEY + ) + nilauth_client = NilauthClient(f"http://{NILAUTH_ENDPOINT}") + + # Pay for the subscription + pay_for_subscription( + nilauth_client, + server_wallet, + server_keypair, + server_private_key, + f"http://{NILCHAIN_GRPC}", + ) + + # Create a root token + root_token: RootToken = get_root_token(nilauth_client, server_private_key) + + nilai_public_key: NilAuthPublicKey = get_nilai_public_key( + f"http://{NILAI_ENDPOINT}" + ) + + # Create a user private key and public key + user_private_key = NilAuthPrivateKey() + user_public_key = user_private_key.pubkey + + if user_public_key is None: + raise Exception("Failed to get public key") + # b64_public_key = base64.b64encode(public_key.serialize()).decode("utf-8") + + delegation_token: DelegationToken = get_delegation_token( + root_token, + server_private_key, + user_public_key, + usage_limit=3, + expires_at=datetime.now(timezone.utc) + timedelta(minutes=5), + ) + + invocation_token: InvocationToken = get_invocation_token( + delegation_token, + nilai_public_key, + user_private_key, + ) + + default_validation_parameters = ValidationParameters.default() + default_validation_parameters.token_requirements = InvocationRequirement( + audience=Did(nilai_public_key.serialize()) + ) + + validate_token( + f"http://{NILAUTH_ENDPOINT}", + invocation_token.token, + default_validation_parameters, + ) + + return invocation_token + + +def get_invalid_rate_limited_nuc_token() -> InvocationToken: + # Services must be running for this to work + PRIVATE_KEY = "l/SYifzu2Iqc3dsWoWHRP2oSMHwrORY/PDw5fDwtJDQ=" # This is an example private key with funds for testing devnet, and should not be used in production + NILAI_ENDPOINT = "localhost:8080" + NILAUTH_ENDPOINT = "localhost:30921" + NILCHAIN_GRPC = "localhost:26649" + + # Server private key + server_wallet, server_keypair, server_private_key = get_wallet_and_private_key( + PRIVATE_KEY + ) + nilauth_client = NilauthClient(f"http://{NILAUTH_ENDPOINT}") + + # Pay for the subscription + pay_for_subscription( + nilauth_client, + server_wallet, + server_keypair, + server_private_key, + f"http://{NILCHAIN_GRPC}", + ) + + # Create a root token + root_token: RootToken = get_root_token(nilauth_client, server_private_key) + + nilai_public_key: NilAuthPublicKey = get_nilai_public_key( + f"http://{NILAI_ENDPOINT}" + ) + + # Create a user private key and public key + user_private_key = NilAuthPrivateKey() + user_public_key = user_private_key.pubkey + + if user_public_key is None: + raise Exception("Failed to get public key") + # b64_public_key = base64.b64encode(public_key.serialize()).decode("utf-8") + + delegation_token: DelegationToken = get_delegation_token( + root_token, + server_private_key, + user_public_key, + usage_limit=3, + expires_at=datetime.now(timezone.utc) + timedelta(minutes=5), + ) + + delegation_token: DelegationToken = get_delegation_token( + delegation_token, + user_private_key, + user_public_key, + usage_limit=5, + expires_at=datetime.now(timezone.utc) + timedelta(minutes=5), + ) + + invocation_token: InvocationToken = get_invocation_token( + delegation_token, + nilai_public_key, + user_private_key, + ) + + default_validation_parameters = ValidationParameters.default() + default_validation_parameters.token_requirements = InvocationRequirement( + audience=Did(nilai_public_key.serialize()) + ) + + validate_token( + f"http://{NILAUTH_ENDPOINT}", + invocation_token.token, + default_validation_parameters, + ) + + return invocation_token diff --git a/tests/e2e/test_http.py b/tests/e2e/test_http.py index 704b3169..8c433e15 100644 --- a/tests/e2e/test_http.py +++ b/tests/e2e/test_http.py @@ -10,8 +10,11 @@ import json -from .config import BASE_URL, test_models -from .nuc import get_nuc_token +from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter +from .nuc import ( + get_rate_limited_nuc_token, + get_invalid_rate_limited_nuc_token, +) import httpx import pytest @@ -19,7 +22,39 @@ @pytest.fixture def client(): """Create an HTTPX client with default headers""" - invocation_token = get_nuc_token() + invocation_token: str = api_key_getter() + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token}", + }, + verify=False, + timeout=None, + ) + + +@pytest.fixture +def rate_limited_client(): + """Create an HTTPX client with default headers""" + invocation_token = get_rate_limited_nuc_token(rate_limit=1) + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token.token}", + }, + timeout=None, + verify=False, + ) + + +@pytest.fixture +def invalid_rate_limited_client(): + """Create an HTTPX client with default headers""" + invocation_token = get_invalid_rate_limited_nuc_token() return httpx.Client( base_url=BASE_URL, headers={ @@ -28,6 +63,7 @@ def client(): "Authorization": f"Bearer {invocation_token.token}", }, timeout=None, + verify=False, ) @@ -42,7 +78,9 @@ def test_health_endpoint(client): def test_models_endpoint(client): """Test the models endpoint""" response = client.get("/models") - assert response.status_code == 200, "Models endpoint should return 200 OK" + assert response.status_code == 200, ( + f"Models endpoint should return 200 OK: {response.json()}" + ) assert isinstance(response.json(), list), "Models should be returned as a list" # Check for specific models mentioned in the requests @@ -55,8 +93,9 @@ def test_models_endpoint(client): def test_usage_endpoint(client): """Test the usage endpoint""" response = client.get("/usage") - assert response.status_code == 200, "Usage endpoint should return 200 OK" - + assert response.status_code == 200, ( + f"Usage endpoint should return 200 OK: {response.json()} {BASE_URL}" + ) # Basic usage response validation usage_data = response.json() assert isinstance(usage_data, dict), "Usage data should be a dictionary" @@ -367,6 +406,7 @@ def test_invalid_auth_token(client): "Content-Type": "application/json", "Authorization": "Bearer invalid_token_123", }, + verify=False, ) response = invalid_client.get("/attestation/report") @@ -401,6 +441,62 @@ def test_rate_limiting(client): pytest.skip("No rate limiting detected. Manual review may be needed.") +@pytest.mark.skipif( + AUTH_STRATEGY != "nuc", reason="NUC rate limiting not used with API key" +) +def test_rate_limiting_nucs(rate_limited_client): + """Test rate limiting by sending multiple rapid requests""" + # Payload for repeated requests + payload = { + "model": test_models[0], + "messages": [{"role": "user", "content": "What is your name?"}], + } + + # Send multiple rapid requests + responses = [] + for _ in range(4): # Adjust number based on expected rate limits + response = rate_limited_client.post("/chat/completions", json=payload) + responses.append(response) + + # Check for potential rate limit responses + rate_limit_statuses = [429, 403, 503] + rate_limited_responses = [ + r for r in responses if r.status_code in rate_limit_statuses + ] + + assert len(rate_limited_responses) > 0, ( + "No NUC rate limiting detected, when expected" + ) + + +@pytest.mark.skipif( + AUTH_STRATEGY != "nuc", reason="NUC rate limiting not used with API key" +) +def test_invalid_rate_limiting_nucs(invalid_rate_limited_client): + """Test rate limiting by sending multiple rapid requests""" + # Payload for repeated requests + payload = { + "model": test_models[0], + "messages": [{"role": "user", "content": "What is your name?"}], + } + + # Send multiple rapid requests + responses = [] + for _ in range(4): # Adjust number based on expected rate limits + response = invalid_rate_limited_client.post("/chat/completions", json=payload) + responses.append(response) + + # Check for potential rate limit responses + rate_limit_statuses = [401] + rate_limited_responses = [ + r for r in responses if r.status_code in rate_limit_statuses + ] + + assert len(rate_limited_responses) > 0, ( + "No NUC rate limiting detected, when expected" + ) + + def test_large_payload_handling(client): """Test handling of large input payloads""" # Create a very large system message diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 374d8142..7a04e376 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -9,19 +9,46 @@ """ import json - +import httpx import pytest from openai import OpenAI from openai.types.chat import ChatCompletion -from .config import BASE_URL, test_models -from .nuc import get_nuc_token +from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter +from .nuc import ( + get_rate_limited_nuc_token, + get_invalid_rate_limited_nuc_token, +) + + +def _create_openai_client(api_key: str) -> OpenAI: + """Helper function to create an OpenAI client with SSL verification disabled""" + transport = httpx.HTTPTransport(verify=False) + return OpenAI( + base_url=BASE_URL, + api_key=api_key, + http_client=httpx.Client(transport=transport), + ) @pytest.fixture def client(): """Create an OpenAI client configured to use the Nilai API""" - invocation_token = get_nuc_token() - return OpenAI(base_url=BASE_URL, api_key=invocation_token.token) + invocation_token: str = api_key_getter() + return _create_openai_client(invocation_token) + + +@pytest.fixture +def rate_limited_client(): + """Create an OpenAI client configured to use the Nilai API with rate limiting""" + invocation_token = get_rate_limited_nuc_token(rate_limit=1) + return _create_openai_client(invocation_token.token) + + +@pytest.fixture +def invalid_rate_limited_client(): + """Create an OpenAI client configured to use the Nilai API with rate limiting""" + invocation_token = get_invalid_rate_limited_nuc_token() + return _create_openai_client(invocation_token.token) @pytest.mark.parametrize( @@ -80,6 +107,72 @@ def test_chat_completion(client, model): pytest.fail(f"Error testing chat completion with {model}: {str(e)}") +@pytest.mark.parametrize( + "model", + test_models, +) +@pytest.mark.skipif( + AUTH_STRATEGY != "nuc", reason="NUC rate limiting not used with API key" +) +def test_rate_limiting_nucs(rate_limited_client, model): + """Test rate limiting by sending multiple rapid requests""" + import openai + + # Send multiple rapid requests + rate_limited = False + for _ in range(4): # Adjust number based on expected rate limits + try: + _ = rate_limited_client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that provides accurate and concise information.", + }, + {"role": "user", "content": "What is the capital of France?"}, + ], + temperature=0.2, + max_tokens=100, + ) + except openai.RateLimitError: + rate_limited = True + + assert rate_limited, "No NUC rate limiting detected, when expected" + + +@pytest.mark.parametrize( + "model", + test_models, +) +@pytest.mark.skipif( + AUTH_STRATEGY != "nuc", reason="NUC rate limiting not used with API key" +) +def test_invalid_rate_limiting_nucs(invalid_rate_limited_client, model): + """Test rate limiting by sending multiple rapid requests""" + import openai + + # Send multiple rapid requests + forbidden = False + for _ in range(4): # Adjust number based on expected rate limits + try: + _ = invalid_rate_limited_client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that provides accurate and concise information.", + }, + {"role": "user", "content": "What is the capital of France?"}, + ], + temperature=0.2, + max_tokens=100, + ) + except openai.AuthenticationError: + forbidden = True + + assert forbidden, "No NUC rate limiting detected, when expected" + + @pytest.mark.parametrize( "model", test_models, @@ -337,15 +430,16 @@ def test_usage_endpoint(client): # The OpenAI client doesn't have a built-in method for this import requests - invocation_token = get_nuc_token() + invocation_token = api_key_getter() url = BASE_URL + "/usage" response = requests.get( url, headers={ - "Authorization": f"Bearer {invocation_token.token}", + "Authorization": f"Bearer {invocation_token}", "Content-Type": "application/json", }, + verify=False, ) assert response.status_code == 200, "Usage endpoint should return 200 OK" @@ -375,14 +469,15 @@ def test_attestation_endpoint(client): import requests url = BASE_URL + "/attestation/report" - invocation_token = get_nuc_token() + invocation_token = api_key_getter() response = requests.get( url, headers={ - "Authorization": f"Bearer {invocation_token.token}", + "Authorization": f"Bearer {invocation_token}", "Content-Type": "application/json", }, params={"nonce": "0" * 64}, + verify=False, ) assert response.status_code == 200, "Attestation endpoint should return 200 OK" @@ -414,6 +509,7 @@ def test_health_endpoint(client): "Accept": "application/json", "Content-Type": "application/json", }, + verify=False, ) print(f"Health response: {response.status_code} {response.text}") diff --git a/tests/unit/nilai_api/__init__.py b/tests/unit/nilai_api/__init__.py index 72143cb4..0be52613 100644 --- a/tests/unit/nilai_api/__init__.py +++ b/tests/unit/nilai_api/__init__.py @@ -1,6 +1,6 @@ import pytest import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Dict, List, Optional, Any @@ -32,7 +32,7 @@ async def insert_user(self, name: str, email: str) -> Dict[str, str]: "prompt_tokens": 0, "completion_tokens": 0, "queries": 0, - "signup_date": datetime.now(), + "signup_date": datetime.now(timezone.utc), "last_activity": None, } @@ -55,7 +55,7 @@ async def update_token_usage( user["prompt_tokens"] += prompt_tokens user["completion_tokens"] += completion_tokens user["queries"] += 1 - user["last_activity"] = datetime.now() + user["last_activity"] = datetime.now(timezone.utc) async def log_query( self, userid: str, model: str, prompt_tokens: int, completion_tokens: int @@ -64,7 +64,7 @@ async def log_query( query_log = { "id": self._next_query_log_id, "userid": userid, - "query_timestamp": datetime.now(), + "query_timestamp": datetime.now(timezone.utc), "model": model, "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, diff --git a/tests/unit/nilai_api/auth/test_auth.py b/tests/unit/nilai_api/auth/test_auth.py index 492c0d85..88780a0b 100644 --- a/tests/unit/nilai_api/auth/test_auth.py +++ b/tests/unit/nilai_api/auth/test_auth.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone from unittest.mock import MagicMock import pytest @@ -27,34 +28,72 @@ def mock_user_model(): mock = MagicMock(spec=UserModel) mock.name = "Test User" mock.userid = "test-user-id" + mock.apikey = "test-api-key" + mock.prompt_tokens = 0 + mock.completion_tokens = 0 + mock.queries = 0 + mock.signup_date = datetime.now(timezone.utc) + mock.last_activity = datetime.now(timezone.utc) + mock.ratelimit_day = 1000 + mock.ratelimit_hour = 1000 + mock.ratelimit_minute = 1000 + return mock + + +@pytest.fixture +def mock_user_data(mock_user_model): + from nilai_api.db.users import UserData + + return UserData.from_sqlalchemy(mock_user_model) + + +@pytest.fixture +def mock_auth_info(): + from nilai_api.auth import AuthenticationInfo + + mock = MagicMock(spec=AuthenticationInfo) + mock.user = mock_user_data return mock @pytest.mark.asyncio -async def test_get_user_valid_token(mock_user_manager, mock_user_model): - from nilai_api.auth import get_user +async def test_get_auth_info_valid_token( + mock_user_manager, mock_auth_info, mock_user_model +): + from nilai_api.auth import get_auth_info - """Test get_user with a valid token.""" + """Test get_auth_info with a valid token.""" mock_user_manager.check_api_key.return_value = mock_user_model credentials = HTTPAuthorizationCredentials( scheme="Bearer", credentials="valid-token" ) - user = await get_user(credentials) - assert user.name == mock_user_model.name - assert user.userid == mock_user_model.userid + auth_info = await get_auth_info(credentials) + print(auth_info) + assert auth_info.user.name == "Test User", ( + f"Expected Test User but got {auth_info.user.name}" + ) + assert auth_info.user.userid == "test-user-id", ( + f"Expected test-user-id but got {auth_info.user.userid}" + ) @pytest.mark.asyncio -async def test_get_user_invalid_token(mock_user_manager): - from nilai_api.auth import get_user +async def test_get_auth_info_invalid_token(mock_user_manager): + from nilai_api.auth import get_auth_info - """Test get_user with an invalid token.""" + """Test get_auth_info with an invalid token.""" mock_user_manager.check_api_key.return_value = None credentials = HTTPAuthorizationCredentials( scheme="Bearer", credentials="invalid-token" ) with pytest.raises(HTTPException) as exc_info: - await get_user(credentials) - assert exc_info.value.status_code == 401 - assert exc_info.value.detail == "Missing or invalid API key" + auth_infor = await get_auth_info(credentials) + print(auth_infor) + print(exc_info) + assert exc_info.value.status_code == 401, ( + f"Expected status code 401 but got {exc_info.value.status_code}" + ) + assert exc_info.value.detail == "Missing or invalid API key", ( + f"Expected Missing or invalid API key but got {exc_info.value.detail}" + ) diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index e08eb0e3..759d233d 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -2,7 +2,9 @@ import string import random from unittest.mock import MagicMock +from datetime import datetime, timedelta, timezone +from nilai_api.auth import TokenRateLimit, TokenRateLimits import pytest import pytest_asyncio from fastapi import HTTPException, Request @@ -41,7 +43,11 @@ async def test_concurrent_rate_limit(req): rate_limit = RateLimit(concurrent_extractor=lambda _: (5, "test")) user_limits = UserRateLimits( - id=random_id(), day_limit=None, hour_limit=None, minute_limit=None + subscription_holder=random_id(), + day_limit=None, + hour_limit=None, + minute_limit=None, + token_rate_limit=None, ) futures = [consume_generator(rate_limit(req, user_limits)) for _ in range(5)] @@ -63,13 +69,40 @@ async def test_concurrent_rate_limit(req): "user_limits", [ UserRateLimits( - id=random_id(), day_limit=10, hour_limit=None, minute_limit=None + subscription_holder=random_id(), + day_limit=10, + hour_limit=None, + minute_limit=None, + token_rate_limit=None, ), UserRateLimits( - id=random_id(), day_limit=None, hour_limit=11, minute_limit=None + subscription_holder=random_id(), + day_limit=None, + hour_limit=11, + minute_limit=None, + token_rate_limit=None, ), UserRateLimits( - id=random_id(), day_limit=None, hour_limit=None, minute_limit=12 + subscription_holder=random_id(), + day_limit=None, + hour_limit=None, + minute_limit=12, + token_rate_limit=None, + ), + UserRateLimits( + subscription_holder=random_id(), + day_limit=None, + hour_limit=None, + minute_limit=None, + token_rate_limit=TokenRateLimits( + limits=[ + TokenRateLimit( + signature=random_id(), + usage_limit=11, + expires_at=datetime.now(timezone.utc) + timedelta(minutes=5), + ) + ] + ), ), ], ) diff --git a/tests/unit/nuc-helpers/test_usage.py b/tests/unit/nuc-helpers/test_usage.py new file mode 100644 index 00000000..2c181338 --- /dev/null +++ b/tests/unit/nuc-helpers/test_usage.py @@ -0,0 +1,269 @@ +import unittest +from unittest.mock import patch +from nuc_helpers.usage import TokenRateLimits, UsageLimitError, UsageLimitKind + +from datetime import datetime, timedelta, timezone + + +# Dummy token envelope structure to simulate nuc.envelope +class DummyNucToken: + def __init__( + self, meta=None, expires_at=datetime.now(timezone.utc) + timedelta(days=1) + ): + self.meta = meta or {} + self.expires_at = expires_at + + +class DummyDecodedNucToken: + def __init__(self, meta=None): + self.token = DummyNucToken(meta) + self.signature = b"\x01\x02" + + +class DummyNucTokenEnvelope: + def __init__(self, proofs, invocation_meta=None): + self.proofs = proofs + self.token = DummyDecodedNucToken(invocation_meta) + + +class GetUsageLimitTests(unittest.TestCase): + def setUp(self): + """Clear the cache before each test, because the cache is global and we use the same dummy token for all tests.""" + TokenRateLimits.from_token.cache_clear() + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_no_usage_limit_returns_none(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[DummyDecodedNucToken(), DummyDecodedNucToken()] + ) + mock_parse.return_value = env + + limits = TokenRateLimits.from_token("dummy_token") + self.assertIsNone(limits) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_single_usage_limit_returns_value(self, mock_parse): + env = DummyNucTokenEnvelope(proofs=[DummyDecodedNucToken({"usage_limit": 10})]) + mock_parse.return_value = env + + limits = TokenRateLimits.from_token("dummy_token") + if limits is None: + self.fail("Limits should not be None") + self.assertEqual(limits.last.usage_limit, 10) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_multiple_consistent_limits(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[ + DummyDecodedNucToken( + {"usage_limit": 25} + ), # This is a second reduction of the base usage limit + DummyDecodedNucToken( + {"usage_limit": 50} + ), # This is a first reduction of the base usage limit + DummyDecodedNucToken( + {"usage_limit": 100} + ), # This is the base usage limit + ] + ) + mock_parse.return_value = env + + limits = TokenRateLimits.from_token("dummy_token") + if limits is None: + self.fail("Limits should not be None") + self.assertEqual(limits.last.usage_limit, 25) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_multiple_consistent_limits_with_none(self, mock_parse): + limits = [25, None, 100] + env = DummyNucTokenEnvelope( + proofs=[ + DummyDecodedNucToken( + {"usage_limit": 25} + ), # This is a second reduction of the base usage limit + DummyDecodedNucToken( + {"usage_limit": None} + ), # This is a first reduction of the base usage limit + DummyDecodedNucToken( + {"usage_limit": 100} + ), # This is the base usage limit + ] + ) + mock_parse.return_value = env + + limits = TokenRateLimits.from_token("dummy_token") + if limits is None: + self.fail("Limits should not be None") + self.assertEqual(limits.last.usage_limit, 25) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_multiple_consistent_limits_with_none_2(self, mock_parse): + limits = [25, 100, None] + env = DummyNucTokenEnvelope( + proofs=[DummyDecodedNucToken({"usage_limit": limit}) for limit in limits] + ) + mock_parse.return_value = env + + token_rate_limits = TokenRateLimits.from_token("dummy_token") + if token_rate_limits is None: + self.fail("Limits should not be None") + limits_reversed_without_none = [ + limit for limit in limits[::-1] if limit is not None + ] + for effective_limit, expected_limit in zip( + token_rate_limits.limits, limits_reversed_without_none + ): + self.assertEqual(effective_limit.usage_limit, expected_limit) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_multiple_consistent_limits_long(self, mock_parse): + limits = [25, 32, None, 50, None, None, 75, 100, None] + env = DummyNucTokenEnvelope( + proofs=[DummyDecodedNucToken({"usage_limit": limit}) for limit in limits] + ) + mock_parse.return_value = env + + token_rate_limits = TokenRateLimits.from_token("dummy_token") + if token_rate_limits is None: + self.fail("Limits should not be None") + limits_reversed_without_none = [ + limit for limit in limits[::-1] if limit is not None + ] + for effective_limit, expected_limit in zip( + token_rate_limits.limits, limits_reversed_without_none + ): + self.assertEqual(effective_limit.usage_limit, expected_limit) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_inconsistent_usage_limits_raises_error(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[ + DummyDecodedNucToken({"usage_limit": 110}), + DummyDecodedNucToken({"usage_limit": 90}), + DummyDecodedNucToken({"usage_limit": 100}), + ] + ) + mock_parse.return_value = env + + with self.assertRaises(UsageLimitError) as cm: + TokenRateLimits.from_token("dummy_token") + self.assertEqual(cm.exception.kind, UsageLimitKind.INCONSISTENT) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_inconsistent_usage_limits_with_none_raises_error(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[ + DummyDecodedNucToken({"usage_limit": 110}), + DummyDecodedNucToken({"usage_limit": None}), + DummyDecodedNucToken({"usage_limit": 100}), + ] + ) + mock_parse.return_value = env + + with self.assertRaises(UsageLimitError) as cm: + TokenRateLimits.from_token("dummy_token") + self.assertEqual(cm.exception.kind, UsageLimitKind.INCONSISTENT) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_inconsistent_usage_limits_with_negative_raises_error(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[ + DummyDecodedNucToken({"usage_limit": 80}), + DummyDecodedNucToken({"usage_limit": -90}), + DummyDecodedNucToken({"usage_limit": 100}), + ] + ) + mock_parse.return_value = env + + with self.assertRaises(UsageLimitError) as cm: + TokenRateLimits.from_token("dummy_token") + self.assertEqual(cm.exception.kind, UsageLimitKind.INCONSISTENT) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_inconsistent_usage_limits_with_long_chain(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[ + DummyDecodedNucToken({"usage_limit": 50}), + DummyDecodedNucToken({"usage_limit": 74}), + DummyDecodedNucToken({"usage_limit": 85}), + DummyDecodedNucToken({"usage_limit": 88}), + DummyDecodedNucToken({"usage_limit": None}), + DummyDecodedNucToken({"usage_limit": -89}), + DummyDecodedNucToken({"usage_limit": 99}), + DummyDecodedNucToken({"usage_limit": None}), + DummyDecodedNucToken({"usage_limit": 100}), + ] + ) + mock_parse.return_value = env + + with self.assertRaises(UsageLimitError) as cm: + TokenRateLimits.from_token("dummy_token") + self.assertEqual(cm.exception.kind, UsageLimitKind.INCONSISTENT) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_invalid_type_usage_limit_raises_error(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[ + DummyDecodedNucToken({"usage_limit": "not-an-int"}), + ] + ) + mock_parse.return_value = env + + with self.assertRaises(UsageLimitError) as cm: + TokenRateLimits.from_token("dummy_token") + self.assertEqual(cm.exception.kind, UsageLimitKind.INVALID_TYPE) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_none_type_usage_doesnt_raise_error(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[ + DummyDecodedNucToken({"usage_limit": None}), + ] + ) + mock_parse.return_value = env + + limits = TokenRateLimits.from_token("dummy_token") + self.assertIsNone(limits) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_invocation_usage_limit_ignored(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[DummyDecodedNucToken({"usage_limit": 5})], + invocation_meta={"usage_limit": 999}, # Should be ignored + ) + mock_parse.return_value = env + + limits = TokenRateLimits.from_token("dummy_token") + if limits is None: + self.fail("Limits should not be None") + self.assertEqual(limits.last.usage_limit, 5) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_caching_behavior(self, mock_parse): + env = DummyNucTokenEnvelope(proofs=[DummyDecodedNucToken({"usage_limit": 10})]) + mock_parse.return_value = env + + TokenRateLimits.from_token("dummy_token") + TokenRateLimits.from_token("dummy_token") + + # NucTokenEnvelope.parse should only be called once due to caching + mock_parse.assert_called_once() + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_expires_at_returns_correct_value(self, mock_parse): + env = DummyNucTokenEnvelope( + proofs=[DummyDecodedNucToken({"usage_limit": 10, "expires_at": 1715702400})] + ) + mock_parse.return_value = env + + limits = TokenRateLimits.from_token("dummy_token") + if limits is None: + self.fail("Limits should not be None") + expires_at = limits.last.expires_at + + # Check expires_at is less than 1 day from now + self.assertLess(expires_at, datetime.now(timezone.utc) + timedelta(days=1)) # type: ignore + + +if __name__ == "__main__": + unittest.main() diff --git a/uv.lock b/uv.lock index 8f0e1b75..bd07cfc7 100644 --- a/uv.lock +++ b/uv.lock @@ -1361,6 +1361,7 @@ dependencies = [ { name = "nilai-common" }, { name = "nilrag" }, { name = "nuc" }, + { name = "nuc-helpers" }, { name = "openai" }, { name = "pg8000" }, { name = "prometheus-fastapi-instrumentator" }, @@ -1387,6 +1388,7 @@ requires-dist = [ { name = "nilai-common", editable = "packages/nilai-common" }, { name = "nilrag", specifier = ">=0.1.11" }, { name = "nuc", git = "https://github.com/NillionNetwork/nuc-py.git" }, + { name = "nuc-helpers", editable = "nilai-auth/nuc-helpers" }, { name = "openai", specifier = ">=1.59.9" }, { name = "pg8000", specifier = ">=1.31.2" }, { name = "prometheus-fastapi-instrumentator", specifier = ">=7.0.2" },