diff --git a/.env.ci b/.env.ci index e3daaac7..0272eec5 100644 --- a/.env.ci +++ b/.env.ci @@ -47,3 +47,10 @@ GF_SECURITY_ADMIN_PASSWORD = "password" # WebSearch Settings BRAVE_SEARCH_API = "Your API here" + +# Optional: Override default testnet URLs if needed +NILDB_NILCHAIN_URL=http://rpc.testnet.nilchain-rpc-proxy.nilogy.xyz +NILDB_NILAUTH_URL=https://nilauth.sandbox.app-cluster.sandbox.nilogy.xyz +NILDB_NODES=https://nildb-stg-n1.nillion.network,https://nildb-stg-n2.nillion.network,https://nildb-stg-n3.nillion.network +NILDB_BUILDER_PRIVATE_KEY=0x1234567890abcdef1234567890abcdef12345678 +NILDB_COLLECTION=12345678-1234-1234-1234-123456789012 diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 31189916..1ed12f0b 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -44,6 +44,13 @@ jobs: - name: Run Ruff linting run: uv run ruff check --exclude packages/verifier/ + - name: Create .env for tests + run: | + cp .env.ci .env + # Set dummy secrets for unit tests + sed -i 's/HF_TOKEN=.*/HF_TOKEN=dummy_token/' .env + sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=dummy_api/' .env + - name: Run tests run: uv run pytest -v tests/unit @@ -135,6 +142,8 @@ jobs: # Copy secret into .env replacing the existing HF_TOKEN sed -i 's/HF_TOKEN=.*/HF_TOKEN=${{ secrets.HF_TOKEN }}/' .env sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env + sed -i 's/NILDB_BUILDER_PRIVATE_KEY=.*/NILDB_BUILDER_PRIVATE_KEY=${{ secrets.NILDB_BUILDER_PRIVATE_KEY }}/' .env + sed -i 's/NILDB_COLLECTION=.*/NILDB_COLLECTION=${{ secrets.NILDB_COLLECTION }}/' .env - name: Compose docker-compose.yml run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -o development-compose.yml diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..e8495c0a --- /dev/null +++ b/conftest.py @@ -0,0 +1,22 @@ +"""Global pytest configuration.""" + +import asyncio +import warnings + + +def pytest_configure(config): + """Configure pytest to suppress StreamWriter errors.""" + # Suppress warnings + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + # Monkey patch StreamWriter.__del__ to suppress exceptions + original_del = asyncio.StreamWriter.__del__ + + def silent_del(self): + try: + original_del(self) + except Exception: + pass + + asyncio.StreamWriter.__del__ = silent_del diff --git a/nilai-api/pyproject.toml b/nilai-api/pyproject.toml index e2b3462f..a0aab93f 100644 --- a/nilai-api/pyproject.toml +++ b/nilai-api/pyproject.toml @@ -32,8 +32,9 @@ dependencies = [ "web3>=7.8.0", "click>=8.1.8", "nuc-helpers", - "nuc", + "nuc>=0.1.0", "pyyaml>=6.0.1", + "secretvaults", ] @@ -45,4 +46,5 @@ build-backend = "hatchling.build" nilai-common = { workspace = true } nuc-helpers = { workspace = true } -nuc = { git = "https://github.com/NillionNetwork/nuc-py.git", rev = "4922b5e9354e611cc31322d681eb29da05be584e" } +# TODO: Remove this once the secretvaults package is released with the fix +secretvaults = { git = "https://github.com/jcabrero/secretvaults-py", rev = "main" } diff --git a/nilai-api/src/nilai_api/auth/common.py b/nilai-api/src/nilai_api/auth/common.py index 99a3bfea..793aed2a 100644 --- a/nilai-api/src/nilai_api/auth/common.py +++ b/nilai-api/src/nilai_api/auth/common.py @@ -3,6 +3,7 @@ from fastapi import HTTPException, status from nilai_api.db.users import UserData from nuc_helpers.usage import TokenRateLimits, TokenRateLimit +from nuc_helpers.nildb_document import PromptDocument class AuthenticationError(HTTPException): @@ -17,6 +18,7 @@ def __init__(self, detail: str): class AuthenticationInfo(BaseModel): user: UserData token_rate_limit: Optional[TokenRateLimits] + prompt_document: Optional[PromptDocument] __all__ = [ @@ -24,4 +26,5 @@ class AuthenticationInfo(BaseModel): "AuthenticationInfo", "TokenRateLimits", "TokenRateLimit", + "PromptDocument", ] diff --git a/nilai-api/src/nilai_api/auth/nuc.py b/nilai-api/src/nilai_api/auth/nuc.py index 0a300d19..8274376d 100644 --- a/nilai-api/src/nilai_api/auth/nuc.py +++ b/nilai-api/src/nilai_api/auth/nuc.py @@ -12,6 +12,7 @@ from nilai_common.logger import setup_logger from nuc_helpers.usage import TokenRateLimits +from nuc_helpers.nildb_document import PromptDocument logger = setup_logger(__name__) @@ -120,3 +121,8 @@ def get_token_rate_limit(nuc_token: str) -> Optional[TokenRateLimits]: raise AuthenticationError("Token has expired") return token_rate_limits + + +def get_token_prompt_document(nuc_token: str) -> Optional[PromptDocument]: + prompt_document = PromptDocument.from_token(nuc_token) + return prompt_document diff --git a/nilai-api/src/nilai_api/auth/strategies.py b/nilai-api/src/nilai_api/auth/strategies.py index 6b9a0498..dc279d73 100644 --- a/nilai-api/src/nilai_api/auth/strategies.py +++ b/nilai-api/src/nilai_api/auth/strategies.py @@ -1,11 +1,16 @@ -from typing import Callable, Awaitable +from typing import Callable, Awaitable, Optional from datetime import datetime, timezone from nilai_api.db.users import UserManager, UserModel, UserData from nilai_api.auth.jwt import validate_jwt -from nilai_api.auth.nuc import validate_nuc, get_token_rate_limit +from nilai_api.auth.nuc import ( + validate_nuc, + get_token_rate_limit, + get_token_prompt_document, +) from nilai_api.config import DOCS_TOKEN from nilai_api.auth.common import ( + PromptDocument, TokenRateLimits, AuthenticationInfo, AuthenticationError, @@ -55,6 +60,7 @@ async def wrapper(token) -> AuthenticationInfo: return AuthenticationInfo( user=UserData.from_sqlalchemy(user_model), token_rate_limit=None, + prompt_document=None, ) return await function(token) @@ -65,10 +71,12 @@ async def wrapper(token) -> AuthenticationInfo: @allow_token(DOCS_TOKEN) async def api_key_strategy(api_key: str) -> AuthenticationInfo: - user_model: UserModel | None = await UserManager.check_api_key(api_key) + user_model: Optional[UserModel] = await UserManager.check_api_key(api_key) if user_model: return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), token_rate_limit=None + user=UserData.from_sqlalchemy(user_model), + token_rate_limit=None, + prompt_document=None, ) raise AuthenticationError("Missing or invalid API key") @@ -76,10 +84,14 @@ async def api_key_strategy(api_key: str) -> AuthenticationInfo: @allow_token(DOCS_TOKEN) async def jwt_strategy(jwt_creds: str) -> AuthenticationInfo: result = validate_jwt(jwt_creds) - user_model: UserModel | None = await UserManager.check_api_key(result.user_address) + user_model: Optional[UserModel] = await UserManager.check_api_key( + result.user_address + ) if user_model: return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), token_rate_limit=None + user=UserData.from_sqlalchemy(user_model), + token_rate_limit=None, + prompt_document=None, ) else: user_model = UserModel( @@ -89,7 +101,9 @@ async def jwt_strategy(jwt_creds: str) -> AuthenticationInfo: ) await UserManager.insert_user_model(user_model) return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), token_rate_limit=None + user=UserData.from_sqlalchemy(user_model), + token_rate_limit=None, + prompt_document=None, ) @@ -99,12 +113,15 @@ async def nuc_strategy(nuc_token) -> AuthenticationInfo: Validate a NUC token and return the user model """ subscription_holder, user = validate_nuc(nuc_token) - token_rate_limits: TokenRateLimits | None = get_token_rate_limit(nuc_token) - user_model: UserModel | None = await UserManager.check_user(user) + token_rate_limits: Optional[TokenRateLimits] = get_token_rate_limit(nuc_token) + prompt_document: Optional[PromptDocument] = get_token_prompt_document(nuc_token) + + user_model: Optional[UserModel] = await UserManager.check_user(user) if user_model: return AuthenticationInfo( user=UserData.from_sqlalchemy(user_model), token_rate_limit=token_rate_limits, + prompt_document=prompt_document, ) user_model = UserModel( @@ -114,7 +131,9 @@ async def nuc_strategy(nuc_token) -> AuthenticationInfo: ) await UserManager.insert_user_model(user_model) return AuthenticationInfo( - user=UserData.from_sqlalchemy(user_model), token_rate_limit=token_rate_limits + user=UserData.from_sqlalchemy(user_model), + token_rate_limit=token_rate_limits, + prompt_document=prompt_document, ) diff --git a/nilai-api/src/nilai_api/db/users.py b/nilai-api/src/nilai_api/db/users.py index 91a5d3f5..92dfa354 100644 --- a/nilai-api/src/nilai_api/db/users.py +++ b/nilai-api/src/nilai_api/db/users.py @@ -92,6 +92,10 @@ def from_sqlalchemy(cls, user: UserModel) -> "UserData": web_search_ratelimit_minute=user.web_search_ratelimit_minute, ) + @property + def is_subscription_owner(self): + return self.userid == self.apikey + class UserManager: @staticmethod diff --git a/nilai-api/src/nilai_api/handlers/nildb/__init__.py b/nilai-api/src/nilai_api/handlers/nildb/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nilai-api/src/nilai_api/handlers/nildb/api_model.py b/nilai-api/src/nilai_api/handlers/nildb/api_model.py new file mode 100644 index 00000000..047930ad --- /dev/null +++ b/nilai-api/src/nilai_api/handlers/nildb/api_model.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, ConfigDict +from typing import TypeAlias + +PromptDelegationRequest: TypeAlias = str + + +class PromptDelegationToken(BaseModel): + """Delegation token model""" + + model_config = ConfigDict(validate_assignment=True) + + token: str + did: str diff --git a/nilai-api/src/nilai_api/handlers/nildb/config.py b/nilai-api/src/nilai_api/handlers/nildb/config.py new file mode 100644 index 00000000..2a94ac03 --- /dev/null +++ b/nilai-api/src/nilai_api/handlers/nildb/config.py @@ -0,0 +1,38 @@ +import os +from typing import Optional +from dotenv import load_dotenv +from pydantic import BaseModel +from pydantic import Field + +from secretvaults.common.types import Uuid + +load_dotenv() + + +class NilDBConfig(BaseModel): + NILCHAIN_URL: str = Field(..., description="The URL of the Nilchain") + NILAUTH_URL: str = Field(..., description="The URL of the Nilauth") + NODES: list[str] = Field(..., description="The URLs of the Nildb nodes") + BUILDER_PRIVATE_KEY: str = Field(..., description="The private key of the builder") + COLLECTION: Uuid = Field(..., description="The ID of the collection") + + +def get_required_env_var(name: str) -> str: + """Get a required environment variable, raising an error if not set.""" + value: Optional[str] = os.getenv(name, None) + if value is None: + raise ValueError(f"Required environment variable {name} is not set") + return value + + +# Validate environment variables at import time +CONFIG = NilDBConfig( + NILCHAIN_URL=get_required_env_var("NILDB_NILCHAIN_URL"), + NILAUTH_URL=get_required_env_var("NILDB_NILAUTH_URL"), + NODES=get_required_env_var("NILDB_NODES").split(","), + BUILDER_PRIVATE_KEY=get_required_env_var("NILDB_BUILDER_PRIVATE_KEY"), + COLLECTION=Uuid(get_required_env_var("NILDB_COLLECTION")), +) + + +__all__ = ["CONFIG"] diff --git a/nilai-api/src/nilai_api/handlers/nildb/handler.py b/nilai-api/src/nilai_api/handlers/nildb/handler.py new file mode 100644 index 00000000..d419ef74 --- /dev/null +++ b/nilai-api/src/nilai_api/handlers/nildb/handler.py @@ -0,0 +1,140 @@ +from typing import Optional +from nilai_api.handlers.nildb.config import CONFIG + +from secretvaults import SecretVaultBuilderClient, SecretVaultUserClient +from secretvaults.common.keypair import Keypair +from secretvaults.common.blindfold import BlindfoldFactoryConfig, BlindfoldOperation + +from secretvaults.common.utils import into_seconds_from_now +from nuc.builder import NucTokenBuilder +from nuc.token import Command, Did +from secretvaults.common.nuc_cmd import NucCmd +from secretvaults.common.types import Uuid +from secretvaults.dto.users import ( + ReadDataRequestParams, +) + +from nilai_api.auth.common import PromptDocument +from nilai_api.handlers.nildb.api_model import PromptDelegationToken + +import datetime + + +BUILDER_CLIENT: Optional[SecretVaultBuilderClient] = None +USER_CLIENT: Optional[SecretVaultUserClient] = None + + +async def create_builder_client(): + """Create and return a builder client using proper initialization pattern""" + global BUILDER_CLIENT + if BUILDER_CLIENT is not None: + return BUILDER_CLIENT + + # Create keypair from private key + keypair = Keypair.from_hex(CONFIG.BUILDER_PRIVATE_KEY) + + # Prepare URLs for the builder client + urls = { + "chain": [CONFIG.NILCHAIN_URL], + "auth": CONFIG.NILAUTH_URL, + "dbs": CONFIG.NODES, + } + + # Create SecretVaultBuilderClient with proper initialization + BUILDER_CLIENT = await SecretVaultBuilderClient.from_options( + keypair=keypair, + urls=urls, + blindfold=BlindfoldFactoryConfig( + operation=BlindfoldOperation.STORE, use_cluster_key=True + ), + ) + + # Get root token for use in other functions + await BUILDER_CLIENT.refresh_root_token() + + return BUILDER_CLIENT + + +async def create_user_client() -> SecretVaultUserClient: + """Create and return a user client using proper initialization pattern""" + global USER_CLIENT + if USER_CLIENT is not None: + return USER_CLIENT + + # Create keypair from private key + keypair = Keypair.from_hex(CONFIG.BUILDER_PRIVATE_KEY) + USER_CLIENT = await SecretVaultUserClient.from_options( + keypair=keypair, + base_urls=CONFIG.NODES, + blindfold=BlindfoldFactoryConfig( + operation=BlindfoldOperation.STORE, use_cluster_key=True + ), + ) + + return USER_CLIENT + + +async def get_nildb_delegation_token(user_did: str) -> PromptDelegationToken: + """Get a delegation token for the builder - core functionality without UI concerns""" + # Get builder's root token + builder_client = await create_builder_client() + root_token_envelope = builder_client.root_token + + if not root_token_envelope: + raise ValueError("Couldn't extract root NUC token from nilDB profile") + + # Create delegation token extending the root token envelope + delegation_token = ( + NucTokenBuilder.extending(root_token_envelope) + .command(Command(NucCmd.NIL_DB_DATA_CREATE.value.split("."))) + .audience(Did.parse(user_did)) + .expires_at(datetime.datetime.fromtimestamp(into_seconds_from_now(60))) + .build(builder_client.keypair.private_key()) + ) + + builder_did = builder_client.keypair.to_did_string() + return PromptDelegationToken(token=delegation_token, did=builder_did) + + +""" Read nilDB records from owned data collection based on the store id given by the user on the request """ + + +async def get_prompt_from_nildb(prompt_document: PromptDocument) -> str: + """Read a specific document - core functionality""" + read_params = ReadDataRequestParams( + collection=CONFIG.COLLECTION, + document=Uuid(prompt_document.document_id), + subject=Uuid(prompt_document.owner_did), + ) + user_client = await create_user_client() + document_response = await user_client.read_data(read_params) + + if not document_response: + raise ValueError("Couldn't get document response from nilDB nodes") + + # Check if response has data attribute (wrapped response) + if hasattr(document_response, "data") and document_response.data: + document_data = document_response.data + else: + document_data = document_response + + # Convert to dict to avoid pyright attribute errors based on flexible typing of output dictionary + if hasattr(document_data, "__dict__"): + data_dict = document_data.__dict__ + elif hasattr(document_data, "model_dump"): + data_dict = document_data.model_dump() + else: + data_dict = dict(document_data) if document_data else {} + + if data_dict.get("owner", None) != str(prompt_document.owner_did): + raise ValueError( + "Non-owning entity trying to invoke access to a document resource" + ) + + if "prompt" not in data_dict: + raise ValueError("Couldn't find prompt field in document response from nilDB") + + prompt = data_dict.get("prompt") + if prompt is None: + raise ValueError("Prompt field is None in document response from nilDB") + return prompt diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 376e6570..b3a50153 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -19,6 +19,15 @@ from nilai_api.rate_limiting import RateLimit from nilai_api.state import state +from nilai_api.handlers.nildb.api_model import ( + PromptDelegationRequest, + PromptDelegationToken, +) +from nilai_api.handlers.nildb.handler import ( + get_nildb_delegation_token, + get_prompt_from_nildb, +) + # Internal libraries from nilai_common import ( AttestationReport, @@ -40,6 +49,26 @@ router = APIRouter() +@router.get("/v1/delegation") +async def get_prompt_store_delegation( + prompt_delegation_request: PromptDelegationRequest, + auth_info: AuthenticationInfo = Depends(get_auth_info), +) -> PromptDelegationToken: + if not auth_info.user.is_subscription_owner: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Prompt storage is reserved to subscription owners", + ) + + try: + return await get_nildb_delegation_token(prompt_delegation_request) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Server unable to produce delegation tokens: {str(e)}", + ) + + @router.get("/v1/usage", tags=["Usage"]) async def get_usage(auth_info: AuthenticationInfo = Depends(get_auth_info)) -> Usage: """ @@ -241,6 +270,17 @@ async def chat_completion( ) client = AsyncOpenAI(base_url=model_url, api_key="") + if auth_info.prompt_document: + try: + nildb_prompt: str = await get_prompt_from_nildb(auth_info.prompt_document) + req.messages.insert( + 0, MessageAdapter.new_message(role="system", content=nildb_prompt) + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Unable to extract prompt from nilDB: {str(e)}", + ) if req.nilrag: logger.info(f"[chat] nilrag start request_id={request_id}") diff --git a/nilai-auth/nuc-helpers/pyproject.toml b/nilai-auth/nuc-helpers/pyproject.toml index 517fffd6..1a38b274 100644 --- a/nilai-auth/nuc-helpers/pyproject.toml +++ b/nilai-auth/nuc-helpers/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "pydantic>=2.11.2", "secp256k1>=0.14.0", "httpx>=0.28.1", - "nuc", + "nuc>=0.1.0", ] [build-system] @@ -20,6 +20,3 @@ requires = ["hatchling"] build-backend = "hatchling.build" [tool.uv.sources] - -nuc = { git = "https://github.com/NillionNetwork/nuc-py.git", rev = "4922b5e9354e611cc31322d681eb29da05be584e" } - diff --git a/nilai-auth/nuc-helpers/src/nuc_helpers/helpers.py b/nilai-auth/nuc-helpers/src/nuc_helpers/helpers.py index 9af37535..cd631c2d 100644 --- a/nilai-auth/nuc-helpers/src/nuc_helpers/helpers.py +++ b/nilai-auth/nuc-helpers/src/nuc_helpers/helpers.py @@ -189,6 +189,8 @@ def get_delegation_token( user_public_key: NilAuthPublicKey, usage_limit: int | None = None, expires_at: datetime.datetime | None = None, + document_id: str | None = None, + document_owner_did: str | None = None, ) -> DelegationToken: """ Delegate the root token to the delegated key @@ -200,6 +202,10 @@ def get_delegation_token( Returns: The delegation token """ + if bool(document_id) != bool(document_owner_did): + raise ValueError( + f"If Document ID or document owner DID provided, the other must also be provided: Document ID: {document_id} Document Owner DID: {document_owner_did}" + ) root_token_envelope = NucTokenEnvelope.parse(root_token.token) delegated_token = ( @@ -212,7 +218,13 @@ def get_delegation_token( ) .audience(Did(user_public_key.serialize())) .command(Command(["nil", "ai", "generate"])) - .meta({"usage_limit": usage_limit}) + .meta( + { + "usage_limit": usage_limit, + "document_id": document_id, + "document_owner_did": document_owner_did, + } + ) .build(private_key) ) return DelegationToken(token=delegated_token) diff --git a/nilai-auth/nuc-helpers/src/nuc_helpers/nildb_document.py b/nilai-auth/nuc-helpers/src/nuc_helpers/nildb_document.py new file mode 100644 index 00000000..3a3fa35c --- /dev/null +++ b/nilai-auth/nuc-helpers/src/nuc_helpers/nildb_document.py @@ -0,0 +1,64 @@ +from functools import lru_cache +from typing import Optional +from nuc.envelope import NucTokenEnvelope +import logging +from pydantic import BaseModel + +from nuc.token import Did + +logger = logging.getLogger(__name__) + + +class PromptDocument(BaseModel): + document_id: str + owner_did: str + + @staticmethod + @lru_cache(maxsize=128) + def from_token(token: str) -> Optional["PromptDocument"]: + """ + Extracts the prompt_document_id from the NUC token if there is one. + + This serves to determine which document if there is one to extract from nilDB to be used as a prompt + + + This function parses the provided token and inspects all associated proofs from upwards down to the invocation + token (if present) to determine the applicable document. The behavior is as follows: + + - The invocation token is never considered as it is created by the user. + - The uppermost token containing a `document_id` in their metadata is the one considered. + - If two `document_id` are present, only the uppermost in the chain is considered. + + The function is cached based on the token string to avoid redundant parsing and validation. + + Note: This function is cached, so it will return the same result for the same token string. + If you need to invalidate the cache, call `get_usage_limit.cache_clear()`. + + + Args: + token (str): The serialized delegation token. + + Returns: + PromptDocumentId: The document_id and the issuer did to be matched to the database + """ + token_envelope = NucTokenEnvelope.parse(token) + + # Iterate over proofs and collect the first document_id found together with issuer. + for i, proof in enumerate(token_envelope.proofs[::-1]): + meta = proof.token.meta if proof.token else None + logger.debug(f"Proof {i} meta: {meta}") + if ( + meta is not None + and meta.get("document_id", None) is not None + and meta.get("document_owner_did", None) is not None + ): + if Did.parse(meta["document_owner_did"]) != proof.token.issuer: + raise ValueError( + f"Document owner DID {meta['document_owner_did']} does not match issuer {proof.token.issuer}" + ) + return PromptDocument( + document_id=meta["document_id"], + owner_did=meta["document_owner_did"], + ) + + return None diff --git a/nilai-auth/nuc-helpers/src/nuc_helpers/usage.py b/nilai-auth/nuc-helpers/src/nuc_helpers/usage.py index 0f21ae41..a9f4467b 100644 --- a/nilai-auth/nuc-helpers/src/nuc_helpers/usage.py +++ b/nilai-auth/nuc-helpers/src/nuc_helpers/usage.py @@ -85,7 +85,7 @@ def from_token(token: str) -> Optional["TokenRateLimits"]: token (str): The serialized delegation token. Returns: - Tuple[str, str, Optional[int]]: The signature, the effective usage limit, and the expiration date, or `None` if no usage limit is found. + TokenRateLimit: The signature, the effective usage limit, and the expiration date, or `None` if no usage limit is found. Raises: UsageLimitInconsistencyError: If usage limits across proofs or invocation are inconsistent. diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 9ab0962d..80f20a8a 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -166,9 +166,9 @@ class WebSearchContext(BaseModel): class ChatRequest(BaseModel): model: str messages: List[Message] = Field(..., min_length=1) - temperature: Optional[float] = Field(default=0.2, ge=0.0, le=5.0) - top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0) - max_tokens: Optional[int] = Field(default=2048, ge=1, le=100000) + temperature: Optional[float] = Field(default=None, ge=0.0, le=5.0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(default=None, ge=1, le=100000) stream: Optional[bool] = False tools: Optional[Iterable[ChatCompletionToolParam]] = None nilrag: Optional[dict] = {} diff --git a/scripts/docker-composer.py b/scripts/docker-composer.py index 834dc2fb..964b0926 100755 --- a/scripts/docker-composer.py +++ b/scripts/docker-composer.py @@ -205,9 +205,16 @@ def restore_files_variable(output_file, files_placeholder): print("Restored ${FILES} variable") -def process_compose_yaml(output_file): +def process_compose_yaml(output_file, preserve_volumes=False): """Process the compose YAML file to remove volumes and convert bind mount formats""" - print("Processing compose YAML for volume removal and bind mount conversions...") + if preserve_volumes: + print( + "Processing compose YAML for bind mount conversions (preserving volumes)..." + ) + else: + print( + "Processing compose YAML for volume removal and bind mount conversions..." + ) with open(output_file, "r") as f: content = f.read() @@ -216,8 +223,8 @@ def process_compose_yaml(output_file): # Parse YAML compose_data = yaml.safe_load(content) - # Remove global volumes section entirely - if "volumes" in compose_data: + # Remove global volumes section entirely (unless preserving volumes) + if "volumes" in compose_data and not preserve_volumes: print("Removing global volumes section...") del compose_data["volumes"] @@ -245,11 +252,15 @@ def process_compose_yaml(output_file): else: new_volumes.append(f"{source}:{target}") elif volume.get("type") == "volume": - # Remove volume mounts entirely - print( - f"Removing volume mount from service {service_name}: {volume}" - ) - continue + if preserve_volumes: + # Keep volume mounts when preserving volumes + new_volumes.append(volume) + else: + # Remove volume mounts entirely + print( + f"Removing volume mount from service {service_name}: {volume}" + ) + continue else: # Keep other types as-is new_volumes.append(volume) @@ -267,17 +278,25 @@ def process_compose_yaml(output_file): # It's a bind mount (absolute path, relative path, or variable) new_volumes.append(volume) else: - # It's a volume mount (named volume) + if preserve_volumes: + # Keep volume mount when preserving volumes + new_volumes.append(volume) + else: + # It's a volume mount (named volume) + print( + f"Removing volume mount from service {service_name}: {volume}" + ) + continue + else: + if preserve_volumes: + # Keep volume mount when preserving volumes + new_volumes.append(volume) + else: + # Single name without colon - likely a volume mount print( f"Removing volume mount from service {service_name}: {volume}" ) continue - else: - # Single name without colon - likely a volume mount - print( - f"Removing volume mount from service {service_name}: {volume}" - ) - continue if new_volumes: service_config["volumes"] = new_volumes else: @@ -370,7 +389,7 @@ def main(): restore_files_variable(args.output, files_placeholder) # Process YAML for volume and mount conversions - process_compose_yaml(args.output) + process_compose_yaml(args.output, preserve_volumes=args.dev) else: # Generate config and apply modifications temp_file = f"{args.output}.tmp" @@ -385,7 +404,7 @@ def main(): restore_files_variable(args.output, files_placeholder) # Process YAML for volume and mount conversions - process_compose_yaml(args.output) + process_compose_yaml(args.output, preserve_volumes=args.dev) # Apply image substitutions apply_image_substitutions(args.output, image_substitutions) diff --git a/tests/e2e/nuc.py b/tests/e2e/nuc.py index 491c4fcf..9259baf6 100644 --- a/tests/e2e/nuc.py +++ b/tests/e2e/nuc.py @@ -4,7 +4,7 @@ pay_for_subscription, get_root_token, get_nilai_public_key, - get_invocation_token, + get_invocation_token as nuc_helpers_get_invocation_token, validate_token, InvocationToken, RootToken, @@ -17,11 +17,19 @@ from nuc.token import Did from nuc.validate import ValidationParameters, InvocationRequirement +# These correspond to the key used to test with nilAuth. Otherwise the OWNER DID would not match the issuer +DOCUMENT_ID = "bb93f3a4-ba4c-4e20-8f2e-c0650c75a372" +DOCUMENT_OWNER_DID = ( + "did:nil:030923f2e7120c50e42905b857ddd2947f6ecced6bb02aab64e63b28e9e2e06d10" +) + def get_nuc_token( usage_limit: int | None = None, expires_at: datetime | None = None, blind_module: BlindModule = BlindModule.NILAI, + document_id: str | None = None, + document_owner_did: str | None = None, create_delegation: bool = False, create_invalid_delegation: bool = False, ) -> InvocationToken: @@ -48,6 +56,8 @@ def get_nuc_token( server_wallet, server_keypair, server_private_key = get_wallet_and_private_key( PRIVATE_KEY ) + + print("Public key: ", server_private_key.pubkey) nilauth_client = NilauthClient(f"http://{NILAUTH_ENDPOINT}") if not server_private_key.pubkey: @@ -99,6 +109,8 @@ def get_nuc_token( user_public_key, usage_limit=delegation_usage_limit, expires_at=delegation_expires_at, + document_id=document_id, + document_owner_did=document_owner_did, ) # Create invalid delegation chain if requested (for testing) @@ -109,17 +121,19 @@ def get_nuc_token( user_public_key, usage_limit=5, expires_at=datetime.now(timezone.utc) + timedelta(minutes=5), + document_id=document_id, + document_owner_did=document_owner_did, ) # Create invocation token from delegation - invocation_token: InvocationToken = get_invocation_token( + invocation_token: InvocationToken = nuc_helpers_get_invocation_token( delegation_token, nilai_public_key, user_private_key, ) else: # Create invocation token directly from root token - invocation_token: InvocationToken = get_invocation_token( + invocation_token: InvocationToken = nuc_helpers_get_invocation_token( root_token, nilai_public_key, server_private_key, @@ -149,6 +163,16 @@ def get_rate_limited_nuc_token(rate_limit: int = 3) -> InvocationToken: ) +def get_document_id_nuc_token() -> InvocationToken: + """Convenience function for getting NILDB NUC tokens.""" + print("DOCUMENT_ID", DOCUMENT_ID) + return get_nuc_token( + create_delegation=True, + document_id=DOCUMENT_ID, + document_owner_did=DOCUMENT_OWNER_DID, + ) + + def get_invalid_rate_limited_nuc_token() -> InvocationToken: """Convenience function for getting invalid rate-limited tokens (for testing).""" return get_nuc_token( diff --git a/tests/e2e/test_http.py b/tests/e2e/test_http.py index 9e105366..e8354f51 100644 --- a/tests/e2e/test_http.py +++ b/tests/e2e/test_http.py @@ -10,11 +10,13 @@ import json + from .config import BASE_URL, test_models, AUTH_STRATEGY, api_key_getter from .nuc import ( get_rate_limited_nuc_token, get_invalid_rate_limited_nuc_token, get_nildb_nuc_token, + get_document_id_nuc_token, ) import httpx import pytest @@ -99,6 +101,22 @@ def nillion_2025_client(): ) +@pytest.fixture +def document_id_client(): + """Create an HTTPX client with default headers""" + invocation_token = get_document_id_nuc_token() + return httpx.Client( + base_url=BASE_URL, + headers={ + "accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {invocation_token.token}", + }, + verify=False, + timeout=None, + ) + + def test_health_endpoint(client): """Test the health endpoint""" response = client.get("health") @@ -791,3 +809,71 @@ def test_model_streaming_request_high_token(client): assert chunk_count > 0, ( "Should receive at least one chunk for high token streaming request" ) + + +@pytest.mark.skipif( + AUTH_STRATEGY != "nuc", reason="NUC required for this tests on nilDB" +) +def test_nildb_delegation(client: httpx.Client): + """Tests getting a delegation token for nilDB and validating that token to be valid""" + from secretvaults.common.keypair import Keypair + from nuc.envelope import NucTokenEnvelope + from nuc.validate import NucTokenValidator, ValidationParameters + from nuc.nilauth import NilauthClient + from nilai_api.handlers.nildb.config import CONFIG + from nuc.token import Did + + keypair = Keypair.generate() + did = keypair.to_did_string() + + response = client.get("/delegation", params={"prompt_delegation_request": did}) + + assert response.status_code == 200, ( + f"Delegation token should be returned: {response.text}" + ) + assert "token" in response.json(), "Delegation token should be returned" + assert "did" in response.json(), "Delegation did should be returned" + token = response.json()["token"] + did = response.json()["did"] + assert token is not None, "Delegation token should be returned" + assert did is not None, "Delegation did should be returned" + + # Validate the token with nilAuth url for nilDB + nuc_token_envelope = NucTokenEnvelope.parse(token) + nilauth_public_keys = [ + Did(NilauthClient(CONFIG.NILAUTH_URL).about().public_key.serialize()) + ] + NucTokenValidator(nilauth_public_keys).validate( + nuc_token_envelope, context={}, parameters=ValidationParameters.default() + ) + + +@pytest.mark.parametrize( + "model", + test_models, +) +@pytest.mark.skipif( + AUTH_STRATEGY != "nuc", reason="NUC required for this tests on nilDB" +) +def test_nildb_prompt_document(document_id_client: httpx.Client, model): + """Tests getting a prompt document from nilDB and executing a chat completion with it""" + payload = { + "model": model, + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + {"role": "user", "content": "Can you make a small rhyme?"}, + ], + "temperature": 0.2, + } + + response = document_id_client.post("/chat/completions", json=payload, timeout=30) + + assert response.status_code == 200, ( + f"Response should be successful: {response.text}" + ) + # Response must talk about cheese which is what the prompt document contains + message: str = response.json()["choices"][0].get("message", {}).get("content", None) + assert "cheese" in message.lower(), "Response should contain cheese" diff --git a/tests/unit/nilai_api/auth/test_nuc.py b/tests/unit/nilai_api/auth/test_nuc.py new file mode 100644 index 00000000..a3785554 --- /dev/null +++ b/tests/unit/nilai_api/auth/test_nuc.py @@ -0,0 +1,40 @@ +import pytest +from unittest.mock import patch +from nilai_api.auth.nuc import get_token_prompt_document +from nilai_api.auth.common import PromptDocument + + +class TestNucAuthFunctions: + """Test class for NUC authentication functions""" + + @patch("nilai_api.auth.nuc.PromptDocument.from_token") + def test_get_token_prompt_document_success(self, mock_from_token): + """Test successful prompt document extraction""" + mock_prompt_doc = PromptDocument( + document_id="test-doc-123", owner_did=f"did:nil:{'1' * 66}" + ) + mock_from_token.return_value = mock_prompt_doc + + result = get_token_prompt_document("test_token") + + assert result == mock_prompt_doc + mock_from_token.assert_called_once_with("test_token") + + @patch("nilai_api.auth.nuc.PromptDocument.from_token") + def test_get_token_prompt_document_none(self, mock_from_token): + """Test when no prompt document is found""" + mock_from_token.return_value = None + + result = get_token_prompt_document("test_token") + + assert result is None + mock_from_token.assert_called_once_with("test_token") + + @patch("nilai_api.auth.nuc.PromptDocument.from_token") + def test_get_token_prompt_document_exception(self, mock_from_token): + """Test when PromptDocument.from_token raises an exception""" + mock_from_token.side_effect = Exception("Token parsing failed") + + # The function should let the exception bubble up + with pytest.raises(Exception, match="Token parsing failed"): + get_token_prompt_document("invalid_token") diff --git a/tests/unit/nilai_api/auth/test_strategies.py b/tests/unit/nilai_api/auth/test_strategies.py new file mode 100644 index 00000000..88d85c3b --- /dev/null +++ b/tests/unit/nilai_api/auth/test_strategies.py @@ -0,0 +1,308 @@ +import pytest +from unittest.mock import patch, MagicMock +from datetime import datetime, timezone, timedelta +from fastapi import HTTPException + +from nilai_api.auth.strategies import api_key_strategy, jwt_strategy, nuc_strategy +from nilai_api.auth.common import AuthenticationInfo, PromptDocument +from nilai_api.db.users import UserModel + + +class TestAuthStrategies: + """Test class for authentication strategies with nilDB integration""" + + @pytest.fixture + def mock_user_model(self): + """Mock UserModel fixture""" + mock = MagicMock(spec=UserModel) + mock.name = "Test User" + mock.userid = "test-user-id" + mock.apikey = "test-api-key" + mock.prompt_tokens = 0 + mock.completion_tokens = 0 + mock.queries = 0 + mock.signup_date = datetime.now(timezone.utc) + mock.last_activity = datetime.now(timezone.utc) + mock.ratelimit_day = 1000 + mock.ratelimit_hour = 1000 + mock.ratelimit_minute = 1000 + mock.web_search_ratelimit_day = 100 + mock.web_search_ratelimit_hour = 50 + mock.web_search_ratelimit_minute = 10 + return mock + + @pytest.fixture + def mock_prompt_document(self): + """Mock PromptDocument fixture""" + return PromptDocument( + document_id="test-document-123", owner_did=f"did:nil:{'1' * 66}" + ) + + @pytest.mark.asyncio + async def test_api_key_strategy_success(self, mock_user_model): + """Test successful API key authentication""" + with patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check: + mock_check.return_value = mock_user_model + + result = await api_key_strategy("test-api-key") + + assert isinstance(result, AuthenticationInfo) + assert result.user.name == "Test User" + assert result.token_rate_limit is None + assert result.prompt_document is None + + @pytest.mark.asyncio + async def test_api_key_strategy_invalid_key(self): + """Test API key authentication with invalid key""" + with patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check: + mock_check.return_value = None + + with pytest.raises(HTTPException) as exc_info: + await api_key_strategy("invalid-key") + + assert exc_info.value.status_code == 401 + assert "Missing or invalid API key" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_jwt_strategy_existing_user(self, mock_user_model): + """Test JWT authentication with existing user""" + with ( + patch("nilai_api.auth.strategies.validate_jwt") as mock_validate, + patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check, + ): + mock_jwt_result = MagicMock() + mock_jwt_result.user_address = "test-address" + mock_jwt_result.pub_key = "test-pub-key" + mock_validate.return_value = mock_jwt_result + mock_check.return_value = mock_user_model + + result = await jwt_strategy("jwt-token") + + assert isinstance(result, AuthenticationInfo) + assert result.user.name == "Test User" + assert result.token_rate_limit is None + assert result.prompt_document is None + + @pytest.mark.asyncio + async def test_jwt_strategy_new_user(self): + """Test JWT authentication creating new user""" + with ( + patch("nilai_api.auth.strategies.validate_jwt") as mock_validate, + patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check, + patch( + "nilai_api.auth.strategies.UserManager.insert_user_model" + ) as mock_insert, + ): + mock_jwt_result = MagicMock() + mock_jwt_result.user_address = "new-user-address" + mock_jwt_result.pub_key = "new-user-pub-key" + mock_validate.return_value = mock_jwt_result + mock_check.return_value = None + mock_insert.return_value = None + + result = await jwt_strategy("jwt-token") + + assert isinstance(result, AuthenticationInfo) + assert result.token_rate_limit is None + assert result.prompt_document is None + mock_insert.assert_called_once() + + @pytest.mark.asyncio + async def test_nuc_strategy_existing_user_with_prompt_document( + self, mock_user_model, mock_prompt_document + ): + """Test NUC authentication with existing user and prompt document""" + with ( + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch( + "nilai_api.auth.strategies.get_token_rate_limit" + ) as mock_get_rate_limit, + patch( + "nilai_api.auth.strategies.get_token_prompt_document" + ) as mock_get_prompt_doc, + patch( + "nilai_api.auth.strategies.UserManager.check_user" + ) as mock_check_user, + ): + mock_validate.return_value = ("subscription_holder", "user_id") + mock_get_rate_limit.return_value = None + mock_get_prompt_doc.return_value = mock_prompt_document + mock_check_user.return_value = mock_user_model + + result = await nuc_strategy("nuc-token") + + assert isinstance(result, AuthenticationInfo) + assert result.user.name == "Test User" + assert result.token_rate_limit is None + assert result.prompt_document == mock_prompt_document + + @pytest.mark.asyncio + async def test_nuc_strategy_new_user_with_token_limits(self, mock_prompt_document): + """Test NUC authentication creating new user with token limits""" + from nuc_helpers.usage import TokenRateLimits, TokenRateLimit + + mock_token_limits = TokenRateLimits( + limits=[ + TokenRateLimit( + signature="test-signature", + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + usage_limit=1, + ) + ] + ) + + with ( + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch( + "nilai_api.auth.strategies.get_token_rate_limit" + ) as mock_get_rate_limit, + patch( + "nilai_api.auth.strategies.get_token_prompt_document" + ) as mock_get_prompt_doc, + patch( + "nilai_api.auth.strategies.UserManager.check_user" + ) as mock_check_user, + patch( + "nilai_api.auth.strategies.UserManager.insert_user_model" + ) as mock_insert, + ): + mock_validate.return_value = ("subscription_holder", "new_user_id") + mock_get_rate_limit.return_value = mock_token_limits + mock_get_prompt_doc.return_value = mock_prompt_document + mock_check_user.return_value = None + mock_insert.return_value = None + + result = await nuc_strategy("nuc-token") + + assert isinstance(result, AuthenticationInfo) + assert result.token_rate_limit == mock_token_limits + assert result.prompt_document == mock_prompt_document + mock_insert.assert_called_once() + + @pytest.mark.asyncio + async def test_nuc_strategy_no_prompt_document(self, mock_user_model): + """Test NUC authentication when no prompt document is found""" + with ( + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch( + "nilai_api.auth.strategies.get_token_rate_limit" + ) as mock_get_rate_limit, + patch( + "nilai_api.auth.strategies.get_token_prompt_document" + ) as mock_get_prompt_doc, + patch( + "nilai_api.auth.strategies.UserManager.check_user" + ) as mock_check_user, + ): + mock_validate.return_value = ("subscription_holder", "user_id") + mock_get_rate_limit.return_value = None + mock_get_prompt_doc.return_value = None + mock_check_user.return_value = mock_user_model + + result = await nuc_strategy("nuc-token") + + assert isinstance(result, AuthenticationInfo) + assert result.user.name == "Test User" + assert result.token_rate_limit is None + assert result.prompt_document is None + + @pytest.mark.asyncio + async def test_nuc_strategy_validation_error(self): + """Test NUC authentication when validation fails""" + with patch("nilai_api.auth.strategies.validate_nuc") as mock_validate: + mock_validate.side_effect = Exception("Invalid NUC token") + + with pytest.raises(Exception, match="Invalid NUC token"): + await nuc_strategy("invalid-nuc-token") + + @pytest.mark.asyncio + async def test_nuc_strategy_get_prompt_document_error(self, mock_user_model): + """Test NUC authentication when get_token_prompt_document fails""" + with ( + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch( + "nilai_api.auth.strategies.get_token_rate_limit" + ) as mock_get_rate_limit, + patch( + "nilai_api.auth.strategies.get_token_prompt_document" + ) as mock_get_prompt_doc, + patch( + "nilai_api.auth.strategies.UserManager.check_user" + ) as mock_check_user, + ): + mock_validate.return_value = ("subscription_holder", "user_id") + mock_get_rate_limit.return_value = None + mock_get_prompt_doc.side_effect = Exception( + "Prompt document extraction failed" + ) + mock_check_user.return_value = mock_user_model + + # The function should let the exception bubble up or handle it gracefully + # Based on the diff, it looks like it doesn't catch exceptions from get_token_prompt_document + with pytest.raises(Exception, match="Prompt document extraction failed"): + await nuc_strategy("nuc-token") + + @pytest.mark.asyncio + async def test_all_strategies_return_authentication_info_with_prompt_document_field( + self, + ): + """Test that all strategies return AuthenticationInfo with prompt_document field""" + mock_user_model = MagicMock(spec=UserModel) + mock_user_model.name = "Test" + mock_user_model.userid = "test" + mock_user_model.apikey = "test" + mock_user_model.prompt_tokens = 0 + mock_user_model.completion_tokens = 0 + mock_user_model.queries = 0 + mock_user_model.signup_date = datetime.now(timezone.utc) + mock_user_model.last_activity = datetime.now(timezone.utc) + mock_user_model.ratelimit_day = 1000 + mock_user_model.ratelimit_hour = 1000 + mock_user_model.ratelimit_minute = 1000 + mock_user_model.web_search_ratelimit_day = 100 + mock_user_model.web_search_ratelimit_hour = 50 + mock_user_model.web_search_ratelimit_minute = 10 + + # Test API key strategy + with patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check: + mock_check.return_value = mock_user_model + result = await api_key_strategy("test-key") + assert hasattr(result, "prompt_document") + assert result.prompt_document is None + + # Test JWT strategy + with ( + patch("nilai_api.auth.strategies.validate_jwt") as mock_validate, + patch("nilai_api.auth.strategies.UserManager.check_api_key") as mock_check, + ): + mock_jwt_result = MagicMock() + mock_jwt_result.user_address = "test-address" + mock_jwt_result.pub_key = "test-pub-key" + mock_validate.return_value = mock_jwt_result + mock_check.return_value = mock_user_model + + result = await jwt_strategy("jwt-token") + assert hasattr(result, "prompt_document") + assert result.prompt_document is None + + # Test NUC strategy + with ( + patch("nilai_api.auth.strategies.validate_nuc") as mock_validate, + patch( + "nilai_api.auth.strategies.get_token_rate_limit" + ) as mock_get_rate_limit, + patch( + "nilai_api.auth.strategies.get_token_prompt_document" + ) as mock_get_prompt_doc, + patch( + "nilai_api.auth.strategies.UserManager.check_user" + ) as mock_check_user, + ): + mock_validate.return_value = ("subscription_holder", "user_id") + mock_get_rate_limit.return_value = None + mock_get_prompt_doc.return_value = None + mock_check_user.return_value = mock_user_model + + result = await nuc_strategy("nuc-token") + assert hasattr(result, "prompt_document") + assert result.prompt_document is None diff --git a/tests/unit/nilai_api/conftest.py b/tests/unit/nilai_api/conftest.py index 245b6a9a..fe00a029 100644 --- a/tests/unit/nilai_api/conftest.py +++ b/tests/unit/nilai_api/conftest.py @@ -1,8 +1,19 @@ import pytest +from unittest.mock import patch, MagicMock from testcontainers.redis import RedisContainer from nilai_api import config +@pytest.fixture(scope="session", autouse=True) +def mock_sentence_transformer(): + """Mock SentenceTransformer to avoid downloading models during tests.""" + mock_model = MagicMock() + mock_model.encode.return_value = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + with patch("sentence_transformers.SentenceTransformer", return_value=mock_model): + yield mock_model + + @pytest.fixture(scope="session", autouse=True) def redis_server(): container = RedisContainer() diff --git a/tests/unit/nilai_api/handlers/__init__.py b/tests/unit/nilai_api/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/nilai_api/handlers/test_nildb_handler.py b/tests/unit/nilai_api/handlers/test_nildb_handler.py new file mode 100644 index 00000000..850aaa94 --- /dev/null +++ b/tests/unit/nilai_api/handlers/test_nildb_handler.py @@ -0,0 +1,377 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from nilai_api.handlers.nildb.handler import ( + get_nildb_delegation_token, + get_prompt_from_nildb, + create_builder_client, + create_user_client, +) +from nilai_api.handlers.nildb.api_model import PromptDelegationToken +from nilai_api.auth.common import PromptDocument +from secretvaults.common.types import Uuid + + +class TestNilDBHandler: + """Test class for nilDB handler functions""" + + @pytest.fixture + def mock_config(self): + """Mock configuration for tests""" + with patch("nilai_api.handlers.nildb.handler.CONFIG") as mock_config: + mock_config.NILCHAIN_URL = "http://test-nilchain.com" + mock_config.NILAUTH_URL = "http://test-nilauth.com" + mock_config.NODES = ["http://node1.com", "http://node2.com"] + mock_config.BUILDER_PRIVATE_KEY = "0x1234567890abcdef" + mock_config.COLLECTION = Uuid("12345678-1234-1234-1234-123456789012") + yield mock_config + + @pytest.fixture + def mock_prompt_document(self): + """Mock PromptDocument for tests""" + return PromptDocument( + document_id="test-document-123", owner_did="did:nil:" + "1" * 66 + ) + + @pytest.fixture + def mock_keypair(self): + """Mock keypair for tests""" + mock_keypair = MagicMock() + mock_keypair.private_key.return_value = "mock_private_key" + mock_keypair.to_did_string.return_value = "did:nil:builder123" + return mock_keypair + + @pytest.fixture + def mock_builder_client(self, mock_keypair): + """Mock builder client for tests""" + client = MagicMock() + client.keypair = mock_keypair + + # Mock the root_token to be a proper envelope-like object + mock_envelope = MagicMock() + mock_token = MagicMock() + mock_envelope.token.token = mock_token + client.root_token = mock_envelope + + client.refresh_root_token = AsyncMock() + return client + + @pytest.fixture + def mock_user_client(self): + """Mock user client for tests""" + client = MagicMock() + client.read_data = AsyncMock() + return client + + @pytest.mark.asyncio + async def test_create_builder_client(self, mock_config): + """Test creating builder client""" + with ( + patch( + "secretvaults.common.keypair.Keypair.from_hex" + ) as mock_keypair_from_hex, + patch( + "nilai_api.handlers.nildb.handler.SecretVaultBuilderClient.from_options" + ) as mock_from_options, + ): + mock_keypair = MagicMock() + mock_keypair_from_hex.return_value = mock_keypair + + mock_client = MagicMock() + mock_client.refresh_root_token = AsyncMock() + mock_from_options.return_value = mock_client + + # Clear the cache first + + result = await create_builder_client() + + mock_keypair_from_hex.assert_called_once_with( + mock_config.BUILDER_PRIVATE_KEY + ) + mock_from_options.assert_called_once() + mock_client.refresh_root_token.assert_called_once() + assert result == mock_client + + @pytest.mark.asyncio + async def test_create_user_client(self, mock_config): + """Test creating user client""" + with ( + patch( + "secretvaults.common.keypair.Keypair.from_hex" + ) as mock_keypair_from_hex, + patch( + "nilai_api.handlers.nildb.handler.SecretVaultUserClient.from_options" + ) as mock_from_options, + ): + mock_keypair = MagicMock() + mock_keypair_from_hex.return_value = mock_keypair + + mock_client = MagicMock() + mock_from_options.return_value = mock_client + + # Clear the cache first + + result = await create_user_client() + + mock_keypair_from_hex.assert_called_once_with( + mock_config.BUILDER_PRIVATE_KEY + ) + mock_from_options.assert_called_once() + assert result == mock_client + + @pytest.mark.asyncio + async def test_get_nildb_delegation_token_success( + self, mock_config, mock_builder_client + ): + """Test successful delegation token generation""" + user_did = f"did:nil:{'1' * 66}" + + with ( + patch( + "nilai_api.handlers.nildb.handler.create_builder_client", + new_callable=AsyncMock, + ) as mock_create_builder, + patch( + "nilai_api.handlers.nildb.handler.into_seconds_from_now" + ) as mock_into_seconds, + ): + mock_create_builder.return_value = mock_builder_client + mock_into_seconds.return_value = 1234567890 + + # Mock the entire NucTokenBuilder class to return a string for the chain + with patch( + "nilai_api.handlers.nildb.handler.NucTokenBuilder" + ) as mock_token_builder: + mock_builder_chain = MagicMock() + mock_builder_chain.command.return_value = mock_builder_chain + mock_builder_chain.audience.return_value = mock_builder_chain + mock_builder_chain.expires_at.return_value = mock_builder_chain + mock_builder_chain.build.return_value = "delegation_token" + + mock_token_builder.extending.return_value = mock_builder_chain + + result = await get_nildb_delegation_token(user_did) + + assert isinstance(result, PromptDelegationToken) + assert result.token == "delegation_token" + assert result.did == "did:nil:builder123" + + mock_token_builder.extending.assert_called_once() + + @pytest.mark.asyncio + async def test_get_nildb_delegation_token_no_root_token(self, mock_config): + """Test delegation token generation when no root token is available""" + user_did = f"did:nil:{'1' * 66}" + + with patch( + "nilai_api.handlers.nildb.handler.create_builder_client", + new_callable=AsyncMock, + ) as mock_create_builder: + mock_builder = MagicMock() + mock_builder.root_token = None + mock_create_builder.return_value = mock_builder + + with pytest.raises( + ValueError, match="Couldn't extract root NUC token from nilDB profile" + ): + await get_nildb_delegation_token(user_did) + + @pytest.mark.asyncio + async def test_get_prompt_from_nildb_success( + self, mock_config, mock_prompt_document, mock_user_client + ): + """Test successful prompt retrieval from nilDB""" + with patch( + "nilai_api.handlers.nildb.handler.create_user_client", + new_callable=AsyncMock, + ) as mock_create_user: + mock_create_user.return_value = mock_user_client + + # Mock successful document response + class MockData: + def __init__(self): + self.owner = "did:nil:" + "1" * 66 + self.prompt = "This is a test prompt" + + class MockResponse: + def __init__(self): + self.data = MockData() + + mock_response = MockResponse() + + mock_user_client.read_data.return_value = mock_response + + result = await get_prompt_from_nildb(mock_prompt_document) + + assert result == "This is a test prompt" + mock_user_client.read_data.assert_called_once() + + @pytest.mark.asyncio + async def test_get_prompt_from_nildb_no_response( + self, mock_config, mock_prompt_document, mock_user_client + ): + """Test prompt retrieval when no response is received""" + with patch( + "nilai_api.handlers.nildb.handler.create_user_client", + new_callable=AsyncMock, + ) as mock_create_user: + mock_create_user.return_value = mock_user_client + mock_user_client.read_data.return_value = None + + with pytest.raises( + ValueError, match="Couldn't get document response from nilDB nodes" + ): + await get_prompt_from_nildb(mock_prompt_document) + + @pytest.mark.asyncio + async def test_get_prompt_from_nildb_wrong_owner( + self, mock_config, mock_prompt_document, mock_user_client + ): + """Test prompt retrieval when document owner doesn't match""" + with patch( + "nilai_api.handlers.nildb.handler.create_user_client", + new_callable=AsyncMock, + ) as mock_create_user: + mock_create_user.return_value = mock_user_client + + # Mock response with different owner + class MockDataWrongOwner: + def __init__(self): + self.owner = "did:nil:" + "2" * 66 + self.prompt = "This is a test prompt" + + class MockResponseWrongOwner: + def __init__(self): + self.data = MockDataWrongOwner() + + mock_response = MockResponseWrongOwner() + + mock_user_client.read_data.return_value = mock_response + + with pytest.raises( + ValueError, + match="Non-owning entity trying to invoke access to a document resource", + ): + await get_prompt_from_nildb(mock_prompt_document) + + @pytest.mark.asyncio + async def test_get_prompt_from_nildb_no_prompt_field( + self, mock_config, mock_prompt_document, mock_user_client + ): + """Test prompt retrieval when prompt field is missing""" + with patch( + "nilai_api.handlers.nildb.handler.create_user_client", + new_callable=AsyncMock, + ) as mock_create_user: + mock_create_user.return_value = mock_user_client + + # Mock response without prompt field + class MockDataNoPrompt: + def __init__(self): + self.owner = "did:nil:" + "1" * 66 + # No prompt attribute + + class MockResponseNoPrompt: + def __init__(self): + self.data = MockDataNoPrompt() + + mock_response = MockResponseNoPrompt() + + mock_user_client.read_data.return_value = mock_response + + with pytest.raises( + ValueError, + match="Couldn't find prompt field in document response from nilDB", + ): + await get_prompt_from_nildb(mock_prompt_document) + + @pytest.mark.asyncio + async def test_get_prompt_from_nildb_null_prompt( + self, mock_config, mock_prompt_document, mock_user_client + ): + """Test prompt retrieval when prompt field is None""" + with patch( + "nilai_api.handlers.nildb.handler.create_user_client", + new_callable=AsyncMock, + ) as mock_create_user: + mock_create_user.return_value = mock_user_client + + # Mock response with None prompt + class MockDataNullPrompt: + def __init__(self): + self.owner = "did:nil:" + "1" * 66 + self.prompt = None + + class MockResponseNullPrompt: + def __init__(self): + self.data = MockDataNullPrompt() + + mock_response = MockResponseNullPrompt() + + mock_user_client.read_data.return_value = mock_response + + with pytest.raises( + ValueError, match="Prompt field is None in document response from nilDB" + ): + await get_prompt_from_nildb(mock_prompt_document) + + @pytest.mark.asyncio + async def test_get_prompt_from_nildb_with_model_dump( + self, mock_config, mock_prompt_document, mock_user_client + ): + """Test prompt retrieval using model_dump method""" + with patch( + "nilai_api.handlers.nildb.handler.create_user_client", + new_callable=AsyncMock, + ) as mock_create_user: + mock_create_user.return_value = mock_user_client + + # Mock response with model_dump method + class MockDataModelDump: + def model_dump(self): + return { + "owner": "did:nil:" + "1" * 66, + "prompt": "Test prompt from model_dump", + } + + # Don't provide __dict__ by overriding __getattribute__ + def __getattribute__(self, name): + if name == "__dict__": + raise AttributeError("__dict__") + return super().__getattribute__(name) + + class MockResponseModelDump: + def __init__(self): + self.data = MockDataModelDump() + + mock_response = MockResponseModelDump() + mock_user_client.read_data.return_value = mock_response + + result = await get_prompt_from_nildb(mock_prompt_document) + + assert result == "Test prompt from model_dump" + + @pytest.mark.asyncio + async def test_get_prompt_from_nildb_direct_response( + self, mock_config, mock_prompt_document, mock_user_client + ): + """Test prompt retrieval with direct response (no data attribute)""" + with patch( + "nilai_api.handlers.nildb.handler.create_user_client", + new_callable=AsyncMock, + ) as mock_create_user: + mock_create_user.return_value = mock_user_client + + # Create a simple object to act as direct response + class MockDirectResponse: + def __init__(self): + self.owner = "did:nil:" + "1" * 66 + self.prompt = "Direct response prompt" + self.data = None + + mock_response = MockDirectResponse() + + mock_user_client.read_data.return_value = mock_response + + result = await get_prompt_from_nildb(mock_prompt_document) + + assert result == "Direct response prompt" diff --git a/tests/unit/nilai_api/routers/test_nildb_endpoints.py b/tests/unit/nilai_api/routers/test_nildb_endpoints.py new file mode 100644 index 00000000..d2702d7a --- /dev/null +++ b/tests/unit/nilai_api/routers/test_nildb_endpoints.py @@ -0,0 +1,409 @@ +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +from fastapi import HTTPException, status + +from nilai_api.auth.common import AuthenticationInfo, PromptDocument +from nilai_api.db.users import UserData, UserModel +from nilai_api.handlers.nildb.api_model import ( + PromptDelegationToken, +) +from datetime import datetime, timezone + + +class TestNilDBEndpoints: + """Test class for nilDB-related API endpoints""" + + @pytest.fixture + def mock_subscription_owner_user(self): + """Mock user data for subscription owner""" + mock_user_model = MagicMock(spec=UserModel) + mock_user_model.name = "Subscription Owner" + mock_user_model.userid = "owner-id" + mock_user_model.apikey = "owner-id" # Same as userid for subscription owner + mock_user_model.prompt_tokens = 0 + mock_user_model.completion_tokens = 0 + mock_user_model.queries = 0 + mock_user_model.signup_date = datetime.now(timezone.utc) + mock_user_model.last_activity = datetime.now(timezone.utc) + mock_user_model.ratelimit_day = 1000 + mock_user_model.ratelimit_hour = 100 + mock_user_model.ratelimit_minute = 10 + mock_user_model.web_search_ratelimit_day = 100 + mock_user_model.web_search_ratelimit_hour = 50 + mock_user_model.web_search_ratelimit_minute = 5 + + return UserData.from_sqlalchemy(mock_user_model) + + @pytest.fixture + def mock_regular_user(self): + """Mock user data for regular user (not subscription owner)""" + mock_user_model = MagicMock(spec=UserModel) + mock_user_model.name = "Regular User" + mock_user_model.userid = "user-id" + mock_user_model.apikey = "different-api-key" # Different from userid + mock_user_model.prompt_tokens = 0 + mock_user_model.completion_tokens = 0 + mock_user_model.queries = 0 + mock_user_model.signup_date = datetime.now(timezone.utc) + mock_user_model.last_activity = datetime.now(timezone.utc) + mock_user_model.ratelimit_day = 1000 + mock_user_model.ratelimit_hour = 100 + mock_user_model.ratelimit_minute = 10 + mock_user_model.web_search_ratelimit_day = 100 + mock_user_model.web_search_ratelimit_hour = 50 + mock_user_model.web_search_ratelimit_minute = 5 + + return UserData.from_sqlalchemy(mock_user_model) + + @pytest.fixture + def mock_auth_info_subscription_owner(self, mock_subscription_owner_user): + """Mock AuthenticationInfo for subscription owner""" + return AuthenticationInfo( + user=mock_subscription_owner_user, + token_rate_limit=None, + prompt_document=None, + ) + + @pytest.fixture + def mock_auth_info_regular_user(self, mock_regular_user): + """Mock AuthenticationInfo for regular user""" + return AuthenticationInfo( + user=mock_regular_user, token_rate_limit=None, prompt_document=None + ) + + @pytest.fixture + def mock_prompt_delegation_token(self): + """Mock PromptDelegationToken""" + return PromptDelegationToken( + token="delegation_token_123", did="did:nil:builder123" + ) + + @pytest.mark.asyncio + async def test_get_prompt_store_delegation_success( + self, mock_auth_info_subscription_owner, mock_prompt_delegation_token + ): + """Test successful delegation token request""" + from nilai_api.routers.private import get_prompt_store_delegation + + with patch( + "nilai_api.routers.private.get_nildb_delegation_token" + ) as mock_get_delegation: + mock_get_delegation.return_value = mock_prompt_delegation_token + + request = "user-123" + + result = await get_prompt_store_delegation( + request, mock_auth_info_subscription_owner + ) + + assert isinstance(result, PromptDelegationToken) + assert result.token == "delegation_token_123" + assert result.did == "did:nil:builder123" + mock_get_delegation.assert_called_once_with("user-123") + + @pytest.mark.asyncio + async def test_get_prompt_store_delegation_forbidden_regular_user( + self, mock_auth_info_regular_user + ): + """Test delegation token request by regular user (not subscription owner)""" + from nilai_api.routers.private import get_prompt_store_delegation + + request = "user-123" + + with pytest.raises(HTTPException) as exc_info: + await get_prompt_store_delegation(request, mock_auth_info_regular_user) + + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN + assert "Prompt storage is reserved to subscription owners" in str( + exc_info.value.detail + ) + + @pytest.mark.asyncio + async def test_get_prompt_store_delegation_handler_error( + self, mock_auth_info_subscription_owner + ): + """Test delegation token request when handler raises an exception""" + from nilai_api.routers.private import get_prompt_store_delegation + + with patch( + "nilai_api.routers.private.get_nildb_delegation_token" + ) as mock_get_delegation: + mock_get_delegation.side_effect = Exception("Handler failed") + + request = "user-123" + + with pytest.raises(HTTPException) as exc_info: + await get_prompt_store_delegation( + request, mock_auth_info_subscription_owner + ) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Server unable to produce delegation tokens: Handler failed" in str( + exc_info.value.detail + ) + + @pytest.mark.asyncio + async def test_chat_completion_with_prompt_document_injection(self): + """Test chat completion with prompt document injection""" + from nilai_api.routers.private import chat_completion + from nilai_common import ChatRequest + + mock_prompt_document = PromptDocument( + document_id="test-doc-123", owner_did="did:nil:" + "1" * 66 + ) + + mock_user = MagicMock() + mock_user.userid = "test-user-id" + mock_user.name = "Test User" + mock_user.apikey = "test-api-key" + + mock_auth_info = AuthenticationInfo( + user=mock_user, token_rate_limit=None, prompt_document=mock_prompt_document + ) + + request = ChatRequest( + model="test-model", messages=[{"role": "user", "content": "Hello"}] + ) + + with ( + patch("nilai_api.routers.private.get_prompt_from_nildb") as mock_get_prompt, + patch("nilai_api.routers.private.AsyncOpenAI") as mock_openai_client, + patch("nilai_api.routers.private.state.get_model") as mock_get_model, + patch("nilai_api.routers.private.handle_nilrag") as mock_handle_nilrag, + patch( + "nilai_api.routers.private.handle_web_search" + ) as mock_handle_web_search, + patch( + "nilai_api.routers.private.UserManager.update_token_usage" + ) as mock_update_usage, + patch( + "nilai_api.routers.private.QueryLogManager.log_query" + ) as mock_log_query, + ): + mock_get_prompt.return_value = "System prompt from nilDB" + + # Mock state.get_model() to return a ModelEndpoint + mock_model_endpoint = MagicMock() + mock_model_endpoint.url = "http://test-model-endpoint" + mock_model_endpoint.metadata.tool_support = True + mock_model_endpoint.metadata.multimodal_support = True + mock_get_model.return_value = mock_model_endpoint + + # Mock handle_nilrag and handle_web_search + mock_handle_nilrag.return_value = None + mock_web_search_result = MagicMock() + mock_web_search_result.messages = request.messages + mock_web_search_result.sources = [] + mock_handle_web_search.return_value = mock_web_search_result + + # Mock async database operations + mock_update_usage.return_value = None + mock_log_query.return_value = None + + # Mock OpenAI client + mock_client_instance = MagicMock() + mock_response = MagicMock() + # Mock the response object that will be awaited + mock_response.model_dump.return_value = { + "id": "test-response-id", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Test response"}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + # Make the create method itself an AsyncMock that returns the response + mock_client_instance.chat.completions.create = AsyncMock( + return_value=mock_response + ) + mock_client_instance.close = AsyncMock() + mock_openai_client.return_value = mock_client_instance + + # + # Call the function (this will test the prompt injection logic) + # Note: We can't easily test the full endpoint without setting up the FastAPI app + # But we can test that get_prompt_from_nildb is called + try: + await chat_completion(req=request, auth_info=mock_auth_info) + except Exception as e: + # Expected to fail due to incomplete mocking, but we should still see the prompt call + print("The exception is: ", str(e)) + raise e + + mock_get_prompt.assert_called_once_with(mock_prompt_document) + + @pytest.mark.asyncio + async def test_chat_completion_prompt_document_extraction_error(self): + """Test chat completion when prompt document extraction fails""" + from nilai_api.routers.private import chat_completion + from nilai_common import ChatRequest + + mock_prompt_document = PromptDocument( + document_id="test-doc-123", owner_did="did:nil:" + "1" * 66 + ) + + mock_user = MagicMock() + mock_user.userid = "test-user-id" + mock_user.name = "Test User" + mock_user.apikey = "test-api-key" + + mock_auth_info = AuthenticationInfo( + user=mock_user, token_rate_limit=None, prompt_document=mock_prompt_document + ) + + request = ChatRequest( + model="test-model", messages=[{"role": "user", "content": "Hello"}] + ) + + with ( + patch("nilai_api.routers.private.get_prompt_from_nildb") as mock_get_prompt, + patch("nilai_api.routers.private.state.get_model") as mock_get_model, + ): + # Mock state.get_model() to return a ModelEndpoint + mock_model_endpoint = MagicMock() + mock_model_endpoint.url = "http://test-model-endpoint" + mock_model_endpoint.metadata.tool_support = True + mock_model_endpoint.metadata.multimodal_support = True + mock_get_model.return_value = mock_model_endpoint + + mock_get_prompt.side_effect = Exception("Unable to extract prompt") + + with pytest.raises(HTTPException) as exc_info: + await chat_completion(req=request, auth_info=mock_auth_info) + + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN + assert ( + "Unable to extract prompt from nilDB: Unable to extract prompt" + in str(exc_info.value.detail) + ) + + @pytest.mark.asyncio + async def test_chat_completion_without_prompt_document(self): + """Test chat completion when no prompt document is present""" + from nilai_api.routers.private import chat_completion + from nilai_common import ChatRequest + + mock_user = MagicMock() + mock_user.userid = "test-user-id" + mock_user.name = "Test User" + mock_user.apikey = "test-api-key" + + mock_auth_info = AuthenticationInfo( + user=mock_user, + token_rate_limit=None, + prompt_document=None, # No prompt document + ) + + request = ChatRequest( + model="test-model", messages=[{"role": "user", "content": "Hello"}] + ) + + with ( + patch("nilai_api.routers.private.get_prompt_from_nildb") as mock_get_prompt, + patch("nilai_api.routers.private.AsyncOpenAI") as mock_openai_client, + patch("nilai_api.routers.private.state.get_model") as mock_get_model, + patch("nilai_api.routers.private.handle_nilrag") as mock_handle_nilrag, + patch( + "nilai_api.routers.private.handle_web_search" + ) as mock_handle_web_search, + patch( + "nilai_api.routers.private.UserManager.update_token_usage" + ) as mock_update_usage, + patch( + "nilai_api.routers.private.QueryLogManager.log_query" + ) as mock_log_query, + ): + # Mock state.get_model() to return a ModelEndpoint + mock_model_endpoint = MagicMock() + mock_model_endpoint.url = "http://test-model-endpoint" + mock_model_endpoint.metadata.tool_support = True + mock_model_endpoint.metadata.multimodal_support = True + mock_get_model.return_value = mock_model_endpoint + + # Mock handle_nilrag and handle_web_search + mock_handle_nilrag.return_value = None + mock_web_search_result = MagicMock() + mock_web_search_result.messages = request.messages + mock_web_search_result.sources = [] + mock_handle_web_search.return_value = mock_web_search_result + + # Mock async database operations + mock_update_usage.return_value = None + mock_log_query.return_value = None + + # Mock OpenAI client + mock_client_instance = MagicMock() + mock_response = MagicMock() + # Mock the response object that will be awaited + mock_response.model_dump.return_value = { + "id": "test-response-id", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Test response"}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + # Make the create method itself an AsyncMock that returns the response + mock_client_instance.chat.completions.create = AsyncMock( + return_value=mock_response + ) + mock_client_instance.close = AsyncMock() + mock_openai_client.return_value = mock_client_instance + + # Call the function + try: + await chat_completion(req=request, auth_info=mock_auth_info) + except Exception: + # Expected to fail due to incomplete mocking + pass + + # Should not call get_prompt_from_nildb when no prompt document + mock_get_prompt.assert_not_called() + + def test_prompt_delegation_request_model_validation(self): + """Test PromptDelegationRequest model validation""" + # Valid request + valid_request = "user-123" + assert valid_request == "user-123" + + # Test with different types of user IDs + request_with_uuid = "550e8400-e29b-41d4-a716-446655440000" + assert request_with_uuid == "550e8400-e29b-41d4-a716-446655440000" + + def test_prompt_delegation_token_model_validation(self): + """Test PromptDelegationToken model validation""" + token = PromptDelegationToken( + token="delegation_token_123", did="did:nil:builder123" + ) + assert token.token == "delegation_token_123" + assert token.did == "did:nil:builder123" + + def test_user_is_subscription_owner_property( + self, mock_subscription_owner_user, mock_regular_user + ): + """Test the is_subscription_owner property""" + # Subscription owner (userid == apikey) + assert mock_subscription_owner_user.is_subscription_owner is True + + # Regular user (userid != apikey) + assert mock_regular_user.is_subscription_owner is False diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index 183d5cdb..e0b4790e 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -83,7 +83,7 @@ def mock_user_manager(mock_user, mocker): @pytest.fixture -def mock_state(mocker, event_loop): +def mock_state(mocker): # Prepare expected models data expected_models = {"ABC": model_endpoint} @@ -110,7 +110,8 @@ def mock_state(mocker, event_loop): ) # Patch the get_attestation_report function mocker.patch( - "nilai_api.attestation.get_attestation_report", + "nilai_api.routers.private.get_attestation_report", + new_callable=AsyncMock, return_value=attestation_response, ) diff --git a/tests/unit/nuc_helpers/__init__.py b/tests/unit/nuc_helpers/__init__.py new file mode 100644 index 00000000..c0f8cc03 --- /dev/null +++ b/tests/unit/nuc_helpers/__init__.py @@ -0,0 +1,34 @@ +""" +Shared test utilities for nuc-helpers tests. + +This module contains common dummy classes and utilities used across +multiple test files to avoid code duplication. +""" + +from datetime import datetime, timedelta, timezone +from nuc.token import Did + + +class DummyNucToken: + """Dummy NUC token for testing purposes.""" + + def __init__(self, meta=None, issuer=None, expires_at=None): + self.meta = meta or {} + self.issuer = issuer or Did.parse(f"did:nil:{'1' * 66}") + self.expires_at = expires_at or (datetime.now(timezone.utc) + timedelta(days=1)) + + +class DummyDecodedNucToken: + """Dummy decoded NUC token for testing purposes.""" + + def __init__(self, meta=None, issuer=None, expires_at=None): + self.token = DummyNucToken(meta, issuer, expires_at) + self.signature = b"\x01\x02" + + +class DummyNucTokenEnvelope: + """Dummy NUC token envelope for testing purposes.""" + + def __init__(self, proofs, invocation_meta=None): + self.proofs = proofs + self.token = DummyDecodedNucToken(invocation_meta) diff --git a/tests/unit/nuc_helpers/test_nildb_document.py b/tests/unit/nuc_helpers/test_nildb_document.py new file mode 100644 index 00000000..860f7b8d --- /dev/null +++ b/tests/unit/nuc_helpers/test_nildb_document.py @@ -0,0 +1,176 @@ +import unittest +from unittest.mock import patch, MagicMock +from nuc.token import Did +from nuc_helpers.nildb_document import PromptDocument +from ..nuc_helpers import DummyDecodedNucToken, DummyNucTokenEnvelope + + +class TestPromptDocument(unittest.TestCase): + def setUp(self): + """Clear the cache before each test""" + PromptDocument.from_token.cache_clear() + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_from_token_no_document_id_returns_none(self, mock_parse): + """Test that from_token returns None when no document_id is found""" + proofs = [ + DummyDecodedNucToken({}), + DummyDecodedNucToken({"other_field": "value"}), + ] + envelope = DummyNucTokenEnvelope(proofs) + mock_parse.return_value = envelope + + result = PromptDocument.from_token("dummy_token") + self.assertIsNone(result) + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_from_token_with_document_id_returns_prompt_document(self, mock_parse): + """Test that from_token returns PromptDocument when document_id is found""" + issuer_did = f"did:nil:{'1' * 66}" + document_id = "test-document-123" + + proofs = [ + DummyDecodedNucToken( + {"document_id": document_id, "document_owner_did": issuer_did}, + Did.parse(issuer_did), + ), + DummyDecodedNucToken({}), + ] + envelope = DummyNucTokenEnvelope(proofs) + mock_parse.return_value = envelope + + result = PromptDocument.from_token( + "dummy_token" + ) # will return the envelope above + + self.assertIsNotNone(result) + self.assertEqual(result.document_id, document_id) # type: ignore + self.assertEqual(result.owner_did, issuer_did) # type: ignore + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_from_token_multiple_document_ids_returns_first(self, mock_parse): + """Test that from_token returns the first document_id found (uppermost in chain)""" + issuer_did_1 = f"did:nil:{'1' * 66}" + issuer_did_2 = f"did:nil:{'2' * 66}" + document_id_1 = "first-document" + document_id_2 = "second-document" + + # Note: proofs are processed in reverse order, so the last one is "uppermost" + proofs = [ + DummyDecodedNucToken( + {"document_id": document_id_2, "document_owner_did": issuer_did_2}, + Did.parse(issuer_did_2), + ), + DummyDecodedNucToken( + {"document_id": document_id_1, "document_owner_did": issuer_did_1}, + Did.parse(issuer_did_1), + ), + ] + envelope = DummyNucTokenEnvelope(proofs) + mock_parse.return_value = envelope + + result = PromptDocument.from_token( + "dummy_token" + ) # will return the envelope above + + self.assertIsNotNone(result) + self.assertEqual(result.document_id, document_id_1) # type: ignore + self.assertEqual(result.owner_did, issuer_did_1) # type: ignore + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_from_token_with_none_document_id_skips(self, mock_parse): + """Test that from_token skips proofs with None document_id""" + issuer_did = f"did:nil:{'1' * 66}" + document_id = "valid-document" + + proofs = [ + DummyDecodedNucToken( + {"document_id": None, "document_owner_did": issuer_did}, + Did.parse(issuer_did), + ), + DummyDecodedNucToken( + {"document_id": document_id, "document_owner_did": issuer_did}, + Did.parse(issuer_did), + ), + ] + envelope = DummyNucTokenEnvelope(proofs) + mock_parse.return_value = envelope + + result = PromptDocument.from_token( + "dummy_token" + ) # will return the envelope above + + self.assertIsNotNone(result) + self.assertEqual(result.document_id, document_id) # type: ignore + self.assertEqual(result.owner_did, issuer_did) # type: ignore + + @patch("nuc.envelope.NucTokenEnvelope.parse") + def test_from_token_with_null_token_meta_skips(self, mock_parse): + """Test that from_token skips proofs with null token meta""" + issuer_did = f"did:nil:{'1' * 66}" + document_id = "valid-document" + + # Create a proof with null token + proof_with_null_token = MagicMock() + proof_with_null_token.token = None + + proofs = [ + proof_with_null_token, + DummyDecodedNucToken( + {"document_id": document_id, "document_owner_did": issuer_did}, + Did.parse(issuer_did), + ), + ] + envelope = DummyNucTokenEnvelope(proofs) + mock_parse.return_value = envelope + + result = PromptDocument.from_token( + "dummy_token" + ) # will return the envelope above + + self.assertIsNotNone(result) + self.assertEqual(result.document_id, document_id) # type: ignore + self.assertEqual(result.owner_did, issuer_did) # type: ignore + + def test_prompt_document_model_validation(self): + """Test that PromptDocument model validates correctly""" + issuer_did = f"did:nil:{'1' * 66}" + document_id = "test-document-123" + + prompt_doc = PromptDocument(document_id=document_id, owner_did=issuer_did) + + self.assertEqual(prompt_doc.document_id, document_id) # type: ignore + self.assertEqual(prompt_doc.owner_did, issuer_did) # type: ignore + + def test_cache_functionality(self): + """Test that the cache works correctly""" + with patch("nuc.envelope.NucTokenEnvelope.parse") as mock_parse: + issuer_did = f"did:nil:{'1' * 66}" + document_id = "cached-document" + + proofs = [ + DummyDecodedNucToken( + {"document_id": document_id, "document_owner_did": issuer_did}, + Did.parse(issuer_did), + ) + ] + envelope = DummyNucTokenEnvelope(proofs) + mock_parse.return_value = envelope + + token = "test_token" + + # First call + result1 = PromptDocument.from_token(token) # will return the envelope above + # Second call - should use cache + result2 = PromptDocument.from_token(token) + + # Should only parse once due to caching + mock_parse.assert_called_once() + + # Both results should be identical + self.assertEqual(result1.document_id, result2.document_id) # type: ignore + self.assertEqual(result1.owner_did, result2.owner_did) # type: ignore + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/nuc-helpers/test_usage.py b/tests/unit/nuc_helpers/test_usage.py similarity index 94% rename from tests/unit/nuc-helpers/test_usage.py rename to tests/unit/nuc_helpers/test_usage.py index 2c181338..445b3286 100644 --- a/tests/unit/nuc-helpers/test_usage.py +++ b/tests/unit/nuc_helpers/test_usage.py @@ -1,31 +1,11 @@ import unittest from unittest.mock import patch from nuc_helpers.usage import TokenRateLimits, UsageLimitError, UsageLimitKind +from ..nuc_helpers import DummyDecodedNucToken, DummyNucTokenEnvelope from datetime import datetime, timedelta, timezone -# Dummy token envelope structure to simulate nuc.envelope -class DummyNucToken: - def __init__( - self, meta=None, expires_at=datetime.now(timezone.utc) + timedelta(days=1) - ): - self.meta = meta or {} - self.expires_at = expires_at - - -class DummyDecodedNucToken: - def __init__(self, meta=None): - self.token = DummyNucToken(meta) - self.signature = b"\x01\x02" - - -class DummyNucTokenEnvelope: - def __init__(self, proofs, invocation_meta=None): - self.proofs = proofs - self.token = DummyDecodedNucToken(invocation_meta) - - class GetUsageLimitTests(unittest.TestCase): def setUp(self): """Clear the cache before each test, because the cache is global and we use the same dummy token for all tests.""" diff --git a/uv.lock b/uv.lock index 0ae7887e..01cde052 100644 --- a/uv.lock +++ b/uv.lock @@ -289,6 +289,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, ] +[[package]] +name = "blindfold" +version = "0.1.0rc0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bcl" }, + { name = "lagrange" }, + { name = "pailliers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0c/1e/53d07fa342006e4f468148e76b8905245d3aca7124db41a82ed3f21d5cc1/blindfold-0.1.0rc0.tar.gz", hash = "sha256:b0e309b6427873dc81d2c397ee71664a0a071e932e29ad879825b9fc8e7e9786", size = 20370, upload-time = "2025-06-24T01:58:33.887Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/c6/a161c9592167ac0ad0a60586638bbc9dd65e229873980e08d391f37625f8/blindfold-0.1.0rc0-py3-none-any.whl", hash = "sha256:249d87fc6cea556bb6c96f171d6d3352381b7b1988adec7a3334ee2f6e43716e", size = 14425, upload-time = "2025-06-24T01:58:32.614Z" }, +] + [[package]] name = "certifi" version = "2025.1.31" @@ -1142,6 +1156,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/0f/8910b19ac0670a0f80ce1008e5e751c4a57e14d2c4c13a482aa6079fa9d6/jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf", size = 18459, upload-time = "2024-10-08T12:29:30.439Z" }, ] +[[package]] +name = "lagrange" +version = "3.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/07/9d/4b6470fd6769b0943fbda9b30e2068bb8d9940be2977b1e80a184d527fa6/lagrange-3.0.1.tar.gz", hash = "sha256:272f352a676679ee318b0b302054f667f23afb73d10063cd3926c612527e09f1", size = 6894, upload-time = "2025-01-01T01:33:14.999Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3e/d8/f1c3ff60a8b3e114cfb3e9eed75140d2a3e1e766791cfe2f210a5c736d61/lagrange-3.0.1-py3-none-any.whl", hash = "sha256:d473913d901f0c257456c505e4a94450f2e4a2f147460a68ad0cfb9ea33a6d0a", size = 6905, upload-time = "2025-01-01T01:33:11.031Z" }, +] + [[package]] name = "mako" version = "1.3.9" @@ -1368,6 +1391,7 @@ dependencies = [ { name = "python-dotenv" }, { name = "pyyaml" }, { name = "redis" }, + { name = "secretvaults" }, { name = "sqlalchemy" }, { name = "uvicorn" }, { name = "verifier" }, @@ -1388,7 +1412,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27.2" }, { name = "nilai-common", editable = "packages/nilai-common" }, { name = "nilrag", specifier = ">=0.1.11" }, - { name = "nuc", git = "https://github.com/NillionNetwork/nuc-py.git?rev=4922b5e9354e611cc31322d681eb29da05be584e" }, + { name = "nuc", specifier = ">=0.1.0" }, { name = "nuc-helpers", editable = "nilai-auth/nuc-helpers" }, { name = "openai", specifier = ">=1.59.9" }, { name = "pg8000", specifier = ">=1.31.2" }, @@ -1396,6 +1420,7 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "pyyaml", specifier = ">=6.0.1" }, { name = "redis", specifier = ">=5.2.1" }, + { name = "secretvaults", git = "https://github.com/jcabrero/secretvaults-py?rev=main" }, { name = "sqlalchemy", specifier = ">=2.0.36" }, { name = "uvicorn", specifier = ">=0.32.1" }, { name = "verifier" }, @@ -1512,13 +1537,17 @@ wheels = [ [[package]] name = "nuc" -version = "0.0.0a0" -source = { git = "https://github.com/NillionNetwork/nuc-py.git?rev=4922b5e9354e611cc31322d681eb29da05be584e#4922b5e9354e611cc31322d681eb29da05be584e" } +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cosmpy" }, { name = "requests" }, { name = "secp256k1" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/e9/58/acfdbdd6dc8e8575a1bc2ade9eedf7d33d99ac428573df5a46a4f4b76949/nuc-0.1.0.tar.gz", hash = "sha256:6a715bf07a8adf2901b68c9597ba44ae28506c3fb0fa03530c092bc0f8ba22f0", size = 29586, upload-time = "2025-07-01T14:46:55.774Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/80/ba/a99b12ee5132976d974fe65f9dbeaaafe4183a8558859c72bd271f87e25c/nuc-0.1.0-py3-none-any.whl", hash = "sha256:6845133866f2d41592be74ca2a41295d09d7a6d89886a5a1181dceefd4fe5a65", size = 22513, upload-time = "2025-07-01T14:46:54.685Z" }, +] [[package]] name = "nuc-helpers" @@ -1536,7 +1565,7 @@ dependencies = [ requires-dist = [ { name = "cosmpy", specifier = "==0.9.2" }, { name = "httpx", specifier = ">=0.28.1" }, - { name = "nuc", git = "https://github.com/NillionNetwork/nuc-py.git?rev=4922b5e9354e611cc31322d681eb29da05be584e" }, + { name = "nuc", specifier = ">=0.1.0" }, { name = "pydantic", specifier = ">=2.11.2" }, { name = "secp256k1", specifier = ">=0.14.0" }, ] @@ -2115,11 +2144,11 @@ wheels = [ [[package]] name = "python-dotenv" -version = "1.1.0" +version = "1.1.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/88/2c/7bb1416c5620485aa793f2de31d3df393d3686aa8a8506d11e10e13c5baf/python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5", size = 39920, upload-time = "2025-03-25T10:14:56.835Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/b0/4bc07ccd3572a2f9df7e6782f52b0c6c90dcbb803ac4a167702d7d0dfe1e/python_dotenv-1.1.1.tar.gz", hash = "sha256:a8a6399716257f45be6a007360200409fce5cda2661e3dec71d23dc15f6189ab", size = 41978, upload-time = "2025-06-24T04:21:07.341Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/18/98a99ad95133c6a6e2005fe89faedf294a748bd5dc803008059409ac9b1e/python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d", size = 20256, upload-time = "2025-03-25T10:14:55.034Z" }, + { url = "https://files.pythonhosted.org/packages/5f/ed/539768cf28c661b5b068d66d96a2f155c4971a5d55684a514c1a0e0dec2f/python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc", size = 20556, upload-time = "2025-06-24T04:21:06.073Z" }, ] [[package]] @@ -2484,6 +2513,19 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/9b/41/bb668a6e4192303542d2d90c3b38d564af3c17c61bd7d4039af4f29405fe/secp256k1-0.14.0.tar.gz", hash = "sha256:82c06712d69ef945220c8b53c1a0d424c2ff6a1f64aee609030df79ad8383397", size = 2420607, upload-time = "2021-11-06T01:36:10.707Z" } +[[package]] +name = "secretvaults" +version = "0.2.1" +source = { git = "https://github.com/jcabrero/secretvaults-py?rev=main#498ee5304fdcc730d1810fcf6172e56fa6dd7d16" } +dependencies = [ + { name = "aiohttp" }, + { name = "blindfold" }, + { name = "nuc" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "structlog" }, +] + [[package]] name = "sentence-transformers" version = "4.0.2" @@ -2580,6 +2622,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/4b/528ccf7a982216885a1ff4908e886b8fb5f19862d1962f56a3fce2435a70/starlette-0.46.1-py3-none-any.whl", hash = "sha256:77c74ed9d2720138b25875133f3a2dae6d854af2ec37dceb56aef370c1d8a227", size = 71995, upload-time = "2025-03-08T10:55:32.662Z" }, ] +[[package]] +name = "structlog" +version = "25.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/79/b9/6e672db4fec07349e7a8a8172c1a6ae235c58679ca29c3f86a61b5e59ff3/structlog-25.4.0.tar.gz", hash = "sha256:186cd1b0a8ae762e29417095664adf1d6a31702160a46dacb7796ea82f7409e4", size = 1369138, upload-time = "2025-06-02T08:21:12.971Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/4a/97ee6973e3a73c74c8120d59829c3861ea52210667ec3e7a16045c62b64d/structlog-25.4.0-py3-none-any.whl", hash = "sha256:fe809ff5c27e557d14e613f45ca441aabda051d119ee5a0102aaba6ce40eed2c", size = 68720, upload-time = "2025-06-02T08:21:11.43Z" }, +] + [[package]] name = "sympy" version = "1.13.1"