diff --git a/.env.ci b/.env.ci index 0272eec5..c459c928 100644 --- a/.env.ci +++ b/.env.ci @@ -48,7 +48,7 @@ GF_SECURITY_ADMIN_PASSWORD = "password" # WebSearch Settings BRAVE_SEARCH_API = "Your API here" -# Optional: Override default testnet URLs if needed +# NilDB Configuration (Required) NILDB_NILCHAIN_URL=http://rpc.testnet.nilchain-rpc-proxy.nilogy.xyz NILDB_NILAUTH_URL=https://nilauth.sandbox.app-cluster.sandbox.nilogy.xyz NILDB_NODES=https://nildb-stg-n1.nillion.network,https://nildb-stg-n2.nillion.network,https://nildb-stg-n3.nillion.network diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 1ed12f0b..4291bfac 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -51,11 +51,13 @@ jobs: sed -i 's/HF_TOKEN=.*/HF_TOKEN=dummy_token/' .env sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=dummy_api/' .env + - name: pyright + run: uv run pyright + - name: Run tests run: uv run pytest -v tests/unit - - name: pyright - run: uv run pyright + start-runner: name: Start self-hosted EC2 runner @@ -252,6 +254,7 @@ jobs: run: | set -e export ENVIRONMENT=ci + export AUTH_STRATEGY=nuc uv run pytest -v tests/e2e - name: Run E2E tests for API Key diff --git a/nilai-api/alembic/env.py b/nilai-api/alembic/env.py index 7ffeb8fb..5cc65034 100644 --- a/nilai-api/alembic/env.py +++ b/nilai-api/alembic/env.py @@ -11,7 +11,7 @@ from nilai_api.db import Base from nilai_api.db.users import UserModel from nilai_api.db.logs import QueryLog -import nilai_api.config as nilai_config +from nilai_api.config import CONFIG as nilai_config # If we don't use the models, they remain unused, and the migration fails # This is a workaround to ensure the models are loaded @@ -93,12 +93,12 @@ def run_migrations_online() -> None: load_dotenv() -db_host = nilai_config.DB_HOST +db_host = nilai_config.database.host if db_host: - db_port = nilai_config.DB_PORT - db_user = nilai_config.DB_USER - db_pass = nilai_config.DB_PASS - db_name = nilai_config.DB_NAME + db_port = nilai_config.database.port + db_user = nilai_config.database.user + db_pass = nilai_config.database.password + db_name = nilai_config.database.db config.set_main_option( "sqlalchemy.url", f"postgresql+asyncpg://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}", diff --git a/nilai-api/alembic/versions/da89d3230653_create_initial_set_of_tables.py b/nilai-api/alembic/versions/da89d3230653_create_initial_set_of_tables.py index b2c9b467..186d6905 100644 --- a/nilai-api/alembic/versions/da89d3230653_create_initial_set_of_tables.py +++ b/nilai-api/alembic/versions/da89d3230653_create_initial_set_of_tables.py @@ -10,11 +10,7 @@ from alembic import op import sqlalchemy as sa -from nilai_api.config import ( - USER_RATE_LIMIT_MINUTE, - USER_RATE_LIMIT_HOUR, - USER_RATE_LIMIT_DAY, -) +from nilai_api.config import CONFIG # revision identifiers, used by Alembic. @@ -39,15 +35,21 @@ def upgrade() -> None: ), sa.Column("last_activity", sa.DateTime, nullable=True), sa.Column( - "ratelimit_day", sa.Integer, default=USER_RATE_LIMIT_DAY, nullable=True + "ratelimit_day", + sa.Integer, + default=CONFIG.rate_limiting.user_rate_limit_day, + nullable=True, ), sa.Column( - "ratelimit_hour", sa.Integer, default=USER_RATE_LIMIT_HOUR, nullable=True + "ratelimit_hour", + sa.Integer, + default=CONFIG.rate_limiting.user_rate_limit_hour, + nullable=True, ), sa.Column( "ratelimit_minute", sa.Integer, - default=USER_RATE_LIMIT_MINUTE, + default=CONFIG.rate_limiting.user_rate_limit_minute, nullable=True, ), ) diff --git a/nilai-api/src/nilai_api/app.py b/nilai-api/src/nilai_api/app.py index 18a24f8a..8a4e7ac4 100644 --- a/nilai-api/src/nilai_api/app.py +++ b/nilai-api/src/nilai_api/app.py @@ -14,7 +14,7 @@ @asynccontextmanager async def lifespan(app: FastAPI): - client, rate_limit_command = await setup_redis_conn(config.REDIS_URL) + client, rate_limit_command = await setup_redis_conn(config.CONFIG.redis.url) yield {"redis": client, "redis_rate_limit_command": rate_limit_command} diff --git a/nilai-api/src/nilai_api/auth/__init__.py b/nilai-api/src/nilai_api/auth/__init__.py index 46101227..9252b76a 100644 --- a/nilai-api/src/nilai_api/auth/__init__.py +++ b/nilai-api/src/nilai_api/auth/__init__.py @@ -3,7 +3,7 @@ from logging import getLogger -from nilai_api import config +from nilai_api.config import CONFIG from nilai_api.db.users import UserManager from nilai_api.auth.strategies import AuthenticationStrategy @@ -25,7 +25,7 @@ async def get_auth_info( credentials: HTTPAuthorizationCredentials = Security(bearer_scheme), ) -> AuthenticationInfo: try: - strategy_name: str = config.AUTH_STRATEGY.upper() + strategy_name: str = CONFIG.auth.auth_strategy.upper() try: strategy = AuthenticationStrategy[strategy_name] diff --git a/nilai-api/src/nilai_api/auth/nuc.py b/nilai-api/src/nilai_api/auth/nuc.py index 8274376d..51d7ee2e 100644 --- a/nilai-api/src/nilai_api/auth/nuc.py +++ b/nilai-api/src/nilai_api/auth/nuc.py @@ -5,7 +5,7 @@ from nuc.nilauth import NilauthClient from nuc.token import Did, NucToken, Command from functools import lru_cache -from nilai_api.config import NILAUTH_TRUSTED_ROOT_ISSUERS +from nilai_api.config import CONFIG from nilai_api.state import state from nilai_api.auth.common import AuthenticationError @@ -32,7 +32,7 @@ def get_validator() -> NucTokenValidator: try: nilauth_public_keys = [ Did(NilauthClient(key).about().public_key.serialize()) - for key in NILAUTH_TRUSTED_ROOT_ISSUERS + for key in CONFIG.auth.nilauth_trusted_root_issuers ] except Exception as e: logger.error(f"Error getting validator: {e}") diff --git a/nilai-api/src/nilai_api/auth/strategies.py b/nilai-api/src/nilai_api/auth/strategies.py index dc279d73..0c64cce6 100644 --- a/nilai-api/src/nilai_api/auth/strategies.py +++ b/nilai-api/src/nilai_api/auth/strategies.py @@ -8,7 +8,7 @@ get_token_rate_limit, get_token_prompt_document, ) -from nilai_api.config import DOCS_TOKEN +from nilai_api.config import CONFIG from nilai_api.auth.common import ( PromptDocument, TokenRateLimits, @@ -69,7 +69,7 @@ async def wrapper(token) -> AuthenticationInfo: return decorator -@allow_token(DOCS_TOKEN) +@allow_token(CONFIG.docs.token) async def api_key_strategy(api_key: str) -> AuthenticationInfo: user_model: Optional[UserModel] = await UserManager.check_api_key(api_key) if user_model: @@ -81,7 +81,7 @@ async def api_key_strategy(api_key: str) -> AuthenticationInfo: raise AuthenticationError("Missing or invalid API key") -@allow_token(DOCS_TOKEN) +@allow_token(CONFIG.docs.token) async def jwt_strategy(jwt_creds: str) -> AuthenticationInfo: result = validate_jwt(jwt_creds) user_model: Optional[UserModel] = await UserManager.check_api_key( @@ -107,7 +107,7 @@ async def jwt_strategy(jwt_creds: str) -> AuthenticationInfo: ) -@allow_token(DOCS_TOKEN) +@allow_token(CONFIG.docs.token) async def nuc_strategy(nuc_token) -> AuthenticationInfo: """ Validate a NUC token and return the user model diff --git a/nilai-api/src/nilai_api/config/__init__.py b/nilai-api/src/nilai_api/config/__init__.py index 7df5b857..fc768f26 100644 --- a/nilai-api/src/nilai_api/config/__init__.py +++ b/nilai-api/src/nilai_api/config/__init__.py @@ -1,83 +1,64 @@ -import os -from typing import List, Dict, Any, Optional -import yaml -from dotenv import load_dotenv -from dataclasses import dataclass - -load_dotenv() - -ENVIRONMENT: str = os.getenv("ENVIRONMENT", "testnet") - -ETCD_HOST: str = os.getenv("ETCD_HOST", "localhost") -ETCD_PORT: int = int(os.getenv("ETCD_PORT", 2379)) - - -REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379") - -DOCS_TOKEN: str | None = os.getenv("DOCS_TOKEN", None) - -DB_USER: str = os.getenv("POSTGRES_USER", "postgres") -DB_PASS: str = os.getenv("POSTGRES_PASSWORD", "") -DB_HOST: str = os.getenv("POSTGRES_HOST", "localhost") -DB_PORT: int = int(os.getenv("POSTGRES_PORT", 5432)) -DB_NAME: str = os.getenv("POSTGRES_DB", "nilai_users") - - -NILAUTH_TRUSTED_ROOT_ISSUERS: List[str] = os.getenv( - "NILAUTH_TRUSTED_ROOT_ISSUERS", "" -).split(",") - -AUTH_STRATEGY: str = os.getenv("AUTH_STRATEGY", "api_key") - - -# Web Search API configuration -@dataclass -class WebSearchSettings: - api_key: Optional[str] = None - api_path: str = "https://api.search.brave.com/res/v1/web/search" - count: int = 3 - lang: str = "en" - country: str = "us" - timeout: float = 20.0 - max_concurrent_requests: int = 20 - rps: int = 20 - - -WEB_SEARCH_SETTINGS = WebSearchSettings(api_key=os.getenv("BRAVE_SEARCH_API")) - -# Default values -USER_RATE_LIMIT_MINUTE: Optional[int] = 100 -USER_RATE_LIMIT_HOUR: Optional[int] = 1000 -USER_RATE_LIMIT_DAY: Optional[int] = 10000 -WEB_SEARCH_RATE_LIMIT_MINUTE: Optional[int] = 1 -WEB_SEARCH_RATE_LIMIT_HOUR: Optional[int] = 3 -WEB_SEARCH_RATE_LIMIT_DAY: Optional[int] = 72 -MODEL_CONCURRENT_RATE_LIMIT: Dict[str, int] = {} - - -def load_config_from_yaml(config_path: str) -> Dict[str, Any]: - if os.path.exists(config_path): - with open(config_path, "r") as f: - return yaml.safe_load(f) - return {} - - -config_file: str = "config.yaml" -config_path = os.path.join(os.path.dirname(__file__), config_file) - -if not os.path.exists(config_path): - config_file = "config.yaml" - config_path = os.path.join(os.path.dirname(__file__), config_file) - -config_data = load_config_from_yaml(config_path) - -# Overwrite with values from yaml -if config_data: - USER_RATE_LIMIT_MINUTE = config_data.get( - "user_rate_limit_minute", USER_RATE_LIMIT_MINUTE +# Import all configuration models +import json +from .environment import EnvironmentConfig +from .database import DatabaseConfig, EtcdConfig, RedisConfig +from .auth import AuthConfig, DocsConfig +from .nildb import NilDBConfig +from .web_search import WebSearchSettings +from .rate_limiting import RateLimitingConfig +from .utils import create_config_model, CONFIG_DATA +from pydantic import BaseModel +import logging + + +class NilAIConfig(BaseModel): + """Centralized configuration container for the Nilai API.""" + + environment: EnvironmentConfig = create_config_model( + EnvironmentConfig, "", CONFIG_DATA + ) + database: DatabaseConfig = create_config_model( + DatabaseConfig, "database", CONFIG_DATA, "POSTGRES_" + ) + etcd: EtcdConfig = create_config_model(EtcdConfig, "etcd", CONFIG_DATA, "ETCD_") + redis: RedisConfig = create_config_model( + RedisConfig, "redis", CONFIG_DATA, "REDIS_" ) - USER_RATE_LIMIT_HOUR = config_data.get("user_rate_limit_hour", USER_RATE_LIMIT_HOUR) - USER_RATE_LIMIT_DAY = config_data.get("user_rate_limit_day", USER_RATE_LIMIT_DAY) - MODEL_CONCURRENT_RATE_LIMIT = config_data.get( - "model_concurrent_rate_limit", MODEL_CONCURRENT_RATE_LIMIT + auth: AuthConfig = create_config_model(AuthConfig, "auth", CONFIG_DATA) + docs: DocsConfig = create_config_model(DocsConfig, "docs", CONFIG_DATA, "DOCS_") + web_search: WebSearchSettings = create_config_model( + WebSearchSettings, "web_search", CONFIG_DATA, "WEB_SEARCH_" ) + rate_limiting: RateLimitingConfig = create_config_model( + RateLimitingConfig, "rate_limiting", CONFIG_DATA + ) + nildb: NilDBConfig = create_config_model( + NilDBConfig, "nildb", CONFIG_DATA, "NILDB_" + ) + + def prettify(self): + """Print the config in a pretty format removing passwords and other sensitive information""" + config_dict = self.model_dump() + keywords = ["pass", "token", "key"] + for key, value in config_dict.items(): + if isinstance(value, str): + for keyword in keywords: + print(key, keyword, keyword in key) + if keyword in key and value is not None: + config_dict[key] = "***************" + if isinstance(value, dict): + for k, v in value.items(): + for keyword in keywords: + if keyword in k and v is not None: + value[k] = "***************" + return json.dumps(config_dict, indent=4) + + +# Global config instance +CONFIG = NilAIConfig() +__all__ = [ + # Main config object + "CONFIG" +] + +logging.info(CONFIG.prettify()) diff --git a/nilai-api/src/nilai_api/config/auth.py b/nilai-api/src/nilai_api/config/auth.py new file mode 100644 index 00000000..1f044f19 --- /dev/null +++ b/nilai-api/src/nilai_api/config/auth.py @@ -0,0 +1,18 @@ +from typing import List, Optional, Literal +from pydantic import BaseModel, Field + + +class AuthConfig(BaseModel): + auth_strategy: Literal["api_key", "jwt", "nuc"] = Field( + description="Authentication strategy" + ) + nilauth_trusted_root_issuers: List[str] = Field( + description="Trusted root issuers for nilauth" + ) + auth_token: Optional[str] = Field( + default=None, description="Auth token for testing" + ) + + +class DocsConfig(BaseModel): + token: Optional[str] = Field(default=None, description="Documentation access token") diff --git a/nilai-api/src/nilai_api/config/config.yaml b/nilai-api/src/nilai_api/config/config.yaml index 9f42f7bb..3bc82ca3 100644 --- a/nilai-api/src/nilai_api/config/config.yaml +++ b/nilai-api/src/nilai_api/config/config.yaml @@ -1,33 +1,43 @@ # In production, this file is automatically generated by the `ansible` playbook. +# Configuration with structured sections and default values -model_concurrent_rate_limit: - meta-llama/Llama-3.2-1B-Instruct: 45 - meta-llama/Llama-3.2-3B-Instruct: 50 - meta-llama/Llama-3.1-8B-Instruct: 30 - cognitivecomputations/Dolphin3.0-Llama3.1-8B: 30 - deepseek-ai/DeepSeek-R1-Distill-Qwen-14B: 5 - hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4: 5 - openai/gpt-oss-20b: 50 - default: 50 +# Environment Configuration +environment: "mainnet" -user_rate_limit_minute: null -user_rate_limit_hour: null -user_rate_limit_day: null -web_search_rate_limit_minute: 1 -web_search_rate_limit_hour: 3 -web_search_rate_limit_day: 72 +# Authentication Configuration +auth: + strategy: "api_key" + nilauth_trusted_root_issuers: [] -# model_concurrent_rate_limit: -# meta-llama/Llama-3.2-1B-Instruct: 10 -# meta-llama/Llama-3.2-3B-Instruct: 10 -# meta-llama/Llama-3.1-8B-Instruct: 5 -# cognitivecomputations/Dolphin3.0-Llama3.1-8B: 5 -# deepseek-ai/DeepSeek-R1-Distill-Qwen-14B: 5 -# hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4: 5 +# Documentation Configuration +docs: + token: null -# user_rate_limit_minute: 100 -# user_rate_limit_hour: 1000 -# user_rate_limit_day: 10000 -# web_search_rate_limit_minute: 1 -# web_search_rate_limit_hour: 3 -# web_search_rate_limit_day: 72 +# Web Search Configuration +web_search: + api_key: null + api_path: "https://api.search.brave.com/res/v1/web/search" + count: 3 + lang: "en" + country: "us" + timeout: 20.0 + max_concurrent_requests: 20 + rps: 20 + +# Rate Limiting Configuration +rate_limiting: + user_rate_limit_minute: 100 + user_rate_limit_hour: 1000 + user_rate_limit_day: 10000 + web_search_rate_limit_minute: 1 + web_search_rate_limit_hour: 3 + web_search_rate_limit_day: 72 + model_concurrent_rate_limit: + meta-llama/Llama-3.2-1B-Instruct: 45 + meta-llama/Llama-3.2-3B-Instruct: 50 + meta-llama/Llama-3.1-8B-Instruct: 30 + cognitivecomputations/Dolphin3.0-Llama3.1-8B: 30 + deepseek-ai/DeepSeek-R1-Distill-Qwen-14B: 5 + hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4: 5 + openai/gpt-oss-20b: 50 + default: 50 diff --git a/nilai-api/src/nilai_api/config/database.py b/nilai-api/src/nilai_api/config/database.py new file mode 100644 index 00000000..6cc1371a --- /dev/null +++ b/nilai-api/src/nilai_api/config/database.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel, Field + + +class DatabaseConfig(BaseModel): + user: str = Field(description="Database user") + password: str = Field(description="Database password") + host: str = Field(description="Database host") + port: int = Field(description="Database port") + db: str = Field(description="Database name") + + +class EtcdConfig(BaseModel): + host: str = Field(description="ETCD host") + port: int = Field(description="ETCD port") + + +class RedisConfig(BaseModel): + url: str = Field(description="Redis URL") diff --git a/nilai-api/src/nilai_api/config/environment.py b/nilai-api/src/nilai_api/config/environment.py new file mode 100644 index 00000000..52890f7d --- /dev/null +++ b/nilai-api/src/nilai_api/config/environment.py @@ -0,0 +1,8 @@ +from typing import Literal +from pydantic import BaseModel, Field + + +class EnvironmentConfig(BaseModel): + environment: Literal["testnet", "mainnet", "ci"] = Field( + default="mainnet", description="The environment to use" + ) diff --git a/nilai-api/src/nilai_api/config/nildb.py b/nilai-api/src/nilai_api/config/nildb.py new file mode 100644 index 00000000..7ff72cab --- /dev/null +++ b/nilai-api/src/nilai_api/config/nildb.py @@ -0,0 +1,25 @@ +from typing import List +from pydantic import BaseModel, Field, field_validator +from secretvaults.common.types import Uuid + + +class NilDBConfig(BaseModel): + nilchain_url: str = Field(..., description="The URL of the Nilchain") + nilauth_url: str = Field(..., description="The URL of the Nilauth") + nodes: List[str] = Field(..., description="The URLs of the Nildb nodes") + builder_private_key: str = Field(..., description="The private key of the builder") + collection: Uuid = Field(..., description="The ID of the collection") + + @field_validator("nodes", mode="before") + @classmethod + def parse_nodes(cls, v): + if isinstance(v, str): + return v.split(",") + return v + + @field_validator("collection", mode="before") + @classmethod + def parse_collection(cls, v): + if isinstance(v, str): + return Uuid(v) + return v diff --git a/nilai-api/src/nilai_api/config/rate_limiting.py b/nilai-api/src/nilai_api/config/rate_limiting.py new file mode 100644 index 00000000..0efce1b1 --- /dev/null +++ b/nilai-api/src/nilai_api/config/rate_limiting.py @@ -0,0 +1,26 @@ +from typing import Dict, Optional +from pydantic import BaseModel, Field + + +class RateLimitingConfig(BaseModel): + user_rate_limit_minute: Optional[int] = Field( + default=100, description="User requests per minute limit" + ) + user_rate_limit_hour: Optional[int] = Field( + default=1000, description="User requests per hour limit" + ) + user_rate_limit_day: Optional[int] = Field( + default=10000, description="User requests per day limit" + ) + web_search_rate_limit_minute: Optional[int] = Field( + default=1, description="Web search requests per minute limit" + ) + web_search_rate_limit_hour: Optional[int] = Field( + default=3, description="Web search requests per hour limit" + ) + web_search_rate_limit_day: Optional[int] = Field( + default=72, description="Web search requests per day limit" + ) + model_concurrent_rate_limit: Dict[str, int] = Field( + default_factory=dict, description="Model concurrent request limits" + ) diff --git a/nilai-api/src/nilai_api/config/utils.py b/nilai-api/src/nilai_api/config/utils.py new file mode 100644 index 00000000..d26d832b --- /dev/null +++ b/nilai-api/src/nilai_api/config/utils.py @@ -0,0 +1,112 @@ +import os +from typing import Dict, Any, Optional, Type, TypeVar, get_origin +import yaml +from dotenv import load_dotenv +from pydantic import BaseModel +import json + +load_dotenv() + +T = TypeVar("T", bound=BaseModel) + + +def load_config_from_yaml(config_path: str) -> Dict[str, Any]: + """Load configuration from YAML file.""" + if os.path.exists(config_path): + with open(config_path, "r") as f: + return yaml.safe_load(f) + return {} + + +def get_nested_value(data: Dict[str, Any], key_path: str) -> Any: + """Get nested value from dict using dot notation.""" + value = data + for key in key_path.split("."): + if isinstance(value, dict) and key in value: + value = value[key] + else: + return None + return value + + +def create_config_model( + model_class: Type[T], + yaml_section: str, + config_data: Dict[str, Any], + env_prefix: str = "", + custom_env_mapping: Optional[Dict[str, str]] = None, +) -> T: + """Create Pydantic model instance with YAML-first, env override approach.""" + # Get YAML section data + yaml_data = get_nested_value(config_data, yaml_section) or {} + + # Prepare data dict with environment overrides + model_data = {} + custom_env_mapping = custom_env_mapping or {} + + # Get model fields + for field_name, field_info in model_class.model_fields.items(): + # Determine environment variable key + if field_name in custom_env_mapping: + # Use custom mapping first + env_keys = [custom_env_mapping[field_name]] + else: + # Use standard prefix logic + env_keys = [ + f"{env_prefix}{field_name.upper()}" + if env_prefix + else field_name.upper() + ] + + # Add special case for api_key -> BRAVE_SEARCH_API for backward compatibility + if ( + field_name == "api_key" + and "BRAVE_SEARCH_API" + not in [custom_env_mapping.get(field_name, "")] + env_keys + ): + env_keys.append("BRAVE_SEARCH_API") + + # Try environment variables in order + env_value = None + for env_key in env_keys: + env_value = os.getenv(env_key) + if env_value is not None: + break + + if env_value is not None: + # Handle type conversion for environment variables + field_type = field_info.annotation + if field_type is bool: + model_data[field_name] = env_value.lower() in ("true", "1", "yes", "on") + elif field_type is int: + model_data[field_name] = int(env_value) + elif field_type is float: + model_data[field_name] = float(env_value) + elif get_origin(field_type) is list: + model_data[field_name] = env_value.split(",") if env_value else [] + elif field_type is dict or get_origin(field_type) is dict: + try: + model_data[field_name] = json.loads(env_value) + except json.JSONDecodeError: + model_data[field_name] = {} + else: + model_data[field_name] = env_value + elif field_name in yaml_data: + # Use YAML value + model_data[field_name] = yaml_data[field_name] + # If neither env nor yaml has the value, let Pydantic handle defaults + return model_class(**model_data) + + +def get_required_env_var(name: str) -> str: + """Get a required environment variable, raising an error if not set.""" + value: Optional[str] = os.getenv(name, None) + if value is None: + raise ValueError(f"Required environment variable {name} is not set") + return value + + +# Load shared config data +config_file: str = "config.yaml" +config_path = os.path.join(os.path.dirname(__file__), config_file) +CONFIG_DATA = load_config_from_yaml(config_path) diff --git a/nilai-api/src/nilai_api/config/web_search.py b/nilai-api/src/nilai_api/config/web_search.py new file mode 100644 index 00000000..889ee13f --- /dev/null +++ b/nilai-api/src/nilai_api/config/web_search.py @@ -0,0 +1,18 @@ +from typing import Optional +from pydantic import BaseModel, Field + + +class WebSearchSettings(BaseModel): + api_key: Optional[str] = Field(default=None, description="Brave Search API key") + api_path: str = Field( + default="https://api.search.brave.com/res/v1/web/search", + description="Search API endpoint", + ) + count: int = Field(default=3, description="Number of search results") + lang: str = Field(default="en", description="Search language") + country: str = Field(default="us", description="Search country") + timeout: float = Field(default=20.0, description="Request timeout in seconds") + max_concurrent_requests: int = Field( + default=20, description="Maximum concurrent requests" + ) + rps: int = Field(default=20, description="Requests per second limit") diff --git a/nilai-api/src/nilai_api/db/__init__.py b/nilai-api/src/nilai_api/db/__init__.py index b9dd9d3b..ee70ffe0 100644 --- a/nilai-api/src/nilai_api/db/__init__.py +++ b/nilai-api/src/nilai_api/db/__init__.py @@ -12,7 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker -from nilai_api import config +from nilai_api.config import CONFIG _engine: Optional[sqlalchemy.ext.asyncio.AsyncEngine] = None _SessionLocal: Optional[sessionmaker] = None @@ -40,11 +40,11 @@ class DatabaseConfig: def from_env() -> "DatabaseConfig": database_url = sqlalchemy.engine.url.URL.create( drivername="postgresql+asyncpg", # Use asyncpg driver - username=config.DB_USER, - password=config.DB_PASS, - host=config.DB_HOST, - port=config.DB_PORT, - database=config.DB_NAME, + username=CONFIG.database.user, + password=CONFIG.database.password, + host=CONFIG.database.host, + port=CONFIG.database.port, + database=CONFIG.database.db, ) return DatabaseConfig(database_url) diff --git a/nilai-api/src/nilai_api/db/users.py b/nilai-api/src/nilai_api/db/users.py index 92dfa354..23c5106d 100644 --- a/nilai-api/src/nilai_api/db/users.py +++ b/nilai-api/src/nilai_api/db/users.py @@ -10,14 +10,7 @@ from sqlalchemy.exc import SQLAlchemyError from nilai_api.db import Base, Column, get_db_session -from nilai_api.config import ( - USER_RATE_LIMIT_MINUTE, - USER_RATE_LIMIT_HOUR, - USER_RATE_LIMIT_DAY, - WEB_SEARCH_RATE_LIMIT_MINUTE, - WEB_SEARCH_RATE_LIMIT_HOUR, - WEB_SEARCH_RATE_LIMIT_DAY, -) +from nilai_api.config import CONFIG logger = logging.getLogger(__name__) @@ -36,19 +29,25 @@ class UserModel(Base): DateTime(timezone=True), server_default=sqlalchemy.func.now(), nullable=False ) # type: ignore last_activity: datetime = Column(DateTime(timezone=True), nullable=True) # type: ignore - ratelimit_day: int = Column(Integer, default=USER_RATE_LIMIT_DAY, nullable=True) # type: ignore - ratelimit_hour: int = Column(Integer, default=USER_RATE_LIMIT_HOUR, nullable=True) # type: ignore + ratelimit_day: int = Column( + Integer, default=CONFIG.rate_limiting.user_rate_limit_day, nullable=True + ) # type: ignore + ratelimit_hour: int = Column( + Integer, default=CONFIG.rate_limiting.user_rate_limit_hour, nullable=True + ) # type: ignore ratelimit_minute: int = Column( - Integer, default=USER_RATE_LIMIT_MINUTE, nullable=True + Integer, default=CONFIG.rate_limiting.user_rate_limit_minute, nullable=True ) # type: ignore web_search_ratelimit_day: int = Column( - Integer, default=WEB_SEARCH_RATE_LIMIT_DAY, nullable=True + Integer, default=CONFIG.rate_limiting.web_search_rate_limit_day, nullable=True ) # type: ignore web_search_ratelimit_hour: int = Column( - Integer, default=WEB_SEARCH_RATE_LIMIT_HOUR, nullable=True + Integer, default=CONFIG.rate_limiting.web_search_rate_limit_hour, nullable=True ) # type: ignore web_search_ratelimit_minute: int = Column( - Integer, default=WEB_SEARCH_RATE_LIMIT_MINUTE, nullable=True + Integer, + default=CONFIG.rate_limiting.web_search_rate_limit_minute, + nullable=True, ) # type: ignore def __repr__(self): @@ -133,9 +132,9 @@ async def insert_user( name: str, apikey: str | None = None, userid: str | None = None, - ratelimit_day: int | None = USER_RATE_LIMIT_DAY, - ratelimit_hour: int | None = USER_RATE_LIMIT_HOUR, - ratelimit_minute: int | None = USER_RATE_LIMIT_MINUTE, + ratelimit_day: int | None = CONFIG.rate_limiting.user_rate_limit_day, + ratelimit_hour: int | None = CONFIG.rate_limiting.user_rate_limit_hour, + ratelimit_minute: int | None = CONFIG.rate_limiting.user_rate_limit_minute, ) -> UserModel: """ Insert a new user into the database. @@ -153,10 +152,18 @@ async def insert_user( """ userid = userid if userid else UserManager.generate_user_id() apikey = apikey if apikey else UserManager.generate_api_key() - ratelimit_day = ratelimit_day if ratelimit_day else USER_RATE_LIMIT_DAY - ratelimit_hour = ratelimit_hour if ratelimit_hour else USER_RATE_LIMIT_HOUR + ratelimit_day = ( + ratelimit_day if ratelimit_day else CONFIG.rate_limiting.user_rate_limit_day + ) + ratelimit_hour = ( + ratelimit_hour + if ratelimit_hour + else CONFIG.rate_limiting.user_rate_limit_hour + ) ratelimit_minute = ( - ratelimit_minute if ratelimit_minute else USER_RATE_LIMIT_MINUTE + ratelimit_minute + if ratelimit_minute + else CONFIG.rate_limiting.user_rate_limit_minute ) user = UserModel( userid=userid, diff --git a/nilai-api/src/nilai_api/handlers/nildb/config.py b/nilai-api/src/nilai_api/handlers/nildb/config.py deleted file mode 100644 index 2a94ac03..00000000 --- a/nilai-api/src/nilai_api/handlers/nildb/config.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -from typing import Optional -from dotenv import load_dotenv -from pydantic import BaseModel -from pydantic import Field - -from secretvaults.common.types import Uuid - -load_dotenv() - - -class NilDBConfig(BaseModel): - NILCHAIN_URL: str = Field(..., description="The URL of the Nilchain") - NILAUTH_URL: str = Field(..., description="The URL of the Nilauth") - NODES: list[str] = Field(..., description="The URLs of the Nildb nodes") - BUILDER_PRIVATE_KEY: str = Field(..., description="The private key of the builder") - COLLECTION: Uuid = Field(..., description="The ID of the collection") - - -def get_required_env_var(name: str) -> str: - """Get a required environment variable, raising an error if not set.""" - value: Optional[str] = os.getenv(name, None) - if value is None: - raise ValueError(f"Required environment variable {name} is not set") - return value - - -# Validate environment variables at import time -CONFIG = NilDBConfig( - NILCHAIN_URL=get_required_env_var("NILDB_NILCHAIN_URL"), - NILAUTH_URL=get_required_env_var("NILDB_NILAUTH_URL"), - NODES=get_required_env_var("NILDB_NODES").split(","), - BUILDER_PRIVATE_KEY=get_required_env_var("NILDB_BUILDER_PRIVATE_KEY"), - COLLECTION=Uuid(get_required_env_var("NILDB_COLLECTION")), -) - - -__all__ = ["CONFIG"] diff --git a/nilai-api/src/nilai_api/handlers/nildb/handler.py b/nilai-api/src/nilai_api/handlers/nildb/handler.py index d419ef74..80d06182 100644 --- a/nilai-api/src/nilai_api/handlers/nildb/handler.py +++ b/nilai-api/src/nilai_api/handlers/nildb/handler.py @@ -1,5 +1,5 @@ from typing import Optional -from nilai_api.handlers.nildb.config import CONFIG +from nilai_api.config import CONFIG from secretvaults import SecretVaultBuilderClient, SecretVaultUserClient from secretvaults.common.keypair import Keypair @@ -31,13 +31,13 @@ async def create_builder_client(): return BUILDER_CLIENT # Create keypair from private key - keypair = Keypair.from_hex(CONFIG.BUILDER_PRIVATE_KEY) + keypair = Keypair.from_hex(CONFIG.nildb.builder_private_key) # Prepare URLs for the builder client urls = { - "chain": [CONFIG.NILCHAIN_URL], - "auth": CONFIG.NILAUTH_URL, - "dbs": CONFIG.NODES, + "chain": [CONFIG.nildb.nilchain_url], + "auth": CONFIG.nildb.nilauth_url, + "dbs": CONFIG.nildb.nodes, } # Create SecretVaultBuilderClient with proper initialization @@ -62,10 +62,10 @@ async def create_user_client() -> SecretVaultUserClient: return USER_CLIENT # Create keypair from private key - keypair = Keypair.from_hex(CONFIG.BUILDER_PRIVATE_KEY) + keypair = Keypair.from_hex(CONFIG.nildb.builder_private_key) USER_CLIENT = await SecretVaultUserClient.from_options( keypair=keypair, - base_urls=CONFIG.NODES, + base_urls=CONFIG.nildb.nodes, blindfold=BlindfoldFactoryConfig( operation=BlindfoldOperation.STORE, use_cluster_key=True ), @@ -102,7 +102,7 @@ async def get_nildb_delegation_token(user_did: str) -> PromptDelegationToken: async def get_prompt_from_nildb(prompt_document: PromptDocument) -> str: """Read a specific document - core functionality""" read_params = ReadDataRequestParams( - collection=CONFIG.COLLECTION, + collection=CONFIG.nildb.collection, document=Uuid(prompt_document.document_id), subject=Uuid(prompt_document.owner_did), ) diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index f5ca5ba3..fc027972 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -6,7 +6,7 @@ import httpx from fastapi import HTTPException, status -from nilai_api.config import WEB_SEARCH_SETTINGS +from nilai_api.config import CONFIG from nilai_common.api_model import ( ChatRequest, Message, @@ -26,9 +26,9 @@ _BRAVE_API_PARAMS_BASE = { "summary": 1, - "count": WEB_SEARCH_SETTINGS.count, - "country": WEB_SEARCH_SETTINGS.country, - "lang": WEB_SEARCH_SETTINGS.lang, + "count": CONFIG.web_search.count, + "country": CONFIG.web_search.country, + "lang": CONFIG.web_search.lang, } @@ -40,10 +40,8 @@ def _get_http_client() -> httpx.AsyncClient: An AsyncClient configured with timeouts and connection limits """ return httpx.AsyncClient( - timeout=WEB_SEARCH_SETTINGS.timeout, - limits=httpx.Limits( - max_connections=WEB_SEARCH_SETTINGS.max_concurrent_requests - ), + timeout=CONFIG.web_search.timeout, + limits=httpx.Limits(max_connections=CONFIG.web_search.max_concurrent_requests), ) @@ -59,7 +57,7 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: Raises: HTTPException: If API key is missing or API request fails """ - if not WEB_SEARCH_SETTINGS.api_key: + if not CONFIG.web_search.api_key: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Missing BRAVE_SEARCH_API key in environment", @@ -71,7 +69,7 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: params = {**_BRAVE_API_PARAMS_BASE, "q": q} headers = { **_BRAVE_API_HEADERS, - "X-Subscription-Token": WEB_SEARCH_SETTINGS.api_key, + "X-Subscription-Token": CONFIG.web_search.api_key, } client = _get_http_client() @@ -83,9 +81,7 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: params.get("lang"), params.get("count"), ) - resp = await client.get( - WEB_SEARCH_SETTINGS.api_path, headers=headers, params=params - ) + resp = await client.get(CONFIG.web_search.api_path, headers=headers, params=params) if resp.status_code >= 400: logger.error("Brave API error: %s - %s", resp.status_code, resp.text) diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index a56fe019..2fcf611b 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -3,7 +3,7 @@ from typing import Callable, Tuple, Awaitable, Annotated from pydantic import BaseModel -from nilai_api.config import WEB_SEARCH_SETTINGS +from nilai_api.config import CONFIG from fastapi.params import Depends from fastapi import status, HTTPException, Request @@ -156,11 +156,11 @@ async def __call__( if web_search_enabled: allowed_rps = min( - WEB_SEARCH_SETTINGS.rps, + CONFIG.web_search.rps, max( 1, - WEB_SEARCH_SETTINGS.max_concurrent_requests - // WEB_SEARCH_SETTINGS.count, + CONFIG.web_search.max_concurrent_requests + // CONFIG.web_search.count, ), ) await self.wait_for_bucket( diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index b3a50153..86bf5c24 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -12,7 +12,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, status, Request from fastapi.responses import StreamingResponse from nilai_api.auth import get_auth_info, AuthenticationInfo -from nilai_api.config import MODEL_CONCURRENT_RATE_LIMIT +from nilai_api.config import CONFIG from nilai_api.crypto import sign_message from nilai_api.db.logs import QueryLogManager from nilai_api.db.users import UserManager @@ -143,8 +143,9 @@ async def chat_completion_concurrent_rate_limit(request: Request) -> Tuple[int, except ValueError: raise HTTPException(status_code=400, detail="Invalid request body") key = f"chat:{chat_request.model}" - limit = MODEL_CONCURRENT_RATE_LIMIT.get( - chat_request.model, MODEL_CONCURRENT_RATE_LIMIT.get("default", 50) + limit = CONFIG.rate_limiting.model_concurrent_rate_limit.get( + chat_request.model, + CONFIG.rate_limiting.model_concurrent_rate_limit.get("default", 50), ) return limit, key diff --git a/nilai-api/src/nilai_api/state.py b/nilai-api/src/nilai_api/state.py index cc1352a1..14e0e903 100644 --- a/nilai-api/src/nilai_api/state.py +++ b/nilai-api/src/nilai_api/state.py @@ -3,7 +3,7 @@ from asyncio import Semaphore from typing import Dict, Optional -from nilai_api import config +from nilai_api.config import CONFIG from nilai_api.crypto import generate_key_pair from nilai_common import ModelServiceDiscovery from nilai_common.api_model import ModelEndpoint @@ -17,7 +17,7 @@ def __init__(self): self.sem = Semaphore(2) self.discovery_service = ModelServiceDiscovery( - host=config.ETCD_HOST, port=config.ETCD_PORT + host=CONFIG.etcd.host, port=CONFIG.etcd.port ) self._uptime = time.time() diff --git a/pyproject.toml b/pyproject.toml index 8a7aa14e..da1b7321 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dev = [ "uvicorn>=0.32.1", "pytest-asyncio>=0.25.0", "testcontainers>=4.9.1", - "pyright>=1.1.400", + "pyright>=1.1.405", "pre-commit>=4.1.0", "httpx>=0.28.1", ] @@ -47,7 +47,7 @@ nilai-api = { workspace = true } nilai-models = { workspace = true } nuc-helpers = { workspace = true } [tool.pyright] -exclude = [".venv"] +exclude = ["**/.venv", "**/.venv/**"] [tool.ruff] -exclude = [".venv"] +exclude = ["**/.venv", "**/.venv/**"] diff --git a/tests/e2e/config.py b/tests/e2e/config.py index d1b50a51..2ce0c652 100644 --- a/tests/e2e/config.py +++ b/tests/e2e/config.py @@ -1,26 +1,31 @@ -import os from .nuc import get_nuc_token +from nilai_api.config import CONFIG -ENVIRONMENT = os.getenv("ENVIRONMENT", "ci") +ENVIRONMENT = CONFIG.environment.environment # Left for API key for backwards compatibility -AUTH_TOKEN = os.getenv("AUTH_TOKEN", "") -AUTH_STRATEGY = os.getenv("AUTH_STRATEGY", "nuc") +AUTH_TOKEN = CONFIG.auth.auth_token +AUTH_STRATEGY = CONFIG.auth.auth_strategy match AUTH_STRATEGY: case "nuc": BASE_URL = "https://localhost/nuc/v1" - - def api_key_getter(): - return get_nuc_token().token case "api_key": BASE_URL = "https://localhost/v1" - - def api_key_getter(): - return AUTH_TOKEN case _: raise ValueError(f"Invalid AUTH_STRATEGY: {AUTH_STRATEGY}") +def api_key_getter() -> str: + if AUTH_STRATEGY == "nuc": + return get_nuc_token().token + elif AUTH_STRATEGY == "api_key": + if AUTH_TOKEN is None: + raise ValueError("Expected AUTH_TOKEN to be set") + return AUTH_TOKEN + else: + raise ValueError(f"Invalid AUTH_STRATEGY: {AUTH_STRATEGY}") + + print(f"USING {AUTH_STRATEGY}") models = { "mainnet": [ diff --git a/tests/e2e/test_http.py b/tests/e2e/test_http.py index e8354f51..7cf6ba11 100644 --- a/tests/e2e/test_http.py +++ b/tests/e2e/test_http.py @@ -820,7 +820,7 @@ def test_nildb_delegation(client: httpx.Client): from nuc.envelope import NucTokenEnvelope from nuc.validate import NucTokenValidator, ValidationParameters from nuc.nilauth import NilauthClient - from nilai_api.handlers.nildb.config import CONFIG + from nilai_api.config import CONFIG from nuc.token import Did keypair = Keypair.generate() @@ -841,7 +841,7 @@ def test_nildb_delegation(client: httpx.Client): # Validate the token with nilAuth url for nilDB nuc_token_envelope = NucTokenEnvelope.parse(token) nilauth_public_keys = [ - Did(NilauthClient(CONFIG.NILAUTH_URL).about().public_key.serialize()) + Did(NilauthClient(CONFIG.nildb.nilauth_url).about().public_key.serialize()) ] NucTokenValidator(nilauth_public_keys).validate( nuc_token_envelope, context={}, parameters=ValidationParameters.default() diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 5c72d690..bce60198 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -35,6 +35,7 @@ def _create_openai_client(api_key: str) -> OpenAI: def client(): """Create an OpenAI client configured to use the Nilai API""" invocation_token: str = api_key_getter() + return _create_openai_client(invocation_token) diff --git a/tests/unit/nilai_api/auth/test_auth.py b/tests/unit/nilai_api/auth/test_auth.py index 88780a0b..0cc2fa73 100644 --- a/tests/unit/nilai_api/auth/test_auth.py +++ b/tests/unit/nilai_api/auth/test_auth.py @@ -5,10 +5,10 @@ from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials -import nilai_api.config as config +from nilai_api.config import CONFIG as config # For these tests, we will use the api_key strategy -config.AUTH_STRATEGY = "api_key" +config.auth.auth_strategy = "api_key" @pytest.fixture diff --git a/tests/unit/nilai_api/conftest.py b/tests/unit/nilai_api/conftest.py index fe00a029..592696fb 100644 --- a/tests/unit/nilai_api/conftest.py +++ b/tests/unit/nilai_api/conftest.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, MagicMock from testcontainers.redis import RedisContainer -from nilai_api import config +from nilai_api.config import CONFIG @pytest.fixture(scope="session", autouse=True) @@ -20,5 +20,5 @@ def redis_server(): container.start() host_ip = container.get_container_host_ip() host_port = container.get_exposed_port(6379) - config.REDIS_URL = f"redis://{host_ip}:{host_port}" + CONFIG.redis.url = f"redis://{host_ip}:{host_port}" return container diff --git a/tests/unit/nilai_api/handlers/test_nildb_handler.py b/tests/unit/nilai_api/handlers/test_nildb_handler.py index 850aaa94..20087a38 100644 --- a/tests/unit/nilai_api/handlers/test_nildb_handler.py +++ b/tests/unit/nilai_api/handlers/test_nildb_handler.py @@ -18,11 +18,14 @@ class TestNilDBHandler: def mock_config(self): """Mock configuration for tests""" with patch("nilai_api.handlers.nildb.handler.CONFIG") as mock_config: - mock_config.NILCHAIN_URL = "http://test-nilchain.com" - mock_config.NILAUTH_URL = "http://test-nilauth.com" - mock_config.NODES = ["http://node1.com", "http://node2.com"] - mock_config.BUILDER_PRIVATE_KEY = "0x1234567890abcdef" - mock_config.COLLECTION = Uuid("12345678-1234-1234-1234-123456789012") + # Mock the nested nildb config structure + mock_nildb = MagicMock() + mock_nildb.nilchain_url = "http://test-nilchain.com" + mock_nildb.nilauth_url = "http://test-nilauth.com" + mock_nildb.nodes = ["http://node1.com", "http://node2.com"] + mock_nildb.builder_private_key = "0x1234567890abcdef" + mock_nildb.collection = Uuid("12345678-1234-1234-1234-123456789012") + mock_config.nildb = mock_nildb yield mock_config @pytest.fixture @@ -85,7 +88,7 @@ async def test_create_builder_client(self, mock_config): result = await create_builder_client() mock_keypair_from_hex.assert_called_once_with( - mock_config.BUILDER_PRIVATE_KEY + mock_config.nildb.builder_private_key ) mock_from_options.assert_called_once() mock_client.refresh_root_token.assert_called_once() @@ -113,7 +116,7 @@ async def test_create_user_client(self, mock_config): result = await create_user_client() mock_keypair_from_hex.assert_called_once_with( - mock_config.BUILDER_PRIVATE_KEY + mock_config.nildb.builder_private_key ) mock_from_options.assert_called_once() assert result == mock_client diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index 940c6199..94a55388 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -179,12 +179,12 @@ async def web_search_extractor(request): @pytest.mark.asyncio async def test_global_web_search_rps_limit(req, redis_client, monkeypatch): - from nilai_api import rate_limiting as rl + from nilai_api.config import CONFIG await redis_client[0].delete("global:web_search:rps") - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "rps", 20) - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "max_concurrent_requests", 20) - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "count", 1) + monkeypatch.setattr(CONFIG.web_search, "rps", 20) + monkeypatch.setattr(CONFIG.web_search, "max_concurrent_requests", 20) + monkeypatch.setattr(CONFIG.web_search, "count", 1) rate_limit = RateLimit(web_search_extractor=lambda _: True) user_limits = UserRateLimits( @@ -216,12 +216,12 @@ async def run_guarded(i, times, t0): @pytest.mark.asyncio async def test_queueing_across_seconds(req, redis_client, monkeypatch): - from nilai_api import rate_limiting as rl + from nilai_api.config import CONFIG await redis_client[0].delete("global:web_search:rps") - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "rps", 20) - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "max_concurrent_requests", 20) - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "count", 1) + monkeypatch.setattr(CONFIG.web_search, "rps", 20) + monkeypatch.setattr(CONFIG.web_search, "max_concurrent_requests", 20) + monkeypatch.setattr(CONFIG.web_search, "count", 1) rate_limit = RateLimit(web_search_extractor=lambda _: True) user_limits = UserRateLimits( diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py index 49755226..4fad7c22 100644 --- a/tests/unit/nilai_api/test_web_search.py +++ b/tests/unit/nilai_api/test_web_search.py @@ -33,7 +33,7 @@ async def test_perform_web_search_async_success(): } with ( - patch("nilai_api.handlers.web_search.WEB_SEARCH_SETTINGS.api_key", "test-key"), + patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), patch( "nilai_api.handlers.web_search._make_brave_api_request", return_value=mock_data, @@ -61,7 +61,7 @@ async def test_perform_web_search_async_no_results(): mock_data = {"web": {"results": []}} with ( - patch("nilai_api.handlers.web_search.WEB_SEARCH_SETTINGS.api_key", "test-key"), + patch("nilai_api.handlers.web_search.CONFIG.web_search.api_key", "test-key"), patch( "nilai_api.handlers.web_search._make_brave_api_request", return_value=mock_data, @@ -101,7 +101,7 @@ async def test_perform_web_search_async_concurrent_queries(): } with ( - patch("nilai_api.handlers.web_search.WEB_SEARCH_SETTINGS.api_key", "test-key"), + patch("nilai_api.config.CONFIG.web_search.api_key", "test-key"), patch( "nilai_api.handlers.web_search._make_brave_api_request", side_effect=[mock_data_1, mock_data_2], diff --git a/uv.lock b/uv.lock index 01cde052..11ee21e8 100644 --- a/uv.lock +++ b/uv.lock @@ -1357,7 +1357,7 @@ dev = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "isort", specifier = ">=5.13.2" }, { name = "pre-commit", specifier = ">=4.1.0" }, - { name = "pyright", specifier = ">=1.1.400" }, + { name = "pyright", specifier = ">=1.1.405" }, { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.25.0" }, { name = "pytest-mock", specifier = ">=3.14.0" }, @@ -2080,15 +2080,15 @@ crypto = [ [[package]] name = "pyright" -version = "1.1.400" +version = "1.1.405" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6c/cb/c306618a02d0ee8aed5fb8d0fe0ecfed0dbf075f71468f03a30b5f4e1fe0/pyright-1.1.400.tar.gz", hash = "sha256:b8a3ba40481aa47ba08ffb3228e821d22f7d391f83609211335858bf05686bdb", size = 3846546, upload-time = "2025-04-24T12:55:18.907Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fb/6c/ba4bbee22e76af700ea593a1d8701e3225080956753bee9750dcc25e2649/pyright-1.1.405.tar.gz", hash = "sha256:5c2a30e1037af27eb463a1cc0b9f6d65fec48478ccf092c1ac28385a15c55763", size = 4068319, upload-time = "2025-09-04T03:37:06.776Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/a5/5d285e4932cf149c90e3c425610c5efaea005475d5f96f1bfdb452956c62/pyright-1.1.400-py3-none-any.whl", hash = "sha256:c80d04f98b5a4358ad3a35e241dbf2a408eee33a40779df365644f8054d2517e", size = 5563460, upload-time = "2025-04-24T12:55:17.002Z" }, + { url = "https://files.pythonhosted.org/packages/d5/1a/524f832e1ff1962a22a1accc775ca7b143ba2e9f5924bb6749dce566784a/pyright-1.1.405-py3-none-any.whl", hash = "sha256:a2cb13700b5508ce8e5d4546034cb7ea4aedb60215c6c33f56cec7f53996035a", size = 5905038, upload-time = "2025-09-04T03:37:04.913Z" }, ] [[package]]