diff --git a/.gitignore b/.gitignore index 3711fdfd..0c810a8b 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,4 @@ cache/ oauth_creds/ +usage/ diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 12014bdc..950e79f4 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -603,6 +603,8 @@ async def process_credential(provider: str, path: str, provider_instance): max_concurrent_requests_per_key=max_concurrent_requests_per_key, ) + await client.initialize_usage_managers() + # Log loaded credentials summary (compact, always visible for deployment verification) # _api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none" # _oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none" @@ -956,7 +958,9 @@ async def chat_completions( is_streaming = request_data.get("stream", False) if is_streaming: - response_generator = client.acompletion(request=request, **request_data) + response_generator = await client.acompletion( + request=request, **request_data + ) return StreamingResponse( streaming_response_wrapper( request, request_data, response_generator, raw_logger diff --git a/src/rotator_library/client.py b/src/rotator_library/_client_legacy.py similarity index 100% rename from src/rotator_library/client.py rename to src/rotator_library/_client_legacy.py diff --git a/src/rotator_library/_usage_manager_legacy.py b/src/rotator_library/_usage_manager_legacy.py new file mode 100644 index 00000000..46e30bbc --- /dev/null +++ b/src/rotator_library/_usage_manager_legacy.py @@ -0,0 +1,3980 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +import json +import os +import time +import logging +import asyncio +import random +from datetime import date, datetime, timezone, time as dt_time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Union +import aiofiles +import litellm + +from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential +from .providers import PROVIDER_PLUGINS +from .utils.resilient_io import ResilientStateWriter +from .utils.paths import get_data_file +from .config import ( + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, + DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE, + COOLDOWN_BACKOFF_TIERS, + COOLDOWN_BACKOFF_MAX, + COOLDOWN_AUTH_ERROR, + COOLDOWN_TRANSIENT_ERROR, + COOLDOWN_RATE_LIMIT_DEFAULT, +) + +lib_logger = logging.getLogger("rotator_library") +lib_logger.propagate = False +if not lib_logger.handlers: + lib_logger.addHandler(logging.NullHandler()) + + +class UsageManager: + """ + Manages usage statistics and cooldowns for API keys with asyncio-safe locking, + asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation. + + The credential rotation strategy can be configured via the `rotation_tolerance` parameter: + + - **tolerance = 0.0**: Deterministic least-used selection. The credential with + the lowest usage count is always selected. This provides predictable, perfectly balanced + load distribution but may be vulnerable to fingerprinting. + + - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected + randomly with weights biased toward less-used ones. Credentials within 2 uses of the + maximum can still be selected with reasonable probability. This provides security through + unpredictability while maintaining good load balance. + + - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant + selection probability. Useful for stress testing or maximum unpredictability, but may + result in less balanced load distribution. + + The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1` + + This ensures lower-usage credentials are preferred while tolerance controls how much + randomness is introduced into the selection process. + + Additionally, providers can specify a rotation mode: + - "balanced" (default): Rotate credentials to distribute load evenly + - "sequential": Use one credential until exhausted (preserves caching) + """ + + def __init__( + self, + file_path: Optional[Union[str, Path]] = None, + daily_reset_time_utc: Optional[str] = "03:00", + rotation_tolerance: float = 0.0, + provider_rotation_modes: Optional[Dict[str, str]] = None, + provider_plugins: Optional[Dict[str, Any]] = None, + priority_multipliers: Optional[Dict[str, Dict[int, int]]] = None, + priority_multipliers_by_mode: Optional[ + Dict[str, Dict[str, Dict[int, int]]] + ] = None, + sequential_fallback_multipliers: Optional[Dict[str, int]] = None, + fair_cycle_enabled: Optional[Dict[str, bool]] = None, + fair_cycle_tracking_mode: Optional[Dict[str, str]] = None, + fair_cycle_cross_tier: Optional[Dict[str, bool]] = None, + fair_cycle_duration: Optional[Dict[str, int]] = None, + exhaustion_cooldown_threshold: Optional[Dict[str, int]] = None, + custom_caps: Optional[ + Dict[str, Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]]] + ] = None, + ): + """ + Initialize the UsageManager. + + Args: + file_path: Path to the usage data JSON file. If None, uses get_data_file("key_usage.json"). + Can be absolute Path, relative Path, or string. + daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format) + rotation_tolerance: Tolerance for weighted random credential rotation. + - 0.0: Deterministic, least-used credential always selected + - tolerance = 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max + - 5.0+: High randomness, more unpredictable selection patterns + provider_rotation_modes: Dict mapping provider names to rotation modes. + - "balanced": Rotate credentials to distribute load evenly (default) + - "sequential": Use one credential until exhausted (preserves caching) + provider_plugins: Dict mapping provider names to provider plugin instances. + Used for per-provider usage reset configuration (window durations, field names). + priority_multipliers: Dict mapping provider -> priority -> multiplier. + Universal multipliers that apply regardless of rotation mode. + Example: {"antigravity": {1: 5, 2: 3}} + priority_multipliers_by_mode: Dict mapping provider -> mode -> priority -> multiplier. + Mode-specific overrides. Example: {"antigravity": {"balanced": {3: 1}}} + sequential_fallback_multipliers: Dict mapping provider -> fallback multiplier. + Used in sequential mode when priority not in priority_multipliers. + Example: {"antigravity": 2} + fair_cycle_enabled: Dict mapping provider -> bool to enable fair cycle rotation. + When enabled, credentials must all exhaust before any can be reused. + Default: enabled for sequential mode only. + fair_cycle_tracking_mode: Dict mapping provider -> tracking mode. + - "model_group": Track per quota group or model (default) + - "credential": Track per credential globally + fair_cycle_cross_tier: Dict mapping provider -> bool for cross-tier tracking. + - False: Each tier cycles independently (default) + - True: All credentials must exhaust regardless of tier + fair_cycle_duration: Dict mapping provider -> cycle duration in seconds. + Default: 86400 (24 hours) + exhaustion_cooldown_threshold: Dict mapping provider -> threshold in seconds. + A cooldown must exceed this to qualify as "exhausted". Default: 300 (5 min) + custom_caps: Dict mapping provider -> tier -> model/group -> cap config. + Allows setting custom usage limits per tier, per model or quota group. + See ProviderInterface.default_custom_caps for format details. + """ + # Resolve file_path - use default if not provided + if file_path is None: + self.file_path = str(get_data_file("key_usage.json")) + elif isinstance(file_path, Path): + self.file_path = str(file_path) + else: + # String path - could be relative or absolute + self.file_path = file_path + self.rotation_tolerance = rotation_tolerance + self.provider_rotation_modes = provider_rotation_modes or {} + self.provider_plugins = provider_plugins or PROVIDER_PLUGINS + self.priority_multipliers = priority_multipliers or {} + self.priority_multipliers_by_mode = priority_multipliers_by_mode or {} + self.sequential_fallback_multipliers = sequential_fallback_multipliers or {} + self._provider_instances: Dict[str, Any] = {} # Cache for provider instances + self.key_states: Dict[str, Dict[str, Any]] = {} + + # Fair cycle rotation configuration + self.fair_cycle_enabled = fair_cycle_enabled or {} + self.fair_cycle_tracking_mode = fair_cycle_tracking_mode or {} + self.fair_cycle_cross_tier = fair_cycle_cross_tier or {} + self.fair_cycle_duration = fair_cycle_duration or {} + self.exhaustion_cooldown_threshold = exhaustion_cooldown_threshold or {} + self.custom_caps = custom_caps or {} + # In-memory cycle state: {provider: {tier_key: {tracking_key: {"cycle_started_at": float, "exhausted": Set[str]}}}} + self._cycle_exhausted: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {} + + self._data_lock = asyncio.Lock() + self._usage_data: Optional[Dict] = None + self._initialized = asyncio.Event() + self._init_lock = asyncio.Lock() + + self._timeout_lock = asyncio.Lock() + self._claimed_on_timeout: Set[str] = set() + + # Resilient writer for usage data persistence + self._state_writer = ResilientStateWriter(file_path, lib_logger) + + if daily_reset_time_utc: + hour, minute = map(int, daily_reset_time_utc.split(":")) + self.daily_reset_time_utc = dt_time( + hour=hour, minute=minute, tzinfo=timezone.utc + ) + else: + self.daily_reset_time_utc = None + + def _get_rotation_mode(self, provider: str) -> str: + """ + Get the rotation mode for a provider. + + Args: + provider: Provider name (e.g., "antigravity", "gemini_cli") + + Returns: + "balanced" or "sequential" + """ + return self.provider_rotation_modes.get(provider, "balanced") + + # ========================================================================= + # FAIR CYCLE ROTATION HELPERS + # ========================================================================= + + def _is_fair_cycle_enabled(self, provider: str, rotation_mode: str) -> bool: + """ + Check if fair cycle rotation is enabled for a provider. + + Args: + provider: Provider name + rotation_mode: Current rotation mode ("balanced" or "sequential") + + Returns: + True if fair cycle is enabled + """ + # Check provider-specific setting first + if provider in self.fair_cycle_enabled: + return self.fair_cycle_enabled[provider] + # Default: enabled only for sequential mode + return rotation_mode == "sequential" + + def _get_fair_cycle_tracking_mode(self, provider: str) -> str: + """ + Get fair cycle tracking mode for a provider. + + Returns: + "model_group" or "credential" + """ + return self.fair_cycle_tracking_mode.get(provider, "model_group") + + def _is_fair_cycle_cross_tier(self, provider: str) -> bool: + """ + Check if fair cycle tracks across all tiers (ignoring priority boundaries). + + Returns: + True if cross-tier tracking is enabled + """ + return self.fair_cycle_cross_tier.get(provider, False) + + def _get_fair_cycle_duration(self, provider: str) -> int: + """ + Get fair cycle duration in seconds for a provider. + + Returns: + Duration in seconds (default 86400 = 24 hours) + """ + return self.fair_cycle_duration.get(provider, DEFAULT_FAIR_CYCLE_DURATION) + + def _get_exhaustion_cooldown_threshold(self, provider: str) -> int: + """ + Get exhaustion cooldown threshold in seconds for a provider. + + A cooldown must exceed this duration to qualify as "exhausted" for fair cycle. + + Returns: + Threshold in seconds (default 300 = 5 minutes) + """ + return self.exhaustion_cooldown_threshold.get( + provider, DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD + ) + + # ========================================================================= + # CUSTOM CAPS HELPERS + # ========================================================================= + + def _get_custom_cap_config( + self, + provider: str, + tier_priority: int, + model: str, + ) -> Optional[Dict[str, Any]]: + """ + Get custom cap config for a provider/tier/model combination. + + Resolution order: + 1. tier + model (exact match) + 2. tier + group (model's quota group) + 3. "default" + model + 4. "default" + group + + Args: + provider: Provider name + tier_priority: Credential's priority level + model: Model name (with provider prefix) + + Returns: + Cap config dict or None if no custom cap applies + """ + provider_caps = self.custom_caps.get(provider) + if not provider_caps: + return None + + # Strip provider prefix from model + clean_model = model.split("/")[-1] if "/" in model else model + + # Get quota group for this model + group = self._get_model_quota_group_by_provider(provider, model) + + # Try to find matching tier config + tier_config = None + default_config = None + + for tier_key, models_config in provider_caps.items(): + if tier_key == "default": + default_config = models_config + continue + + # Check if this tier_key matches our priority + if isinstance(tier_key, int) and tier_key == tier_priority: + tier_config = models_config + break + elif isinstance(tier_key, tuple) and tier_priority in tier_key: + tier_config = models_config + break + + # Resolution order for tier config + if tier_config: + # Try model first + if clean_model in tier_config: + return tier_config[clean_model] + # Try group + if group and group in tier_config: + return tier_config[group] + + # Resolution order for default config + if default_config: + # Try model first + if clean_model in default_config: + return default_config[clean_model] + # Try group + if group and group in default_config: + return default_config[group] + + return None + + def _get_model_quota_group_by_provider( + self, provider: str, model: str + ) -> Optional[str]: + """ + Get quota group for a model using provider name instead of credential. + + Args: + provider: Provider name + model: Model name + + Returns: + Group name or None + """ + plugin_instance = self._get_provider_instance(provider) + if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): + return plugin_instance.get_model_quota_group(model) + return None + + def _resolve_custom_cap_max( + self, + provider: str, + model: str, + cap_config: Dict[str, Any], + actual_max: Optional[int], + ) -> Optional[int]: + """ + Resolve custom cap max_requests value, handling percentages and clamping. + + Args: + provider: Provider name + model: Model name (for logging) + cap_config: Custom cap configuration + actual_max: Actual API max requests (may be None if unknown) + + Returns: + Resolved cap value (clamped), or None if can't be calculated + """ + max_requests = cap_config.get("max_requests") + if max_requests is None: + return None + + # Handle percentage + if isinstance(max_requests, str) and max_requests.endswith("%"): + if actual_max is None: + lib_logger.warning( + f"Custom cap '{max_requests}' for {provider}/{model} requires known max_requests. " + f"Skipping until quota baseline is fetched. Use absolute value for immediate enforcement." + ) + return None + try: + percentage = float(max_requests.rstrip("%")) / 100.0 + calculated = int(actual_max * percentage) + except ValueError: + lib_logger.warning( + f"Invalid percentage cap '{max_requests}' for {provider}/{model}" + ) + return None + else: + # Absolute value + try: + calculated = int(max_requests) + except (ValueError, TypeError): + lib_logger.warning( + f"Invalid cap value '{max_requests}' for {provider}/{model}" + ) + return None + + # Clamp to actual max (can only be MORE restrictive) + if actual_max is not None: + return min(calculated, actual_max) + return calculated + + def _calculate_custom_cooldown_until( + self, + cap_config: Dict[str, Any], + window_start_ts: Optional[float], + natural_reset_ts: Optional[float], + ) -> Optional[float]: + """ + Calculate when custom cap cooldown should end, clamped to natural reset. + + Args: + cap_config: Custom cap configuration + window_start_ts: When first request was made (for fixed mode) + natural_reset_ts: Natural quota reset timestamp + + Returns: + Cooldown end timestamp (clamped), or None if can't calculate + """ + mode = cap_config.get("cooldown_mode", DEFAULT_CUSTOM_CAP_COOLDOWN_MODE) + value = cap_config.get("cooldown_value", DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE) + + if mode == "quota_reset": + calculated = natural_reset_ts + elif mode == "offset": + if natural_reset_ts is None: + return None + calculated = natural_reset_ts + value + elif mode == "fixed": + if window_start_ts is None: + return None + calculated = window_start_ts + value + else: + lib_logger.warning(f"Unknown cooldown_mode '{mode}', using quota_reset") + calculated = natural_reset_ts + + if calculated is None: + return None + + # Clamp to natural reset (can only be MORE restrictive = longer cooldown) + if natural_reset_ts is not None: + return max(calculated, natural_reset_ts) + return calculated + + def _check_and_apply_custom_cap( + self, + credential: str, + model: str, + request_count: int, + ) -> bool: + """ + Check if custom cap is exceeded and apply cooldown if so. + + This should be called after incrementing request_count in record_success(). + + Args: + credential: Credential identifier + model: Model name (with provider prefix) + request_count: Current request count for this model + + Returns: + True if cap exceeded and cooldown applied, False otherwise + """ + provider = self._get_provider_from_credential(credential) + if not provider: + return False + + priority = self._get_credential_priority(credential, provider) + cap_config = self._get_custom_cap_config(provider, priority, model) + if not cap_config: + return False + + # Get model data for actual max and timing info + key_data = self._usage_data.get(credential, {}) + model_data = key_data.get("models", {}).get(model, {}) + actual_max = model_data.get("quota_max_requests") + window_start_ts = model_data.get("window_start_ts") + natural_reset_ts = model_data.get("quota_reset_ts") + + # Resolve custom cap max + custom_max = self._resolve_custom_cap_max( + provider, model, cap_config, actual_max + ) + if custom_max is None: + return False + + # Check if exceeded + if request_count < custom_max: + return False + + # Calculate cooldown end time + cooldown_until = self._calculate_custom_cooldown_until( + cap_config, window_start_ts, natural_reset_ts + ) + if cooldown_until is None: + # Can't calculate cooldown, use natural reset if available + if natural_reset_ts: + cooldown_until = natural_reset_ts + else: + lib_logger.warning( + f"Custom cap hit for {mask_credential(credential)}/{model} but can't calculate cooldown. " + f"Skipping cooldown application." + ) + return False + + now_ts = time.time() + + # Apply cooldown + model_cooldowns = key_data.setdefault("model_cooldowns", {}) + model_cooldowns[model] = cooldown_until + + # Store custom cap info in model data for reference + model_data["custom_cap_max"] = custom_max + model_data["custom_cap_hit_at"] = now_ts + model_data["custom_cap_cooldown_until"] = cooldown_until + + hours_until = (cooldown_until - now_ts) / 3600 + lib_logger.info( + f"Custom cap hit: {mask_credential(credential)} reached {request_count}/{custom_max} " + f"for {model}. Cooldown for {hours_until:.1f}h" + ) + + # Sync cooldown across quota group + group = self._get_model_quota_group(credential, model) + if group: + grouped_models = self._get_grouped_models(credential, group) + for grouped_model in grouped_models: + if grouped_model != model: + model_cooldowns[grouped_model] = cooldown_until + + # Check if this should trigger fair cycle exhaustion + cooldown_duration = cooldown_until - now_ts + threshold = self._get_exhaustion_cooldown_threshold(provider) + if cooldown_duration > threshold: + rotation_mode = self._get_rotation_mode(provider) + if self._is_fair_cycle_enabled(provider, rotation_mode): + tier_key = self._get_tier_key(provider, priority) + tracking_key = self._get_tracking_key(credential, model, provider) + self._mark_credential_exhausted( + credential, provider, tier_key, tracking_key + ) + + return True + + def _get_tier_key(self, provider: str, priority: int) -> str: + """ + Get the tier key for cycle tracking based on cross_tier setting. + + Args: + provider: Provider name + priority: Credential priority level + + Returns: + "__all_tiers__" if cross-tier enabled, else str(priority) + """ + if self._is_fair_cycle_cross_tier(provider): + return "__all_tiers__" + return str(priority) + + def _get_tracking_key(self, credential: str, model: str, provider: str) -> str: + """ + Get the key for exhaustion tracking based on tracking mode. + + Args: + credential: Credential identifier + model: Model name (with provider prefix) + provider: Provider name + + Returns: + Tracking key string (quota group name, model name, or "__credential__") + """ + mode = self._get_fair_cycle_tracking_mode(provider) + if mode == "credential": + return "__credential__" + # model_group mode: use quota group if exists, else model + group = self._get_model_quota_group(credential, model) + return group if group else model + + def _get_credential_priority(self, credential: str, provider: str) -> int: + """ + Get the priority level for a credential. + + Args: + credential: Credential identifier + provider: Provider name + + Returns: + Priority level (default 999 if unknown) + """ + plugin_instance = self._get_provider_instance(provider) + if plugin_instance and hasattr(plugin_instance, "get_credential_priority"): + priority = plugin_instance.get_credential_priority(credential) + if priority is not None: + return priority + return 999 + + def _get_cycle_data( + self, provider: str, tier_key: str, tracking_key: str + ) -> Optional[Dict[str, Any]]: + """ + Get cycle data for a provider/tier/tracking key combination. + + Returns: + Cycle data dict or None if not exists + """ + return ( + self._cycle_exhausted.get(provider, {}).get(tier_key, {}).get(tracking_key) + ) + + def _ensure_cycle_structure( + self, provider: str, tier_key: str, tracking_key: str + ) -> Dict[str, Any]: + """ + Ensure the nested cycle structure exists and return the cycle data dict. + """ + if provider not in self._cycle_exhausted: + self._cycle_exhausted[provider] = {} + if tier_key not in self._cycle_exhausted[provider]: + self._cycle_exhausted[provider][tier_key] = {} + if tracking_key not in self._cycle_exhausted[provider][tier_key]: + self._cycle_exhausted[provider][tier_key][tracking_key] = { + "cycle_started_at": None, + "exhausted": set(), + } + return self._cycle_exhausted[provider][tier_key][tracking_key] + + def _mark_credential_exhausted( + self, + credential: str, + provider: str, + tier_key: str, + tracking_key: str, + ) -> None: + """ + Mark a credential as exhausted for fair cycle tracking. + + Starts the cycle timer on first exhaustion. + Skips if credential is already in the exhausted set (prevents duplicate logging). + """ + cycle_data = self._ensure_cycle_structure(provider, tier_key, tracking_key) + + # Skip if already exhausted in this cycle (prevents duplicate logging) + if credential in cycle_data.get("exhausted", set()): + return + + # Start cycle timer on first exhaustion + if cycle_data["cycle_started_at"] is None: + cycle_data["cycle_started_at"] = time.time() + lib_logger.info( + f"Fair cycle started for {provider} tier={tier_key} tracking='{tracking_key}'" + ) + + cycle_data["exhausted"].add(credential) + lib_logger.info( + f"Fair cycle: marked {mask_credential(credential)} exhausted " + f"for {tracking_key} ({len(cycle_data['exhausted'])} total)" + ) + + def _is_credential_exhausted_in_cycle( + self, + credential: str, + provider: str, + tier_key: str, + tracking_key: str, + ) -> bool: + """ + Check if a credential was exhausted in the current cycle. + """ + cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) + if cycle_data is None: + return False + return credential in cycle_data.get("exhausted", set()) + + def _is_cycle_expired( + self, provider: str, tier_key: str, tracking_key: str + ) -> bool: + """ + Check if the current cycle has exceeded its duration. + """ + cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) + if cycle_data is None: + return False + cycle_started = cycle_data.get("cycle_started_at") + if cycle_started is None: + return False + duration = self._get_fair_cycle_duration(provider) + return time.time() >= cycle_started + duration + + def _should_reset_cycle( + self, + provider: str, + tier_key: str, + tracking_key: str, + all_credentials_in_tier: List[str], + available_not_on_cooldown: Optional[List[str]] = None, + ) -> bool: + """ + Check if cycle should reset. + + Returns True if: + 1. Cycle duration has expired, OR + 2. No credentials remain available (after cooldown + fair cycle exclusion), OR + 3. All credentials in the tier have been marked exhausted (fallback) + """ + # Check duration first + if self._is_cycle_expired(provider, tier_key, tracking_key): + return True + + cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) + if cycle_data is None: + return False + + # If available credentials are provided, reset when none remain usable + if available_not_on_cooldown is not None: + has_available = any( + not self._is_credential_exhausted_in_cycle( + cred, provider, tier_key, tracking_key + ) + for cred in available_not_on_cooldown + ) + if not has_available and len(all_credentials_in_tier) > 0: + return True + + exhausted = cycle_data.get("exhausted", set()) + # All must be exhausted (and there must be at least one credential) + return ( + len(exhausted) >= len(all_credentials_in_tier) + and len(all_credentials_in_tier) > 0 + ) + + def _reset_cycle(self, provider: str, tier_key: str, tracking_key: str) -> None: + """ + Reset exhaustion tracking for a completed cycle. + """ + cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) + if cycle_data: + exhausted_count = len(cycle_data.get("exhausted", set())) + lib_logger.info( + f"Fair cycle complete for {provider} tier={tier_key} " + f"tracking='{tracking_key}' - resetting ({exhausted_count} credentials cycled)" + ) + cycle_data["cycle_started_at"] = None + cycle_data["exhausted"] = set() + + def _get_all_credentials_for_tier_key( + self, + provider: str, + tier_key: str, + available_keys: List[str], + credential_priorities: Optional[Dict[str, int]], + ) -> List[str]: + """ + Get all credentials that belong to a tier key. + + Args: + provider: Provider name + tier_key: Either "__all_tiers__" or str(priority) + available_keys: List of available credential identifiers + credential_priorities: Dict mapping credentials to priorities + + Returns: + List of credentials belonging to this tier key + """ + if tier_key == "__all_tiers__": + # Cross-tier: all credentials for this provider + return list(available_keys) + else: + # Within-tier: only credentials with matching priority + priority = int(tier_key) + if credential_priorities: + return [ + k + for k in available_keys + if credential_priorities.get(k, 999) == priority + ] + return list(available_keys) + + def _count_fair_cycle_excluded( + self, + provider: str, + tier_key: str, + tracking_key: str, + candidates: List[str], + ) -> int: + """ + Count how many candidates are excluded by fair cycle. + + Args: + provider: Provider name + tier_key: Tier key for tracking + tracking_key: Model/group tracking key + candidates: List of candidate credentials (not on cooldown) + + Returns: + Number of candidates excluded by fair cycle + """ + count = 0 + for cred in candidates: + if self._is_credential_exhausted_in_cycle( + cred, provider, tier_key, tracking_key + ): + count += 1 + return count + + def _get_priority_multiplier( + self, provider: str, priority: int, rotation_mode: str + ) -> int: + """ + Get the concurrency multiplier for a provider/priority/mode combination. + + Lookup order: + 1. Mode-specific tier override: priority_multipliers_by_mode[provider][mode][priority] + 2. Universal tier multiplier: priority_multipliers[provider][priority] + 3. Sequential fallback (if mode is sequential): sequential_fallback_multipliers[provider] + 4. Global default: 1 (no multiplier effect) + + Args: + provider: Provider name (e.g., "antigravity") + priority: Priority level (1 = highest priority) + rotation_mode: Current rotation mode ("sequential" or "balanced") + + Returns: + Multiplier value + """ + provider_lower = provider.lower() + + # 1. Check mode-specific override + if provider_lower in self.priority_multipliers_by_mode: + mode_multipliers = self.priority_multipliers_by_mode[provider_lower] + if rotation_mode in mode_multipliers: + if priority in mode_multipliers[rotation_mode]: + return mode_multipliers[rotation_mode][priority] + + # 2. Check universal tier multiplier + if provider_lower in self.priority_multipliers: + if priority in self.priority_multipliers[provider_lower]: + return self.priority_multipliers[provider_lower][priority] + + # 3. Sequential fallback (only for sequential mode) + if rotation_mode == "sequential": + if provider_lower in self.sequential_fallback_multipliers: + return self.sequential_fallback_multipliers[provider_lower] + + # 4. Global default + return 1 + + def _get_provider_from_credential(self, credential: str) -> Optional[str]: + """ + Extract provider name from credential path or identifier. + + Supports multiple credential formats: + - OAuth: "oauth_creds/antigravity_oauth_15.json" -> "antigravity" + - OAuth: "C:\\...\\oauth_creds\\gemini_cli_oauth_1.json" -> "gemini_cli" + - OAuth filename only: "antigravity_oauth_1.json" -> "antigravity" + - API key style: extracted from model names in usage data (e.g., "firmware/model" -> "firmware") + + Args: + credential: The credential identifier (path or key) + + Returns: + Provider name string or None if cannot be determined + """ + import re + + # Pattern: env:// URI format (e.g., "env://antigravity/1" -> "antigravity") + if credential.startswith("env://"): + parts = credential[6:].split("/") # Remove "env://" prefix + if parts and parts[0]: + return parts[0].lower() + # Malformed env:// URI (empty provider name) + lib_logger.warning(f"Malformed env:// credential URI: {credential}") + return None + + # Normalize path separators + normalized = credential.replace("\\", "/") + + # Pattern: path ending with {provider}_oauth_{number}.json + match = re.search(r"/([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) + if match: + return match.group(1).lower() + + # Pattern: oauth_creds/{provider}_... + match = re.search(r"oauth_creds/([a-z_]+)_", normalized, re.IGNORECASE) + if match: + return match.group(1).lower() + + # Pattern: filename only {provider}_oauth_{number}.json (no path) + match = re.match(r"([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) + if match: + return match.group(1).lower() + + # Pattern: API key prefixes for specific providers + # These are raw API keys with recognizable prefixes + api_key_prefixes = { + "sk-nano-": "nanogpt", + "sk-or-": "openrouter", + "sk-ant-": "anthropic", + } + for prefix, provider in api_key_prefixes.items(): + if credential.startswith(prefix): + return provider + + # Fallback: For raw API keys, extract provider from model names in usage data + # This handles providers like firmware, chutes, nanogpt that use credential-level quota + if self._usage_data and credential in self._usage_data: + cred_data = self._usage_data[credential] + + # Check "models" section first (for per_model mode and quota tracking) + models_data = cred_data.get("models", {}) + if models_data: + # Get first model name and extract provider prefix + first_model = next(iter(models_data.keys()), None) + if first_model and "/" in first_model: + provider = first_model.split("/")[0].lower() + return provider + + # Fallback to "daily" section (legacy structure) + daily_data = cred_data.get("daily", {}) + daily_models = daily_data.get("models", {}) + if daily_models: + # Get first model name and extract provider prefix + first_model = next(iter(daily_models.keys()), None) + if first_model and "/" in first_model: + provider = first_model.split("/")[0].lower() + return provider + + return None + + def _get_provider_instance(self, provider: str) -> Optional[Any]: + """ + Get or create a provider plugin instance. + + Args: + provider: The provider name + + Returns: + Provider plugin instance or None + """ + if not provider: + return None + + plugin_class = self.provider_plugins.get(provider) + if not plugin_class: + return None + + # Get or create provider instance from cache + if provider not in self._provider_instances: + # Instantiate the plugin if it's a class, or use it directly if already an instance + if isinstance(plugin_class, type): + self._provider_instances[provider] = plugin_class() + else: + self._provider_instances[provider] = plugin_class + + return self._provider_instances[provider] + + def _get_usage_reset_config(self, credential: str) -> Optional[Dict[str, Any]]: + """ + Get the usage reset configuration for a credential from its provider plugin. + + Args: + credential: The credential identifier + + Returns: + Configuration dict with window_seconds, field_name, etc. + or None to use default daily reset. + """ + provider = self._get_provider_from_credential(credential) + plugin_instance = self._get_provider_instance(provider) + + if plugin_instance and hasattr(plugin_instance, "get_usage_reset_config"): + return plugin_instance.get_usage_reset_config(credential) + + return None + + def _get_reset_mode(self, credential: str) -> str: + """ + Get the reset mode for a credential: 'credential' or 'per_model'. + + Args: + credential: The credential identifier + + Returns: + "per_model" or "credential" (default) + """ + config = self._get_usage_reset_config(credential) + return config.get("mode", "credential") if config else "credential" + + def _get_model_quota_group(self, credential: str, model: str) -> Optional[str]: + """ + Get the quota group for a model, if the provider defines one. + + Args: + credential: The credential identifier + model: Model name (with or without provider prefix) + + Returns: + Group name (e.g., "claude") or None if not grouped + """ + provider = self._get_provider_from_credential(credential) + plugin_instance = self._get_provider_instance(provider) + + if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): + return plugin_instance.get_model_quota_group(model) + + return None + + def _get_grouped_models(self, credential: str, group: str) -> List[str]: + """ + Get all model names in a quota group (with provider prefix), normalized. + + Returns only public-facing model names, deduplicated. Internal variants + (e.g., claude-sonnet-4-5-thinking) are normalized to their public name + (e.g., claude-sonnet-4.5). + + Args: + credential: The credential identifier + group: Group name (e.g., "claude") + + Returns: + List of normalized, deduplicated model names with provider prefix + (e.g., ["antigravity/claude-sonnet-4.5", "antigravity/claude-opus-4.5"]) + """ + provider = self._get_provider_from_credential(credential) + plugin_instance = self._get_provider_instance(provider) + + if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"): + models = plugin_instance.get_models_in_quota_group(group) + + # Normalize and deduplicate + if hasattr(plugin_instance, "normalize_model_for_tracking"): + seen = set() + normalized = [] + for m in models: + prefixed = f"{provider}/{m}" + norm = plugin_instance.normalize_model_for_tracking(prefixed) + if norm not in seen: + seen.add(norm) + normalized.append(norm) + return normalized + + # Fallback: just add provider prefix + return [f"{provider}/{m}" for m in models] + + return [] + + def _get_model_usage_weight(self, credential: str, model: str) -> int: + """ + Get the usage weight for a model when calculating grouped usage. + + Args: + credential: The credential identifier + model: Model name (with or without provider prefix) + + Returns: + Weight multiplier (default 1 if not configured) + """ + provider = self._get_provider_from_credential(credential) + plugin_instance = self._get_provider_instance(provider) + + if plugin_instance and hasattr(plugin_instance, "get_model_usage_weight"): + return plugin_instance.get_model_usage_weight(model) + + return 1 + + def _normalize_model(self, credential: str, model: str) -> str: + """ + Normalize model name using provider's mapping. + + Converts internal model names (e.g., claude-sonnet-4-5-thinking) to + public-facing names (e.g., claude-sonnet-4.5) for consistent storage. + + Args: + credential: The credential identifier + model: Model name (with or without provider prefix) + + Returns: + Normalized model name (provider prefix preserved if present) + """ + provider = self._get_provider_from_credential(credential) + plugin_instance = self._get_provider_instance(provider) + + if plugin_instance and hasattr(plugin_instance, "normalize_model_for_tracking"): + return plugin_instance.normalize_model_for_tracking(model) + + return model + + # Providers where request_count should be used for credential selection + # instead of success_count (because failed requests also consume quota) + _REQUEST_COUNT_PROVIDERS = {"antigravity", "gemini_cli", "chutes", "nanogpt"} + + def _get_grouped_usage_count(self, key: str, model: str) -> int: + """ + Get usage count for credential selection, considering quota groups. + + For providers in _REQUEST_COUNT_PROVIDERS (e.g., antigravity), uses + request_count instead of success_count since failed requests also + consume quota. + + If the model belongs to a quota group, the request_count is already + synced across all models in the group (by record_success/record_failure), + so we just read from the requested model directly. + + Args: + key: Credential identifier + model: Model name (with provider prefix, e.g., "antigravity/claude-sonnet-4-5") + + Returns: + Usage count for the model (synced across group if applicable) + """ + # Determine usage field based on provider + # Some providers (antigravity) count failed requests against quota + provider = self._get_provider_from_credential(key) + usage_field = ( + "request_count" + if provider in self._REQUEST_COUNT_PROVIDERS + else "success_count" + ) + + # For providers with synced quota groups (antigravity), request_count + # is already synced across all models in the group, so just read directly. + # For other providers, we still need to sum success_count across group. + if provider in self._REQUEST_COUNT_PROVIDERS: + # request_count is synced - just read the model's value + return self._get_usage_count(key, model, usage_field) + + # For non-synced providers, check if model is in a quota group and sum + group = self._get_model_quota_group(key, model) + + if group: + # Get all models in the group + grouped_models = self._get_grouped_models(key, group) + + # Sum weighted usage across all models in the group + total_weighted_usage = 0 + for grouped_model in grouped_models: + usage = self._get_usage_count(key, grouped_model, usage_field) + weight = self._get_model_usage_weight(key, grouped_model) + total_weighted_usage += usage * weight + return total_weighted_usage + + # Not grouped - return individual model usage (no weight applied) + return self._get_usage_count(key, model, usage_field) + + def _get_quota_display(self, key: str, model: str) -> str: + """ + Get a formatted quota display string for logging. + + For antigravity (providers in _REQUEST_COUNT_PROVIDERS), returns: + "quota: 170/250 [32%]" format + + For other providers, returns: + "usage: 170" format (no max available) + + Args: + key: Credential identifier + model: Model name (with provider prefix) + + Returns: + Formatted string for logging + """ + provider = self._get_provider_from_credential(key) + + if provider not in self._REQUEST_COUNT_PROVIDERS: + # Non-antigravity: just show usage count + usage = self._get_usage_count(key, model, "success_count") + return f"usage: {usage}" + + # Antigravity: show quota display with remaining percentage + if self._usage_data is None: + return "quota: 0/? [100%]" + + # Normalize model name for consistent lookup (data is stored under normalized names) + model = self._normalize_model(key, model) + + key_data = self._usage_data.get(key, {}) + model_data = key_data.get("models", {}).get(model, {}) + + request_count = model_data.get("request_count", 0) + max_requests = model_data.get("quota_max_requests") + + if max_requests: + remaining = max_requests - request_count + remaining_pct = ( + int((remaining / max_requests) * 100) if max_requests > 0 else 0 + ) + return f"quota: {request_count}/{max_requests} [{remaining_pct}%]" + else: + return f"quota: {request_count}" + + def _get_usage_field_name(self, credential: str) -> str: + """ + Get the usage tracking field name for a credential. + + Returns the provider-specific field name if configured, + otherwise falls back to "daily". + + Args: + credential: The credential identifier + + Returns: + Field name string (e.g., "5h_window", "weekly", "daily") + """ + config = self._get_usage_reset_config(credential) + if config and "field_name" in config: + return config["field_name"] + + # Check provider default + provider = self._get_provider_from_credential(credential) + plugin_instance = self._get_provider_instance(provider) + + if plugin_instance and hasattr(plugin_instance, "get_default_usage_field_name"): + return plugin_instance.get_default_usage_field_name() + + return "daily" + + def _get_usage_count( + self, key: str, model: str, field: str = "success_count" + ) -> int: + """ + Get the current usage count for a model from the appropriate usage structure. + + Supports both: + - New per-model structure: {"models": {"model_name": {"success_count": N, ...}}} + - Legacy structure: {"daily": {"models": {"model_name": {"success_count": N, ...}}}} + + Args: + key: Credential identifier + model: Model name + field: The field to read for usage count (default: "success_count"). + Use "request_count" for providers where failed requests also + consume quota (e.g., antigravity). + + Returns: + Usage count for the model in the current window/period + """ + if self._usage_data is None: + return 0 + + # Normalize model name for consistent lookup (data is stored under normalized names) + model = self._normalize_model(key, model) + + key_data = self._usage_data.get(key, {}) + reset_mode = self._get_reset_mode(key) + + if reset_mode == "per_model": + # New per-model structure: key_data["models"][model][field] + return key_data.get("models", {}).get(model, {}).get(field, 0) + else: + # Legacy structure: key_data["daily"]["models"][model][field] + return ( + key_data.get("daily", {}).get("models", {}).get(model, {}).get(field, 0) + ) + + # ========================================================================= + # TIMESTAMP FORMATTING HELPERS + # ========================================================================= + + def _format_timestamp_local(self, ts: Optional[float]) -> Optional[str]: + """ + Format Unix timestamp as local time string with timezone offset. + + Args: + ts: Unix timestamp or None + + Returns: + Formatted string like "2025-12-07 14:30:17 +0100" or None + """ + if ts is None: + return None + try: + dt = datetime.fromtimestamp(ts).astimezone() # Local timezone + # Use UTC offset for conciseness (works on all platforms) + return dt.strftime("%Y-%m-%d %H:%M:%S %z") + except (OSError, ValueError, OverflowError): + return None + + def _add_readable_timestamps(self, data: Dict) -> Dict: + """ + Add human-readable timestamp fields to usage data before saving. + + Adds 'window_started' and 'quota_resets' fields derived from + Unix timestamps for easier debugging and monitoring. + + Args: + data: The usage data dict to enhance + + Returns: + The same dict with readable timestamp fields added + """ + for key, key_data in data.items(): + # Handle per-model structure + models = key_data.get("models", {}) + for model_name, model_stats in models.items(): + if not isinstance(model_stats, dict): + continue + + # Add readable window start time + window_start = model_stats.get("window_start_ts") + if window_start: + model_stats["window_started"] = self._format_timestamp_local( + window_start + ) + elif "window_started" in model_stats: + del model_stats["window_started"] + + # Add readable reset time + quota_reset = model_stats.get("quota_reset_ts") + if quota_reset: + model_stats["quota_resets"] = self._format_timestamp_local( + quota_reset + ) + elif "quota_resets" in model_stats: + del model_stats["quota_resets"] + + return data + + def _sort_sequential( + self, + candidates: List[Tuple[str, int]], + credential_priorities: Optional[Dict[str, int]] = None, + ) -> List[Tuple[str, int]]: + """ + Sort credentials for sequential mode with position retention. + + Credentials maintain their position based on established usage patterns, + ensuring that actively-used credentials remain primary until exhausted. + + Sorting order (within each sort key, lower value = higher priority): + 1. Priority tier (lower number = higher priority) + 2. Usage count (higher = more established in rotation, maintains position) + 3. Last used timestamp (higher = more recent, tiebreaker for stickiness) + 4. Credential ID (alphabetical, stable ordering) + + Args: + candidates: List of (credential_id, usage_count) tuples + credential_priorities: Optional dict mapping credentials to priority levels + + Returns: + Sorted list of candidates (same format as input) + """ + if not candidates: + return [] + + if len(candidates) == 1: + return candidates + + def sort_key(item: Tuple[str, int]) -> Tuple[int, int, float, str]: + cred, usage_count = item + priority = ( + credential_priorities.get(cred, 999) if credential_priorities else 999 + ) + last_used = ( + self._usage_data.get(cred, {}).get("last_used_ts", 0) + if self._usage_data + else 0 + ) + return ( + priority, # ASC: lower priority number = higher priority + -usage_count, # DESC: higher usage = more established + -last_used, # DESC: more recent = preferred for ties + cred, # ASC: stable alphabetical ordering + ) + + sorted_candidates = sorted(candidates, key=sort_key) + + # Debug logging - show top 3 credentials in ordering + if lib_logger.isEnabledFor(logging.DEBUG): + order_info = [ + f"{mask_credential(c)}(p={credential_priorities.get(c, 999) if credential_priorities else 'N/A'}, u={u})" + for c, u in sorted_candidates[:3] + ] + lib_logger.debug(f"Sequential ordering: {' → '.join(order_info)}") + + return sorted_candidates + + # ========================================================================= + # FAIR CYCLE PERSISTENCE + # ========================================================================= + + def _serialize_cycle_state(self) -> Dict[str, Any]: + """ + Serialize in-memory cycle state for JSON persistence. + + Converts sets to lists for JSON compatibility. + """ + result: Dict[str, Any] = {} + for provider, tier_data in self._cycle_exhausted.items(): + result[provider] = {} + for tier_key, tracking_data in tier_data.items(): + result[provider][tier_key] = {} + for tracking_key, cycle_data in tracking_data.items(): + result[provider][tier_key][tracking_key] = { + "cycle_started_at": cycle_data.get("cycle_started_at"), + "exhausted": list(cycle_data.get("exhausted", set())), + } + return result + + def _deserialize_cycle_state(self, data: Dict[str, Any]) -> None: + """ + Deserialize cycle state from JSON and populate in-memory structure. + + Converts lists back to sets and validates expired cycles. + """ + self._cycle_exhausted = {} + now_ts = time.time() + + for provider, tier_data in data.items(): + if not isinstance(tier_data, dict): + continue + self._cycle_exhausted[provider] = {} + + for tier_key, tracking_data in tier_data.items(): + if not isinstance(tracking_data, dict): + continue + self._cycle_exhausted[provider][tier_key] = {} + + for tracking_key, cycle_data in tracking_data.items(): + if not isinstance(cycle_data, dict): + continue + + cycle_started = cycle_data.get("cycle_started_at") + exhausted_list = cycle_data.get("exhausted", []) + + # Check if cycle has expired + if cycle_started is not None: + duration = self._get_fair_cycle_duration(provider) + if now_ts >= cycle_started + duration: + # Cycle expired - skip (don't restore) + lib_logger.debug( + f"Fair cycle expired for {provider}/{tier_key}/{tracking_key} - not restoring" + ) + continue + + # Restore valid cycle + self._cycle_exhausted[provider][tier_key][tracking_key] = { + "cycle_started_at": cycle_started, + "exhausted": set(exhausted_list) if exhausted_list else set(), + } + + # Log restoration summary + total_cycles = sum( + len(tracking) + for tier in self._cycle_exhausted.values() + for tracking in tier.values() + ) + if total_cycles > 0: + lib_logger.info(f"Restored {total_cycles} active fair cycle(s) from disk") + + async def _lazy_init(self): + """Initializes the usage data by loading it from the file asynchronously.""" + async with self._init_lock: + if not self._initialized.is_set(): + await self._load_usage() + await self._reset_daily_stats_if_needed() + self._initialized.set() + + async def _load_usage(self): + """Loads usage data from the JSON file asynchronously with resilience.""" + async with self._data_lock: + if not os.path.exists(self.file_path): + self._usage_data = {} + return + + try: + async with aiofiles.open(self.file_path, "r") as f: + content = await f.read() + self._usage_data = json.loads(content) if content.strip() else {} + except FileNotFoundError: + # File deleted between exists check and open + self._usage_data = {} + except json.JSONDecodeError as e: + lib_logger.warning( + f"Corrupted usage file {self.file_path}: {e}. Starting fresh." + ) + self._usage_data = {} + except (OSError, PermissionError, IOError) as e: + lib_logger.warning( + f"Cannot read usage file {self.file_path}: {e}. Using empty state." + ) + self._usage_data = {} + + # Restore fair cycle state from persisted data + fair_cycle_data = self._usage_data.get("__fair_cycle__", {}) + if fair_cycle_data: + self._deserialize_cycle_state(fair_cycle_data) + + async def _save_usage(self): + """Saves the current usage data using the resilient state writer.""" + if self._usage_data is None: + return + + async with self._data_lock: + # Add human-readable timestamp fields before saving + self._add_readable_timestamps(self._usage_data) + + # Persist fair cycle state (separate from credential data) + if self._cycle_exhausted: + self._usage_data["__fair_cycle__"] = self._serialize_cycle_state() + elif "__fair_cycle__" in self._usage_data: + # Clean up empty cycle data + del self._usage_data["__fair_cycle__"] + + # Hand off to resilient writer - handles retries and disk failures + self._state_writer.write(self._usage_data) + + async def _get_usage_data_snapshot(self) -> Dict[str, Any]: + """ + Get a shallow copy of the current usage data. + + Returns: + Copy of usage data dict (safe for reading without lock) + """ + await self._lazy_init() + async with self._data_lock: + return dict(self._usage_data) if self._usage_data else {} + + async def get_available_credentials_for_model( + self, credentials: List[str], model: str + ) -> List[str]: + """ + Get credentials that are not on cooldown for a specific model. + + Filters out credentials where: + - key_cooldown_until > now (key-level cooldown) + - model_cooldowns[model] > now (model-specific cooldown, includes quota exhausted) + + Args: + credentials: List of credential identifiers to check + model: Model name to check cooldowns for + + Returns: + List of credentials that are available (not on cooldown) for this model + """ + await self._lazy_init() + now = time.time() + available = [] + + async with self._data_lock: + for key in credentials: + key_data = self._usage_data.get(key, {}) + + # Skip if key-level cooldown is active + if (key_data.get("key_cooldown_until") or 0) > now: + continue + + # Normalize model name for consistent cooldown lookup + # (cooldowns are stored under normalized names by record_failure) + # For providers without normalize_model_for_tracking (non-Antigravity), + # this returns the model unchanged, so cooldown lookups work as before. + normalized_model = self._normalize_model(key, model) + + # Skip if model-specific cooldown is active + if ( + key_data.get("model_cooldowns", {}).get(normalized_model) or 0 + ) > now: + continue + + available.append(key) + + return available + + async def get_credential_availability_stats( + self, + credentials: List[str], + model: str, + credential_priorities: Optional[Dict[str, int]] = None, + ) -> Dict[str, int]: + """ + Get credential availability statistics including cooldown and fair cycle exclusions. + + This is used for logging to show why credentials are excluded. + + Args: + credentials: List of credential identifiers to check + model: Model name to check + credential_priorities: Optional dict mapping credentials to priorities + + Returns: + Dict with: + "total": Total credentials + "on_cooldown": Count on cooldown + "fair_cycle_excluded": Count excluded by fair cycle + "available": Count available for selection + """ + await self._lazy_init() + now = time.time() + + total = len(credentials) + on_cooldown = 0 + not_on_cooldown = [] + + # First pass: check cooldowns + async with self._data_lock: + for key in credentials: + key_data = self._usage_data.get(key, {}) + + # Check if key-level or model-level cooldown is active + normalized_model = self._normalize_model(key, model) + if (key_data.get("key_cooldown_until") or 0) > now or ( + key_data.get("model_cooldowns", {}).get(normalized_model) or 0 + ) > now: + on_cooldown += 1 + else: + not_on_cooldown.append(key) + + # Second pass: check fair cycle exclusions (only for non-cooldown credentials) + fair_cycle_excluded = 0 + if not_on_cooldown: + provider = self._get_provider_from_credential(not_on_cooldown[0]) + if provider: + rotation_mode = self._get_rotation_mode(provider) + if self._is_fair_cycle_enabled(provider, rotation_mode): + # Check each credential against its own tier's exhausted set + for key in not_on_cooldown: + key_priority = ( + credential_priorities.get(key, 999) + if credential_priorities + else 999 + ) + tier_key = self._get_tier_key(provider, key_priority) + tracking_key = self._get_tracking_key(key, model, provider) + + if self._is_credential_exhausted_in_cycle( + key, provider, tier_key, tracking_key + ): + fair_cycle_excluded += 1 + + available = total - on_cooldown - fair_cycle_excluded + + return { + "total": total, + "on_cooldown": on_cooldown, + "fair_cycle_excluded": fair_cycle_excluded, + "available": available, + } + + async def get_soonest_cooldown_end( + self, + credentials: List[str], + model: str, + ) -> Optional[float]: + """ + Find the soonest time when any credential will come off cooldown. + + This is used for smart waiting logic - if no credentials are available, + we can determine whether to wait (if soonest cooldown < deadline) or + fail fast (if soonest cooldown > deadline). + + Args: + credentials: List of credential identifiers to check + model: Model name to check cooldowns for + + Returns: + Timestamp of soonest cooldown end, or None if no credentials are on cooldown + """ + await self._lazy_init() + now = time.time() + soonest_end = None + + async with self._data_lock: + for key in credentials: + key_data = self._usage_data.get(key, {}) + normalized_model = self._normalize_model(key, model) + + # Check key-level cooldown + key_cooldown = key_data.get("key_cooldown_until") or 0 + if key_cooldown > now: + if soonest_end is None or key_cooldown < soonest_end: + soonest_end = key_cooldown + + # Check model-level cooldown + model_cooldown = ( + key_data.get("model_cooldowns", {}).get(normalized_model) or 0 + ) + if model_cooldown > now: + if soonest_end is None or model_cooldown < soonest_end: + soonest_end = model_cooldown + + return soonest_end + + async def _reset_daily_stats_if_needed(self): + """ + Checks if usage stats need to be reset for any key. + + Supports three reset modes: + 1. per_model: Each model has its own window, resets based on quota_reset_ts or fallback window + 2. credential: One window per credential (legacy with custom window duration) + 3. daily: Legacy daily reset at daily_reset_time_utc + """ + if self._usage_data is None: + return + + now_utc = datetime.now(timezone.utc) + now_ts = time.time() + today_str = now_utc.date().isoformat() + needs_saving = False + + for key, data in self._usage_data.items(): + reset_config = self._get_usage_reset_config(key) + + if reset_config: + reset_mode = reset_config.get("mode", "credential") + + if reset_mode == "per_model": + # Per-model window reset + needs_saving |= await self._check_per_model_resets( + key, data, reset_config, now_ts + ) + else: + # Credential-level window reset (legacy) + needs_saving |= await self._check_window_reset( + key, data, reset_config, now_ts + ) + elif self.daily_reset_time_utc: + # Legacy daily reset + needs_saving |= await self._check_daily_reset( + key, data, now_utc, today_str, now_ts + ) + + if needs_saving: + await self._save_usage() + + async def _check_per_model_resets( + self, + key: str, + data: Dict[str, Any], + reset_config: Dict[str, Any], + now_ts: float, + ) -> bool: + """ + Check and perform per-model resets for a credential. + + Each model resets independently based on: + 1. quota_reset_ts (authoritative, from quota exhausted error) if set + 2. window_start_ts + window_seconds (fallback) otherwise + + Grouped models reset together - all models in a group must be ready. + + Args: + key: Credential identifier + data: Usage data for this credential + reset_config: Provider's reset configuration + now_ts: Current timestamp + + Returns: + True if data was modified and needs saving + """ + window_seconds = reset_config.get("window_seconds", 86400) + models_data = data.get("models", {}) + + if not models_data: + return False + + modified = False + processed_groups = set() + + for model, model_data in list(models_data.items()): + # Check if this model is in a quota group + group = self._get_model_quota_group(key, model) + + if group: + if group in processed_groups: + continue # Already handled this group + + # Check if entire group should reset + if self._should_group_reset( + key, group, models_data, window_seconds, now_ts + ): + # Archive and reset all models in group + grouped_models = self._get_grouped_models(key, group) + archived_count = 0 + + for grouped_model in grouped_models: + if grouped_model in models_data: + gm_data = models_data[grouped_model] + self._archive_model_to_global(data, grouped_model, gm_data) + self._reset_model_data(gm_data) + archived_count += 1 + + if archived_count > 0: + lib_logger.info( + f"Reset model group '{group}' ({archived_count} models) for {mask_credential(key)}" + ) + modified = True + + processed_groups.add(group) + + else: + # Ungrouped model - check individually + if self._should_model_reset(model_data, window_seconds, now_ts): + self._archive_model_to_global(data, model, model_data) + self._reset_model_data(model_data) + lib_logger.info(f"Reset model {model} for {mask_credential(key)}") + modified = True + + # Preserve unexpired cooldowns + if modified: + self._preserve_unexpired_cooldowns(key, data, now_ts) + if "failures" in data: + data["failures"] = {} + + return modified + + def _should_model_reset( + self, model_data: Dict[str, Any], window_seconds: int, now_ts: float + ) -> bool: + """ + Check if a single model should reset. + + Returns True if: + - quota_reset_ts is set AND now >= quota_reset_ts, OR + - quota_reset_ts is NOT set AND now >= window_start_ts + window_seconds + """ + quota_reset = model_data.get("quota_reset_ts") + window_start = model_data.get("window_start_ts") + + if quota_reset: + return now_ts >= quota_reset + elif window_start: + return now_ts >= window_start + window_seconds + return False + + def _should_group_reset( + self, + key: str, + group: str, + models_data: Dict[str, Dict], + window_seconds: int, + now_ts: float, + ) -> bool: + """ + Check if all models in a group should reset. + + All models in the group must be ready to reset. + If any model has an active cooldown/window, the whole group waits. + """ + grouped_models = self._get_grouped_models(key, group) + + # Track if any model in group has data + any_has_data = False + + for grouped_model in grouped_models: + model_data = models_data.get(grouped_model, {}) + + if not model_data or ( + model_data.get("window_start_ts") is None + and model_data.get("success_count", 0) == 0 + ): + continue # No stats for this model yet + + any_has_data = True + + if not self._should_model_reset(model_data, window_seconds, now_ts): + return False # At least one model not ready + + return any_has_data + + def _archive_model_to_global( + self, data: Dict[str, Any], model: str, model_data: Dict[str, Any] + ) -> None: + """Archive a single model's stats to global.""" + global_data = data.setdefault("global", {"models": {}}) + global_model = global_data["models"].setdefault( + model, + { + "success_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + + global_model["success_count"] += model_data.get("success_count", 0) + global_model["prompt_tokens"] += model_data.get("prompt_tokens", 0) + global_model["prompt_tokens_cached"] = global_model.get( + "prompt_tokens_cached", 0 + ) + model_data.get("prompt_tokens_cached", 0) + global_model["completion_tokens"] += model_data.get("completion_tokens", 0) + global_model["approx_cost"] += model_data.get("approx_cost", 0.0) + + def _reset_model_data(self, model_data: Dict[str, Any]) -> None: + """Reset a model's window and stats.""" + model_data["window_start_ts"] = None + model_data["quota_reset_ts"] = None + model_data["success_count"] = 0 + model_data["failure_count"] = 0 + model_data["request_count"] = 0 + model_data["prompt_tokens"] = 0 + model_data["completion_tokens"] = 0 + model_data["approx_cost"] = 0.0 + # Reset quota baseline fields only if they exist (Antigravity-specific) + # These are added by update_quota_baseline(), only called for Antigravity + if "baseline_remaining_fraction" in model_data: + model_data["baseline_remaining_fraction"] = None + model_data["baseline_fetched_at"] = None + model_data["requests_at_baseline"] = None + # Reset quota display but keep max_requests (it doesn't change between periods) + max_req = model_data.get("quota_max_requests") + if max_req: + model_data["quota_display"] = f"0/{max_req}" + + async def _check_window_reset( + self, + key: str, + data: Dict[str, Any], + reset_config: Dict[str, Any], + now_ts: float, + ) -> bool: + """ + Check and perform rolling window reset for a credential. + + Args: + key: Credential identifier + data: Usage data for this credential + reset_config: Provider's reset configuration + now_ts: Current timestamp + + Returns: + True if data was modified and needs saving + """ + window_seconds = reset_config.get("window_seconds", 86400) # Default 24h + field_name = reset_config.get("field_name", "window") + description = reset_config.get("description", "rolling window") + + # Get current window data + window_data = data.get(field_name, {}) + window_start = window_data.get("start_ts") + + # No window started yet - nothing to reset + if window_start is None: + return False + + # Check if window has expired + window_end = window_start + window_seconds + if now_ts < window_end: + # Window still active + return False + + # Window expired - perform reset + hours_elapsed = (now_ts - window_start) / 3600 + lib_logger.info( + f"Resetting {field_name} for {mask_credential(key)} - " + f"{description} expired after {hours_elapsed:.1f}h" + ) + + # Archive to global + self._archive_to_global(data, window_data) + + # Preserve unexpired cooldowns + self._preserve_unexpired_cooldowns(key, data, now_ts) + + # Reset window stats (but don't start new window until first request) + data[field_name] = {"start_ts": None, "models": {}} + + # Reset consecutive failures + if "failures" in data: + data["failures"] = {} + + return True + + async def _check_daily_reset( + self, + key: str, + data: Dict[str, Any], + now_utc: datetime, + today_str: str, + now_ts: float, + ) -> bool: + """ + Check and perform legacy daily reset for a credential. + + Args: + key: Credential identifier + data: Usage data for this credential + now_utc: Current datetime in UTC + today_str: Today's date as ISO string + now_ts: Current timestamp + + Returns: + True if data was modified and needs saving + """ + last_reset_str = data.get("last_daily_reset", "") + + if last_reset_str == today_str: + return False + + last_reset_dt = None + if last_reset_str: + try: + last_reset_dt = datetime.fromisoformat(last_reset_str).replace( + tzinfo=timezone.utc + ) + except ValueError: + pass + + # Determine the reset threshold for today + reset_threshold_today = datetime.combine( + now_utc.date(), self.daily_reset_time_utc + ) + + if not ( + last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc + ): + return False + + lib_logger.debug(f"Performing daily reset for key {mask_credential(key)}") + + # Preserve unexpired cooldowns + self._preserve_unexpired_cooldowns(key, data, now_ts) + + # Reset consecutive failures + if "failures" in data: + data["failures"] = {} + + # Archive daily stats to global + daily_data = data.get("daily", {}) + if daily_data: + self._archive_to_global(data, daily_data) + + # Reset daily stats + data["daily"] = {"date": today_str, "models": {}} + data["last_daily_reset"] = today_str + + return True + + def _archive_to_global( + self, data: Dict[str, Any], source_data: Dict[str, Any] + ) -> None: + """ + Archive usage stats from a source field (daily/window) to global. + + Args: + data: The credential's usage data + source_data: The source field data to archive (has "models" key) + """ + global_data = data.setdefault("global", {"models": {}}) + for model, stats in source_data.get("models", {}).items(): + global_model_stats = global_data["models"].setdefault( + model, + { + "success_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + global_model_stats["success_count"] += stats.get("success_count", 0) + global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0) + global_model_stats["prompt_tokens_cached"] = global_model_stats.get( + "prompt_tokens_cached", 0 + ) + stats.get("prompt_tokens_cached", 0) + global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0) + global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0) + + def _preserve_unexpired_cooldowns( + self, key: str, data: Dict[str, Any], now_ts: float + ) -> None: + """ + Preserve unexpired cooldowns during reset (important for long quota cooldowns). + + Args: + key: Credential identifier (for logging) + data: The credential's usage data + now_ts: Current timestamp + """ + # Preserve unexpired model cooldowns + if "model_cooldowns" in data: + active_cooldowns = { + model: end_time + for model, end_time in data["model_cooldowns"].items() + if end_time > now_ts + } + if active_cooldowns: + max_remaining = max( + end_time - now_ts for end_time in active_cooldowns.values() + ) + hours_remaining = max_remaining / 3600 + lib_logger.info( + f"Preserving {len(active_cooldowns)} active cooldown(s) " + f"for key {mask_credential(key)} during reset " + f"(longest: {hours_remaining:.1f}h remaining)" + ) + data["model_cooldowns"] = active_cooldowns + else: + data["model_cooldowns"] = {} + + # Preserve unexpired key-level cooldown + if data.get("key_cooldown_until"): + if data["key_cooldown_until"] <= now_ts: + data["key_cooldown_until"] = None + else: + hours_remaining = (data["key_cooldown_until"] - now_ts) / 3600 + lib_logger.info( + f"Preserving key-level cooldown for {mask_credential(key)} " + f"during reset ({hours_remaining:.1f}h remaining)" + ) + else: + data["key_cooldown_until"] = None + + def _initialize_key_states(self, keys: List[str]): + """Initializes state tracking for all provided keys if not already present.""" + for key in keys: + if key not in self.key_states: + self.key_states[key] = { + "lock": asyncio.Lock(), + "condition": asyncio.Condition(), + "models_in_use": {}, # Dict[model_name, concurrent_count] + } + + def _select_weighted_random(self, candidates: List[tuple], tolerance: float) -> str: + """ + Selects a credential using weighted random selection based on usage counts. + + Args: + candidates: List of (credential_id, usage_count) tuples + tolerance: Tolerance value for weight calculation + + Returns: + Selected credential ID + + Formula: + weight = (max_usage - credential_usage) + tolerance + 1 + + This formula ensures: + - Lower usage = higher weight = higher selection probability + - Tolerance adds variability: higher tolerance means more randomness + - The +1 ensures all credentials have at least some chance of selection + """ + if not candidates: + raise ValueError("Cannot select from empty candidate list") + + if len(candidates) == 1: + return candidates[0][0] + + # Extract usage counts + usage_counts = [usage for _, usage in candidates] + max_usage = max(usage_counts) + + # Calculate weights using the formula: (max - current) + tolerance + 1 + weights = [] + for credential, usage in candidates: + weight = (max_usage - usage) + tolerance + 1 + weights.append(weight) + + # Log weight distribution for debugging + if lib_logger.isEnabledFor(logging.DEBUG): + total_weight = sum(weights) + weight_info = ", ".join( + f"{mask_credential(cred)}: w={w:.1f} ({w / total_weight * 100:.1f}%)" + for (cred, _), w in zip(candidates, weights) + ) + # lib_logger.debug(f"Weighted selection candidates: {weight_info}") + + # Random selection with weights + selected_credential = random.choices( + [cred for cred, _ in candidates], weights=weights, k=1 + )[0] + + return selected_credential + + async def acquire_key( + self, + available_keys: List[str], + model: str, + deadline: float, + max_concurrent: int = 1, + credential_priorities: Optional[Dict[str, int]] = None, + credential_tier_names: Optional[Dict[str, str]] = None, + all_provider_credentials: Optional[List[str]] = None, + ) -> str: + """ + Acquires the best available key using a tiered, model-aware locking strategy, + respecting a global deadline and credential priorities. + + Priority Logic: + - Groups credentials by priority level (1=highest, 2=lower, etc.) + - Always tries highest priority (lowest number) first + - Within same priority, sorts by usage count (load balancing) + - Only moves to next priority if all higher-priority keys exhausted/busy + + Args: + available_keys: List of credential identifiers to choose from + model: Model name being requested + deadline: Timestamp after which to stop trying + max_concurrent: Maximum concurrent requests allowed per credential + credential_priorities: Optional dict mapping credentials to priority levels (1=highest) + credential_tier_names: Optional dict mapping credentials to tier names (for logging) + all_provider_credentials: Full list of provider credentials (used for cycle reset checks) + + Returns: + Selected credential identifier + + Raises: + NoAvailableKeysError: If no key could be acquired within the deadline + """ + await self._lazy_init() + await self._reset_daily_stats_if_needed() + self._initialize_key_states(available_keys) + + # Normalize model name for consistent cooldown lookup + # (cooldowns are stored under normalized names by record_failure) + # Use first credential for provider detection; all credentials passed here + # are for the same provider (filtered by client.py before calling acquire_key). + # For providers without normalize_model_for_tracking (non-Antigravity), + # this returns the model unchanged, so cooldown lookups work as before. + normalized_model = ( + self._normalize_model(available_keys[0], model) if available_keys else model + ) + + # This loop continues as long as the global deadline has not been met. + while time.time() < deadline: + now = time.time() + + # Group credentials by priority level (if priorities provided) + if credential_priorities: + # Group keys by priority level + priority_groups = {} + async with self._data_lock: + for key in available_keys: + key_data = self._usage_data.get(key, {}) + + # Skip keys on cooldown (use normalized model for lookup) + if (key_data.get("key_cooldown_until") or 0) > now or ( + key_data.get("model_cooldowns", {}).get(normalized_model) + or 0 + ) > now: + continue + + # Get priority for this key (default to 999 if not specified) + priority = credential_priorities.get(key, 999) + + # Get usage count for load balancing within priority groups + # Uses grouped usage if model is in a quota group + usage_count = self._get_grouped_usage_count(key, model) + + # Group by priority + if priority not in priority_groups: + priority_groups[priority] = [] + priority_groups[priority].append((key, usage_count)) + + # Try priority groups in order (1, 2, 3, ...) + sorted_priorities = sorted(priority_groups.keys()) + + for priority_level in sorted_priorities: + keys_in_priority = priority_groups[priority_level] + + # Determine selection method based on provider's rotation mode + provider = model.split("/")[0] if "/" in model else "" + rotation_mode = self._get_rotation_mode(provider) + + # Fair cycle filtering + if provider and self._is_fair_cycle_enabled( + provider, rotation_mode + ): + tier_key = self._get_tier_key(provider, priority_level) + tracking_key = self._get_tracking_key( + keys_in_priority[0][0] if keys_in_priority else "", + model, + provider, + ) + + # Get all credentials for this tier (for cycle completion check) + all_tier_creds = self._get_all_credentials_for_tier_key( + provider, + tier_key, + all_provider_credentials or available_keys, + credential_priorities, + ) + + # Check if cycle should reset (all exhausted, expired, or none available) + if self._should_reset_cycle( + provider, + tier_key, + tracking_key, + all_tier_creds, + available_not_on_cooldown=[ + key for key, _ in keys_in_priority + ], + ): + self._reset_cycle(provider, tier_key, tracking_key) + + # Filter out exhausted credentials + filtered_keys = [] + for key, usage_count in keys_in_priority: + if not self._is_credential_exhausted_in_cycle( + key, provider, tier_key, tracking_key + ): + filtered_keys.append((key, usage_count)) + + keys_in_priority = filtered_keys + + # Calculate effective concurrency based on priority tier + multiplier = self._get_priority_multiplier( + provider, priority_level, rotation_mode + ) + effective_max_concurrent = max_concurrent * multiplier + + # Within each priority group, use existing tier1/tier2 logic + tier1_keys, tier2_keys = [], [] + for key, usage_count in keys_in_priority: + key_state = self.key_states[key] + + # Tier 1: Completely idle keys (preferred) + if not key_state["models_in_use"]: + tier1_keys.append((key, usage_count)) + # Tier 2: Keys that can accept more concurrent requests + elif ( + key_state["models_in_use"].get(model, 0) + < effective_max_concurrent + ): + tier2_keys.append((key, usage_count)) + + if rotation_mode == "sequential": + # Sequential mode: sort credentials by priority, usage, recency + # Keep all candidates in sorted order (no filtering to single key) + selection_method = "sequential" + if tier1_keys: + tier1_keys = self._sort_sequential( + tier1_keys, credential_priorities + ) + if tier2_keys: + tier2_keys = self._sort_sequential( + tier2_keys, credential_priorities + ) + elif self.rotation_tolerance > 0: + # Balanced mode with weighted randomness + selection_method = "weighted-random" + if tier1_keys: + selected_key = self._select_weighted_random( + tier1_keys, self.rotation_tolerance + ) + tier1_keys = [ + (k, u) for k, u in tier1_keys if k == selected_key + ] + if tier2_keys: + selected_key = self._select_weighted_random( + tier2_keys, self.rotation_tolerance + ) + tier2_keys = [ + (k, u) for k, u in tier2_keys if k == selected_key + ] + else: + # Deterministic: sort by usage within each tier + selection_method = "least-used" + tier1_keys.sort(key=lambda x: x[1]) + tier2_keys.sort(key=lambda x: x[1]) + + # Try to acquire from Tier 1 first + for key, usage in tier1_keys: + state = self.key_states[key] + async with state["lock"]: + if not state["models_in_use"]: + state["models_in_use"][model] = 1 + tier_name = ( + credential_tier_names.get(key, "unknown") + if credential_tier_names + else "unknown" + ) + quota_display = self._get_quota_display(key, model) + lib_logger.info( + f"Acquired key {mask_credential(key)} for model {model} " + f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, {quota_display})" + ) + return key + + # Then try Tier 2 + for key, usage in tier2_keys: + state = self.key_states[key] + async with state["lock"]: + current_count = state["models_in_use"].get(model, 0) + if current_count < effective_max_concurrent: + state["models_in_use"][model] = current_count + 1 + tier_name = ( + credential_tier_names.get(key, "unknown") + if credential_tier_names + else "unknown" + ) + quota_display = self._get_quota_display(key, model) + lib_logger.info( + f"Acquired key {mask_credential(key)} for model {model} " + f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" + ) + return key + + # If we get here, all priority groups were exhausted but keys might become available + # Collect all keys across all priorities for waiting + all_potential_keys = [] + for keys_list in priority_groups.values(): + all_potential_keys.extend(keys_list) + + if not all_potential_keys: + # All credentials are on cooldown - check if waiting makes sense + soonest_end = await self.get_soonest_cooldown_end( + available_keys, model + ) + + if soonest_end is None: + # No cooldowns active but no keys available (shouldn't happen) + lib_logger.warning( + "No keys eligible and no cooldowns active. Re-evaluating..." + ) + await asyncio.sleep(1) + continue + + remaining_budget = deadline - time.time() + wait_needed = soonest_end - time.time() + + if wait_needed > remaining_budget: + # Fail fast - no credential will be available in time + lib_logger.warning( + f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " + f"but only {remaining_budget:.1f}s budget remaining. Failing fast." + ) + break # Exit loop, will raise NoAvailableKeysError + + # Wait for the credential to become available + lib_logger.info( + f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." + ) + await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) + continue + + # Wait for the highest priority key with lowest usage + best_priority = min(priority_groups.keys()) + best_priority_keys = priority_groups[best_priority] + best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0] + wait_condition = self.key_states[best_wait_key]["condition"] + + lib_logger.info( + f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..." + ) + + else: + # Original logic when no priorities specified + + # Determine selection method based on provider's rotation mode + provider = model.split("/")[0] if "/" in model else "" + rotation_mode = self._get_rotation_mode(provider) + + # Calculate effective concurrency for default priority (999) + # When no priorities are specified, all credentials get default priority + default_priority = 999 + multiplier = self._get_priority_multiplier( + provider, default_priority, rotation_mode + ) + effective_max_concurrent = max_concurrent * multiplier + + tier1_keys, tier2_keys = [], [] + + # First, filter the list of available keys to exclude any on cooldown. + async with self._data_lock: + for key in available_keys: + key_data = self._usage_data.get(key, {}) + + # Skip keys on cooldown (use normalized model for lookup) + if (key_data.get("key_cooldown_until") or 0) > now or ( + key_data.get("model_cooldowns", {}).get(normalized_model) + or 0 + ) > now: + continue + + # Prioritize keys based on their current usage to ensure load balancing. + # Uses grouped usage if model is in a quota group + usage_count = self._get_grouped_usage_count(key, model) + key_state = self.key_states[key] + + # Tier 1: Completely idle keys (preferred). + if not key_state["models_in_use"]: + tier1_keys.append((key, usage_count)) + # Tier 2: Keys that can accept more concurrent requests for this model. + elif ( + key_state["models_in_use"].get(model, 0) + < effective_max_concurrent + ): + tier2_keys.append((key, usage_count)) + + # Fair cycle filtering (non-priority case) + if provider and self._is_fair_cycle_enabled(provider, rotation_mode): + tier_key = self._get_tier_key(provider, default_priority) + tracking_key = self._get_tracking_key( + available_keys[0] if available_keys else "", + model, + provider, + ) + + # Get all credentials for this tier (for cycle completion check) + all_tier_creds = self._get_all_credentials_for_tier_key( + provider, + tier_key, + all_provider_credentials or available_keys, + None, + ) + + # Check if cycle should reset (all exhausted, expired, or none available) + if self._should_reset_cycle( + provider, + tier_key, + tracking_key, + all_tier_creds, + available_not_on_cooldown=[ + key for key, _ in (tier1_keys + tier2_keys) + ], + ): + self._reset_cycle(provider, tier_key, tracking_key) + + # Filter out exhausted credentials from both tiers + tier1_keys = [ + (key, usage) + for key, usage in tier1_keys + if not self._is_credential_exhausted_in_cycle( + key, provider, tier_key, tracking_key + ) + ] + tier2_keys = [ + (key, usage) + for key, usage in tier2_keys + if not self._is_credential_exhausted_in_cycle( + key, provider, tier_key, tracking_key + ) + ] + + if rotation_mode == "sequential": + # Sequential mode: sort credentials by priority, usage, recency + # Keep all candidates in sorted order (no filtering to single key) + selection_method = "sequential" + if tier1_keys: + tier1_keys = self._sort_sequential( + tier1_keys, credential_priorities + ) + if tier2_keys: + tier2_keys = self._sort_sequential( + tier2_keys, credential_priorities + ) + elif self.rotation_tolerance > 0: + # Balanced mode with weighted randomness + selection_method = "weighted-random" + if tier1_keys: + selected_key = self._select_weighted_random( + tier1_keys, self.rotation_tolerance + ) + tier1_keys = [ + (k, u) for k, u in tier1_keys if k == selected_key + ] + if tier2_keys: + selected_key = self._select_weighted_random( + tier2_keys, self.rotation_tolerance + ) + tier2_keys = [ + (k, u) for k, u in tier2_keys if k == selected_key + ] + else: + # Deterministic: sort by usage within each tier + selection_method = "least-used" + tier1_keys.sort(key=lambda x: x[1]) + tier2_keys.sort(key=lambda x: x[1]) + + # Attempt to acquire a key from Tier 1 first. + for key, usage in tier1_keys: + state = self.key_states[key] + async with state["lock"]: + if not state["models_in_use"]: + state["models_in_use"][model] = 1 + tier_name = ( + credential_tier_names.get(key) + if credential_tier_names + else None + ) + tier_info = f"tier: {tier_name}, " if tier_name else "" + quota_display = self._get_quota_display(key, model) + lib_logger.info( + f"Acquired key {mask_credential(key)} for model {model} " + f"({tier_info}selection: {selection_method}, {quota_display})" + ) + return key + + # If no Tier 1 keys are available, try Tier 2. + for key, usage in tier2_keys: + state = self.key_states[key] + async with state["lock"]: + current_count = state["models_in_use"].get(model, 0) + if current_count < effective_max_concurrent: + state["models_in_use"][model] = current_count + 1 + tier_name = ( + credential_tier_names.get(key) + if credential_tier_names + else None + ) + tier_info = f"tier: {tier_name}, " if tier_name else "" + quota_display = self._get_quota_display(key, model) + lib_logger.info( + f"Acquired key {mask_credential(key)} for model {model} " + f"({tier_info}selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" + ) + return key + + # If all eligible keys are locked, wait for a key to be released. + lib_logger.info( + "All eligible keys are currently locked for this model. Waiting..." + ) + + all_potential_keys = tier1_keys + tier2_keys + if not all_potential_keys: + # All credentials are on cooldown - check if waiting makes sense + soonest_end = await self.get_soonest_cooldown_end( + available_keys, model + ) + + if soonest_end is None: + # No cooldowns active but no keys available (shouldn't happen) + lib_logger.warning( + "No keys eligible and no cooldowns active. Re-evaluating..." + ) + await asyncio.sleep(1) + continue + + remaining_budget = deadline - time.time() + wait_needed = soonest_end - time.time() + + if wait_needed > remaining_budget: + # Fail fast - no credential will be available in time + lib_logger.warning( + f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " + f"but only {remaining_budget:.1f}s budget remaining. Failing fast." + ) + break # Exit loop, will raise NoAvailableKeysError + + # Wait for the credential to become available + lib_logger.info( + f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." + ) + await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) + continue + + # Wait on the condition of the key with the lowest current usage. + best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] + wait_condition = self.key_states[best_wait_key]["condition"] + + try: + async with wait_condition: + remaining_budget = deadline - time.time() + if remaining_budget <= 0: + break # Exit if the budget has already been exceeded. + # Wait for a notification, but no longer than the remaining budget or 1 second. + await asyncio.wait_for( + wait_condition.wait(), timeout=min(1, remaining_budget) + ) + lib_logger.info("Notified that a key was released. Re-evaluating...") + except asyncio.TimeoutError: + # This is not an error, just a timeout for the wait. The main loop will re-evaluate. + lib_logger.info("Wait timed out. Re-evaluating for any available key.") + + # If the loop exits, it means the deadline was exceeded. + raise NoAvailableKeysError( + f"Could not acquire a key for model {model} within the global time budget." + ) + + async def release_key(self, key: str, model: str): + """Releases a key's lock for a specific model and notifies waiting tasks.""" + if key not in self.key_states: + return + + state = self.key_states[key] + async with state["lock"]: + if model in state["models_in_use"]: + state["models_in_use"][model] -= 1 + remaining = state["models_in_use"][model] + if remaining <= 0: + del state["models_in_use"][model] # Clean up when count reaches 0 + lib_logger.info( + f"Released credential {mask_credential(key)} from model {model} " + f"(remaining concurrent: {max(0, remaining)})" + ) + else: + lib_logger.warning( + f"Attempted to release credential {mask_credential(key)} for model {model}, but it was not in use." + ) + + # Notify all tasks waiting on this key's condition + async with state["condition"]: + state["condition"].notify_all() + + async def record_success( + self, + key: str, + model: str, + completion_response: Optional[litellm.ModelResponse] = None, + ): + """ + Records a successful API call, resetting failure counters. + It safely handles cases where token usage data is not available. + + Supports two modes based on provider configuration: + - per_model: Each model has its own window_start_ts and stats in key_data["models"] + - credential: Legacy mode with key_data["daily"]["models"] + """ + await self._lazy_init() + + # Normalize model name to public-facing name for consistent tracking + model = self._normalize_model(key, model) + + async with self._data_lock: + now_ts = time.time() + today_utc_str = datetime.now(timezone.utc).date().isoformat() + + reset_config = self._get_usage_reset_config(key) + reset_mode = ( + reset_config.get("mode", "credential") if reset_config else "credential" + ) + + if reset_mode == "per_model": + # New per-model structure + key_data = self._usage_data.setdefault( + key, + { + "models": {}, + "global": {"models": {}}, + "model_cooldowns": {}, + "failures": {}, + }, + ) + + # Ensure models dict exists + if "models" not in key_data: + key_data["models"] = {} + + # Get or create per-model data with window tracking + model_data = key_data["models"].setdefault( + model, + { + "window_start_ts": None, + "quota_reset_ts": None, + "success_count": 0, + "failure_count": 0, + "request_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + + # Start window on first request for this model + if model_data.get("window_start_ts") is None: + model_data["window_start_ts"] = now_ts + + # Set expected quota reset time from provider config + window_seconds = ( + reset_config.get("window_seconds", 0) if reset_config else 0 + ) + if window_seconds > 0: + model_data["quota_reset_ts"] = now_ts + window_seconds + + window_hours = window_seconds / 3600 if window_seconds else 0 + lib_logger.info( + f"Started {window_hours:.1f}h window for model {model} on {mask_credential(key)}" + ) + + # Record stats + model_data["success_count"] += 1 + model_data["request_count"] = model_data.get("request_count", 0) + 1 + + # Sync request_count across quota group (for providers with shared quota pools) + new_request_count = model_data["request_count"] + group = self._get_model_quota_group(key, model) + if group: + grouped_models = self._get_grouped_models(key, group) + for grouped_model in grouped_models: + if grouped_model != model: + other_model_data = key_data["models"].setdefault( + grouped_model, + { + "window_start_ts": None, + "quota_reset_ts": None, + "success_count": 0, + "failure_count": 0, + "request_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + other_model_data["request_count"] = new_request_count + # Sync window timing (shared quota pool = shared window) + window_start = model_data.get("window_start_ts") + if window_start: + other_model_data["window_start_ts"] = window_start + # Also sync quota_max_requests if set + max_req = model_data.get("quota_max_requests") + if max_req: + other_model_data["quota_max_requests"] = max_req + other_model_data["quota_display"] = ( + f"{new_request_count}/{max_req}" + ) + + # Update quota_display if max_requests is set (Antigravity-specific) + max_req = model_data.get("quota_max_requests") + if max_req: + model_data["quota_display"] = ( + f"{model_data['request_count']}/{max_req}" + ) + + # Check custom cap + if self._check_and_apply_custom_cap( + key, model, model_data["request_count"] + ): + # Custom cap exceeded, cooldown applied + # Continue to record tokens/cost but credential will be skipped next time + pass + + usage_data_ref = model_data # For token/cost recording below + + else: + # Legacy credential-level structure + key_data = self._usage_data.setdefault( + key, + { + "daily": {"date": today_utc_str, "models": {}}, + "global": {"models": {}}, + "model_cooldowns": {}, + "failures": {}, + }, + ) + + if "last_daily_reset" not in key_data: + key_data["last_daily_reset"] = today_utc_str + + # Get or create model data in daily structure + usage_data_ref = key_data["daily"]["models"].setdefault( + model, + { + "success_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + usage_data_ref["success_count"] += 1 + + # Reset failures for this model + model_failures = key_data.setdefault("failures", {}).setdefault(model, {}) + model_failures["consecutive_failures"] = 0 + + # Clear transient cooldown on success (but NOT quota_reset_ts) + if model in key_data.get("model_cooldowns", {}): + del key_data["model_cooldowns"][model] + + # Record token and cost usage + if ( + completion_response + and hasattr(completion_response, "usage") + and completion_response.usage + ): + usage = completion_response.usage + prompt_total = usage.prompt_tokens + + # Extract cached tokens from prompt_tokens_details if present + cached_tokens = 0 + prompt_details = getattr(usage, "prompt_tokens_details", None) + if prompt_details: + if isinstance(prompt_details, dict): + cached_tokens = prompt_details.get("cached_tokens", 0) or 0 + elif hasattr(prompt_details, "cached_tokens"): + cached_tokens = prompt_details.cached_tokens or 0 + + # Store uncached tokens (prompt_tokens is total, subtract cached) + uncached_tokens = prompt_total - cached_tokens + usage_data_ref["prompt_tokens"] += uncached_tokens + + # Store cached tokens separately + if cached_tokens > 0: + usage_data_ref["prompt_tokens_cached"] = ( + usage_data_ref.get("prompt_tokens_cached", 0) + cached_tokens + ) + + usage_data_ref["completion_tokens"] += getattr( + usage, "completion_tokens", 0 + ) + lib_logger.info( + f"Recorded usage from response object for key {mask_credential(key)}" + ) + try: + provider_name = model.split("/")[0] + provider_instance = self._get_provider_instance(provider_name) + + if provider_instance and getattr( + provider_instance, "skip_cost_calculation", False + ): + lib_logger.debug( + f"Skipping cost calculation for provider '{provider_name}' (custom provider)." + ) + else: + if isinstance(completion_response, litellm.EmbeddingResponse): + model_info = litellm.get_model_info(model) + input_cost = model_info.get("input_cost_per_token") + if input_cost: + cost = ( + completion_response.usage.prompt_tokens * input_cost + ) + else: + cost = None + else: + cost = litellm.completion_cost( + completion_response=completion_response, model=model + ) + + if cost is not None: + usage_data_ref["approx_cost"] += cost + except Exception as e: + lib_logger.warning( + f"Could not calculate cost for model {model}: {e}" + ) + elif isinstance(completion_response, asyncio.Future) or hasattr( + completion_response, "__aiter__" + ): + pass # Stream - usage recorded from chunks + else: + lib_logger.warning( + f"No usage data found in completion response for model {model}. Recording success without token count." + ) + + key_data["last_used_ts"] = now_ts + + await self._save_usage() + + async def record_failure( + self, + key: str, + model: str, + classified_error: ClassifiedError, + increment_consecutive_failures: bool = True, + ): + """Records a failure and applies cooldowns based on error type. + + Distinguishes between: + - quota_exceeded: Long cooldown with exact reset time (from quota_reset_timestamp) + Sets quota_reset_ts on model (and group) - this becomes authoritative stats reset time + - rate_limit: Short transient cooldown (just wait and retry) + Only sets model_cooldowns - does NOT affect stats reset timing + + Args: + key: The API key or credential identifier + model: The model name + classified_error: The classified error object + increment_consecutive_failures: Whether to increment the failure counter. + Set to False for provider-level errors that shouldn't count against the key. + """ + await self._lazy_init() + + # Normalize model name to public-facing name for consistent tracking + model = self._normalize_model(key, model) + + async with self._data_lock: + now_ts = time.time() + today_utc_str = datetime.now(timezone.utc).date().isoformat() + + reset_config = self._get_usage_reset_config(key) + reset_mode = ( + reset_config.get("mode", "credential") if reset_config else "credential" + ) + + # Initialize key data with appropriate structure + if reset_mode == "per_model": + key_data = self._usage_data.setdefault( + key, + { + "models": {}, + "global": {"models": {}}, + "model_cooldowns": {}, + "failures": {}, + }, + ) + else: + key_data = self._usage_data.setdefault( + key, + { + "daily": {"date": today_utc_str, "models": {}}, + "global": {"models": {}}, + "model_cooldowns": {}, + "failures": {}, + }, + ) + + # Provider-level errors (transient issues) should not count against the key + provider_level_errors = {"server_error", "api_connection"} + + # Determine if we should increment the failure counter + should_increment = ( + increment_consecutive_failures + and classified_error.error_type not in provider_level_errors + ) + + # Calculate cooldown duration based on error type + cooldown_seconds = None + model_cooldowns = key_data.setdefault("model_cooldowns", {}) + + # Capture existing cooldown BEFORE we modify it + # Used to determine if this is a fresh exhaustion vs re-processing + existing_cooldown_before = model_cooldowns.get(model) + was_already_on_cooldown = ( + existing_cooldown_before is not None + and existing_cooldown_before > now_ts + ) + + if classified_error.error_type == "quota_exceeded": + # Quota exhausted - use authoritative reset timestamp if available + quota_reset_ts = classified_error.quota_reset_timestamp + cooldown_seconds = ( + classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT + ) + + if quota_reset_ts and reset_mode == "per_model": + # Set quota_reset_ts on model - this becomes authoritative stats reset time + models_data = key_data.setdefault("models", {}) + model_data = models_data.setdefault( + model, + { + "window_start_ts": None, + "quota_reset_ts": None, + "success_count": 0, + "failure_count": 0, + "request_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + model_data["quota_reset_ts"] = quota_reset_ts + # Track failure for quota estimation (request still consumes quota) + model_data["failure_count"] = model_data.get("failure_count", 0) + 1 + model_data["request_count"] = model_data.get("request_count", 0) + 1 + + # Clamp request_count to quota_max_requests when quota is exhausted + # This prevents display overflow (e.g., 151/150) when requests are + # counted locally before API refresh corrects the value + max_req = model_data.get("quota_max_requests") + if max_req is not None and model_data["request_count"] > max_req: + model_data["request_count"] = max_req + # Update quota_display with clamped value + model_data["quota_display"] = f"{max_req}/{max_req}" + new_request_count = model_data["request_count"] + + # Apply to all models in the same quota group + group = self._get_model_quota_group(key, model) + if group: + grouped_models = self._get_grouped_models(key, group) + for grouped_model in grouped_models: + group_model_data = models_data.setdefault( + grouped_model, + { + "window_start_ts": None, + "quota_reset_ts": None, + "success_count": 0, + "failure_count": 0, + "request_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + group_model_data["quota_reset_ts"] = quota_reset_ts + # Sync request_count across quota group + group_model_data["request_count"] = new_request_count + # Also sync quota_max_requests if set + max_req = model_data.get("quota_max_requests") + if max_req: + group_model_data["quota_max_requests"] = max_req + group_model_data["quota_display"] = ( + f"{new_request_count}/{max_req}" + ) + # Also set transient cooldown for selection logic + model_cooldowns[grouped_model] = quota_reset_ts + + reset_dt = datetime.fromtimestamp( + quota_reset_ts, tz=timezone.utc + ) + lib_logger.info( + f"Quota exhausted for group '{group}' ({len(grouped_models)} models) " + f"on {mask_credential(key)}. Resets at {reset_dt.isoformat()}" + ) + else: + reset_dt = datetime.fromtimestamp( + quota_reset_ts, tz=timezone.utc + ) + hours = (quota_reset_ts - now_ts) / 3600 + lib_logger.info( + f"Quota exhausted for model {model} on {mask_credential(key)}. " + f"Resets at {reset_dt.isoformat()} ({hours:.1f}h)" + ) + + # Set transient cooldown for selection logic + model_cooldowns[model] = quota_reset_ts + else: + # No authoritative timestamp or legacy mode - just use retry_after + model_cooldowns[model] = now_ts + cooldown_seconds + hours = cooldown_seconds / 3600 + lib_logger.info( + f"Quota exhausted on {mask_credential(key)} for model {model}. " + f"Cooldown: {cooldown_seconds}s ({hours:.1f}h)" + ) + + # Mark credential as exhausted for fair cycle if cooldown exceeds threshold + # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) + # This prevents re-marking after cycle reset + if not was_already_on_cooldown: + effective_cooldown = ( + (quota_reset_ts - now_ts) + if quota_reset_ts + else (cooldown_seconds or 0) + ) + provider = self._get_provider_from_credential(key) + if provider: + threshold = self._get_exhaustion_cooldown_threshold(provider) + if effective_cooldown > threshold: + rotation_mode = self._get_rotation_mode(provider) + if self._is_fair_cycle_enabled(provider, rotation_mode): + priority = self._get_credential_priority(key, provider) + tier_key = self._get_tier_key(provider, priority) + tracking_key = self._get_tracking_key( + key, model, provider + ) + self._mark_credential_exhausted( + key, provider, tier_key, tracking_key + ) + + elif classified_error.error_type == "rate_limit": + # Transient rate limit - just set short cooldown (does NOT set quota_reset_ts) + cooldown_seconds = ( + classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT + ) + model_cooldowns[model] = now_ts + cooldown_seconds + lib_logger.info( + f"Rate limit on {mask_credential(key)} for model {model}. " + f"Transient cooldown: {cooldown_seconds}s" + ) + + elif classified_error.error_type == "authentication": + # Apply a 5-minute key-level lockout for auth errors + key_data["key_cooldown_until"] = now_ts + COOLDOWN_AUTH_ERROR + cooldown_seconds = COOLDOWN_AUTH_ERROR + model_cooldowns[model] = now_ts + cooldown_seconds + lib_logger.warning( + f"Authentication error on key {mask_credential(key)}. Applying 5-minute key-level lockout." + ) + + # If we should increment failures, calculate escalating backoff + if should_increment: + failures_data = key_data.setdefault("failures", {}) + model_failures = failures_data.setdefault( + model, {"consecutive_failures": 0} + ) + model_failures["consecutive_failures"] += 1 + count = model_failures["consecutive_failures"] + + # If cooldown wasn't set by specific error type, use escalating backoff + if cooldown_seconds is None: + cooldown_seconds = COOLDOWN_BACKOFF_TIERS.get( + count, COOLDOWN_BACKOFF_MAX + ) + model_cooldowns[model] = now_ts + cooldown_seconds + lib_logger.warning( + f"Failure #{count} for key {mask_credential(key)} with model {model}. " + f"Error type: {classified_error.error_type}, cooldown: {cooldown_seconds}s" + ) + else: + # Provider-level errors: apply short cooldown but don't count against key + if cooldown_seconds is None: + cooldown_seconds = COOLDOWN_TRANSIENT_ERROR + model_cooldowns[model] = now_ts + cooldown_seconds + lib_logger.info( + f"Provider-level error ({classified_error.error_type}) for key {mask_credential(key)} " + f"with model {model}. NOT incrementing failures. Cooldown: {cooldown_seconds}s" + ) + + # Check for key-level lockout condition + await self._check_key_lockout(key, key_data) + + # Track failure count for quota estimation (all failures consume quota) + # This is separate from consecutive_failures which is for backoff logic + if reset_mode == "per_model": + models_data = key_data.setdefault("models", {}) + model_data = models_data.setdefault( + model, + { + "window_start_ts": None, + "quota_reset_ts": None, + "success_count": 0, + "failure_count": 0, + "request_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + # Only increment if not already incremented in quota_exceeded branch + if classified_error.error_type != "quota_exceeded": + model_data["failure_count"] = model_data.get("failure_count", 0) + 1 + model_data["request_count"] = model_data.get("request_count", 0) + 1 + + # Sync request_count across quota group + new_request_count = model_data["request_count"] + group = self._get_model_quota_group(key, model) + if group: + grouped_models = self._get_grouped_models(key, group) + for grouped_model in grouped_models: + if grouped_model != model: + other_model_data = models_data.setdefault( + grouped_model, + { + "window_start_ts": None, + "quota_reset_ts": None, + "success_count": 0, + "failure_count": 0, + "request_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + other_model_data["request_count"] = new_request_count + # Also sync quota_max_requests if set + max_req = model_data.get("quota_max_requests") + if max_req: + other_model_data["quota_max_requests"] = max_req + other_model_data["quota_display"] = ( + f"{new_request_count}/{max_req}" + ) + + key_data["last_failure"] = { + "timestamp": now_ts, + "model": model, + "error": str(classified_error.original_exception), + } + + await self._save_usage() + + async def update_quota_baseline( + self, + credential: str, + model: str, + remaining_fraction: float, + max_requests: Optional[int] = None, + reset_timestamp: Optional[float] = None, + ) -> Optional[Dict[str, Any]]: + """ + Update quota baseline data for a credential/model after fetching from API. + + This stores the current quota state as a baseline, which is used to + estimate remaining quota based on subsequent request counts. + + When quota is exhausted (remaining_fraction <= 0.0) and a valid reset_timestamp + is provided, this also sets model_cooldowns to prevent wasted requests. + + Args: + credential: Credential identifier (file path or env:// URI) + model: Model name (with or without provider prefix) + remaining_fraction: Current remaining quota as fraction (0.0 to 1.0) + max_requests: Maximum requests allowed per quota period (e.g., 250 for Claude) + reset_timestamp: Unix timestamp when quota resets. Only trusted when + remaining_fraction < 1.0 (quota has been used). API returns garbage + reset times for unused quota (100%). + + Returns: + None if no cooldown was set/updated, otherwise: + { + "group_or_model": str, # quota group name or model name if ungrouped + "hours_until_reset": float, + } + """ + await self._lazy_init() + async with self._data_lock: + now_ts = time.time() + + # Get or create key data structure + key_data = self._usage_data.setdefault( + credential, + { + "models": {}, + "global": {"models": {}}, + "model_cooldowns": {}, + "failures": {}, + }, + ) + + # Ensure models dict exists + if "models" not in key_data: + key_data["models"] = {} + + # Get or create per-model data + model_data = key_data["models"].setdefault( + model, + { + "window_start_ts": None, + "quota_reset_ts": None, + "success_count": 0, + "failure_count": 0, + "request_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + "baseline_remaining_fraction": None, + "baseline_fetched_at": None, + "requests_at_baseline": None, + }, + ) + + # Calculate actual used requests from API's remaining fraction + # The API is authoritative - sync our local count to match reality + if max_requests is not None: + used_requests = int((1.0 - remaining_fraction) * max_requests) + else: + # Estimate max_requests from provider's quota cost + # This matches how get_max_requests_for_model() calculates it + provider = self._get_provider_from_credential(credential) + plugin_instance = self._get_provider_instance(provider) + if plugin_instance and hasattr( + plugin_instance, "get_max_requests_for_model" + ): + # Get tier from provider's cache + tier = getattr(plugin_instance, "project_tier_cache", {}).get( + credential, "standard-tier" + ) + # Strip provider prefix from model if present + clean_model = model.split("/")[-1] if "/" in model else model + max_requests = plugin_instance.get_max_requests_for_model( + clean_model, tier + ) + used_requests = int((1.0 - remaining_fraction) * max_requests) + else: + # Fallback: keep existing count if we can't calculate + used_requests = model_data.get("request_count", 0) + max_requests = model_data.get("quota_max_requests") + + # Sync local request count to API's authoritative value + # Use max() to prevent API from resetting our count if it returns stale/cached 100% + # The API can only increase our count (if we missed requests), not decrease it + # See: https://github.com/Mirrowel/LLM-API-Key-Proxy/issues/75 + current_count = model_data.get("request_count", 0) + synced_count = max(current_count, used_requests) + model_data["request_count"] = synced_count + model_data["requests_at_baseline"] = synced_count + + # Update baseline fields + model_data["baseline_remaining_fraction"] = remaining_fraction + model_data["baseline_fetched_at"] = now_ts + + # Update max_requests and quota_display + if max_requests is not None: + model_data["quota_max_requests"] = max_requests + model_data["quota_display"] = f"{synced_count}/{max_requests}" + + # Handle reset_timestamp: only trust it when quota has been used (< 100%) + # API returns garbage reset times for unused quota + valid_reset_ts = ( + reset_timestamp is not None + and remaining_fraction < 1.0 + and reset_timestamp > now_ts + ) + + if valid_reset_ts: + model_data["quota_reset_ts"] = reset_timestamp + + # Set cooldowns when quota is exhausted + model_cooldowns = key_data.setdefault("model_cooldowns", {}) + is_exhausted = remaining_fraction <= 0.0 + cooldown_set_info = ( + None # Will be returned if cooldown was newly set/updated + ) + + if is_exhausted and valid_reset_ts: + # Check if there was an existing ACTIVE cooldown before we update + # This distinguishes between fresh exhaustion vs refresh of existing state + existing_cooldown = model_cooldowns.get(model) + was_already_on_cooldown = ( + existing_cooldown is not None and existing_cooldown > now_ts + ) + + # Only update cooldown if not set or differs by more than 5 minutes + should_update = ( + existing_cooldown is None + or abs(existing_cooldown - reset_timestamp) > 300 + ) + if should_update: + model_cooldowns[model] = reset_timestamp + hours_until_reset = (reset_timestamp - now_ts) / 3600 + # Determine group or model name for logging + group = self._get_model_quota_group(credential, model) + cooldown_set_info = { + "group_or_model": group if group else model.split("/")[-1], + "hours_until_reset": hours_until_reset, + } + + # Mark credential as exhausted in fair cycle if cooldown exceeds threshold + # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) + # This prevents re-marking after cycle reset when quota refresh sees existing cooldown + if not was_already_on_cooldown: + cooldown_duration = reset_timestamp - now_ts + provider = self._get_provider_from_credential(credential) + if provider: + threshold = self._get_exhaustion_cooldown_threshold(provider) + if cooldown_duration > threshold: + rotation_mode = self._get_rotation_mode(provider) + if self._is_fair_cycle_enabled(provider, rotation_mode): + priority = self._get_credential_priority( + credential, provider + ) + tier_key = self._get_tier_key(provider, priority) + tracking_key = self._get_tracking_key( + credential, model, provider + ) + self._mark_credential_exhausted( + credential, provider, tier_key, tracking_key + ) + + # Defensive clamp: ensure request_count doesn't exceed max when exhausted + if ( + max_requests is not None + and model_data["request_count"] > max_requests + ): + model_data["request_count"] = max_requests + model_data["quota_display"] = f"{max_requests}/{max_requests}" + + # Sync baseline fields and quota info across quota group + group = self._get_model_quota_group(credential, model) + if group: + grouped_models = self._get_grouped_models(credential, group) + for grouped_model in grouped_models: + if grouped_model != model: + other_model_data = key_data["models"].setdefault( + grouped_model, + { + "window_start_ts": None, + "quota_reset_ts": None, + "success_count": 0, + "failure_count": 0, + "request_count": 0, + "prompt_tokens": 0, + "prompt_tokens_cached": 0, + "completion_tokens": 0, + "approx_cost": 0.0, + }, + ) + # Sync request tracking (use synced_count to prevent reset bug) + other_model_data["request_count"] = synced_count + if max_requests is not None: + other_model_data["quota_max_requests"] = max_requests + other_model_data["quota_display"] = ( + f"{synced_count}/{max_requests}" + ) + # Sync baseline fields + other_model_data["baseline_remaining_fraction"] = ( + remaining_fraction + ) + other_model_data["baseline_fetched_at"] = now_ts + other_model_data["requests_at_baseline"] = synced_count + # Sync reset timestamp if valid + if valid_reset_ts: + other_model_data["quota_reset_ts"] = reset_timestamp + # Sync window start time + window_start = model_data.get("window_start_ts") + if window_start: + other_model_data["window_start_ts"] = window_start + # Sync cooldown if exhausted (with ±5 min check) + if is_exhausted and valid_reset_ts: + existing_grouped = model_cooldowns.get(grouped_model) + should_update_grouped = ( + existing_grouped is None + or abs(existing_grouped - reset_timestamp) > 300 + ) + if should_update_grouped: + model_cooldowns[grouped_model] = reset_timestamp + + # Defensive clamp for grouped models when exhausted + if ( + max_requests is not None + and other_model_data["request_count"] > max_requests + ): + other_model_data["request_count"] = max_requests + other_model_data["quota_display"] = ( + f"{max_requests}/{max_requests}" + ) + + lib_logger.debug( + f"Updated quota baseline for {mask_credential(credential)} model={model}: " + f"remaining={remaining_fraction:.2%}, synced_request_count={synced_count}" + ) + + await self._save_usage() + return cooldown_set_info + + async def _check_key_lockout(self, key: str, key_data: Dict): + """ + Checks if a key should be locked out due to multiple model failures. + + NOTE: This check is currently disabled. The original logic counted individual + models in long-term lockout, but this caused issues with quota groups - when + a single quota group (e.g., "claude" with 5 models) was exhausted, it would + count as 5 lockouts and trigger key-level lockout, blocking other quota groups + (like gemini) that were still available. + + The per-model and per-group cooldowns already handle quota exhaustion properly. + """ + # Disabled - see docstring above + pass + + async def get_stats_for_endpoint( + self, + provider_filter: Optional[str] = None, + include_global: bool = True, + ) -> Dict[str, Any]: + """ + Get usage stats formatted for the /v1/quota-stats endpoint. + + Aggregates data from key_usage.json grouped by provider. + Includes both current period stats and global (lifetime) stats. + + Args: + provider_filter: If provided, only return stats for this provider + include_global: If True, include global/lifetime stats alongside current + + Returns: + { + "providers": { + "provider_name": { + "credential_count": int, + "active_count": int, + "on_cooldown_count": int, + "total_requests": int, + "tokens": { + "input_cached": int, + "input_uncached": int, + "input_cache_pct": float, + "output": int + }, + "approx_cost": float | None, + "credentials": [...], + "global": {...} # If include_global is True + } + }, + "summary": {...}, + "global_summary": {...}, # If include_global is True + "timestamp": float + } + """ + await self._lazy_init() + + now_ts = time.time() + providers: Dict[str, Dict[str, Any]] = {} + # Track global stats separately + global_providers: Dict[str, Dict[str, Any]] = {} + + async with self._data_lock: + if not self._usage_data: + return { + "providers": {}, + "summary": { + "total_providers": 0, + "total_credentials": 0, + "active_credentials": 0, + "exhausted_credentials": 0, + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_total_cost": 0.0, + }, + "global_summary": { + "total_providers": 0, + "total_credentials": 0, + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_total_cost": 0.0, + }, + "data_source": "cache", + "timestamp": now_ts, + } + + for credential, cred_data in self._usage_data.items(): + # Extract provider from credential path + provider = self._get_provider_from_credential(credential) + if not provider: + continue + + # Apply filter if specified + if provider_filter and provider != provider_filter: + continue + + # Initialize provider entry + if provider not in providers: + providers[provider] = { + "credential_count": 0, + "active_count": 0, + "on_cooldown_count": 0, + "exhausted_count": 0, + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_cost": 0.0, + "credentials": [], + } + global_providers[provider] = { + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_cost": 0.0, + } + + prov_stats = providers[provider] + prov_stats["credential_count"] += 1 + + # Determine credential status and cooldowns + key_cooldown = cred_data.get("key_cooldown_until", 0) or 0 + model_cooldowns = cred_data.get("model_cooldowns", {}) + + # Build active cooldowns with remaining time + active_cooldowns = {} + for model, cooldown_ts in model_cooldowns.items(): + if cooldown_ts > now_ts: + remaining_seconds = int(cooldown_ts - now_ts) + active_cooldowns[model] = { + "until_ts": cooldown_ts, + "remaining_seconds": remaining_seconds, + } + + key_cooldown_remaining = None + if key_cooldown > now_ts: + key_cooldown_remaining = int(key_cooldown - now_ts) + + has_active_cooldown = key_cooldown > now_ts or len(active_cooldowns) > 0 + + # Check if exhausted (all quota groups exhausted for Antigravity) + is_exhausted = False + models_data = cred_data.get("models", {}) + if models_data: + # Check if any model has remaining quota + all_exhausted = True + for model_stats in models_data.values(): + if isinstance(model_stats, dict): + baseline = model_stats.get("baseline_remaining_fraction") + if baseline is None or baseline > 0: + all_exhausted = False + break + if all_exhausted and len(models_data) > 0: + is_exhausted = True + + if is_exhausted: + prov_stats["exhausted_count"] += 1 + status = "exhausted" + elif has_active_cooldown: + prov_stats["on_cooldown_count"] += 1 + status = "cooldown" + else: + prov_stats["active_count"] += 1 + status = "active" + + # Aggregate token stats (current period) + cred_tokens = { + "input_cached": 0, + "input_uncached": 0, + "output": 0, + } + cred_requests = 0 + cred_cost = 0.0 + + # Aggregate global token stats + cred_global_tokens = { + "input_cached": 0, + "input_uncached": 0, + "output": 0, + } + cred_global_requests = 0 + cred_global_cost = 0.0 + + # Handle per-model structure (current period) + if models_data: + for model_name, model_stats in models_data.items(): + if not isinstance(model_stats, dict): + continue + # Prefer request_count if available and non-zero, else fall back to success+failure + req_count = model_stats.get("request_count", 0) + if req_count > 0: + cred_requests += req_count + else: + cred_requests += model_stats.get("success_count", 0) + cred_requests += model_stats.get("failure_count", 0) + # Token stats - track cached separately + cred_tokens["input_cached"] += model_stats.get( + "prompt_tokens_cached", 0 + ) + cred_tokens["input_uncached"] += model_stats.get( + "prompt_tokens", 0 + ) + cred_tokens["output"] += model_stats.get("completion_tokens", 0) + cred_cost += model_stats.get("approx_cost", 0.0) + + # Handle legacy daily structure + daily_data = cred_data.get("daily", {}) + daily_models = daily_data.get("models", {}) + for model_name, model_stats in daily_models.items(): + if not isinstance(model_stats, dict): + continue + cred_requests += model_stats.get("success_count", 0) + cred_tokens["input_cached"] += model_stats.get( + "prompt_tokens_cached", 0 + ) + cred_tokens["input_uncached"] += model_stats.get("prompt_tokens", 0) + cred_tokens["output"] += model_stats.get("completion_tokens", 0) + cred_cost += model_stats.get("approx_cost", 0.0) + + # Handle global stats + global_data = cred_data.get("global", {}) + global_models = global_data.get("models", {}) + for model_name, model_stats in global_models.items(): + if not isinstance(model_stats, dict): + continue + cred_global_requests += model_stats.get("success_count", 0) + cred_global_tokens["input_cached"] += model_stats.get( + "prompt_tokens_cached", 0 + ) + cred_global_tokens["input_uncached"] += model_stats.get( + "prompt_tokens", 0 + ) + cred_global_tokens["output"] += model_stats.get( + "completion_tokens", 0 + ) + cred_global_cost += model_stats.get("approx_cost", 0.0) + + # Add current period stats to global totals + cred_global_requests += cred_requests + cred_global_tokens["input_cached"] += cred_tokens["input_cached"] + cred_global_tokens["input_uncached"] += cred_tokens["input_uncached"] + cred_global_tokens["output"] += cred_tokens["output"] + cred_global_cost += cred_cost + + # Build credential entry + # Mask credential identifier for display + if credential.startswith("env://"): + identifier = credential + else: + identifier = Path(credential).name + + cred_entry = { + "identifier": identifier, + "full_path": credential, + "status": status, + "last_used_ts": cred_data.get("last_used_ts"), + "requests": cred_requests, + "tokens": cred_tokens, + "approx_cost": cred_cost if cred_cost > 0 else None, + } + + # Add cooldown info + if key_cooldown_remaining is not None: + cred_entry["key_cooldown_remaining"] = key_cooldown_remaining + if active_cooldowns: + cred_entry["model_cooldowns"] = active_cooldowns + + # Add global stats for this credential + if include_global: + # Calculate global cache percentage + global_total_input = ( + cred_global_tokens["input_cached"] + + cred_global_tokens["input_uncached"] + ) + global_cache_pct = ( + round( + cred_global_tokens["input_cached"] + / global_total_input + * 100, + 1, + ) + if global_total_input > 0 + else 0 + ) + + cred_entry["global"] = { + "requests": cred_global_requests, + "tokens": { + "input_cached": cred_global_tokens["input_cached"], + "input_uncached": cred_global_tokens["input_uncached"], + "input_cache_pct": global_cache_pct, + "output": cred_global_tokens["output"], + }, + "approx_cost": cred_global_cost + if cred_global_cost > 0 + else None, + } + + # Add model-specific data for providers with per-model tracking + if models_data: + cred_entry["models"] = {} + for model_name, model_stats in models_data.items(): + if not isinstance(model_stats, dict): + continue + cred_entry["models"][model_name] = { + "requests": model_stats.get("success_count", 0) + + model_stats.get("failure_count", 0), + "request_count": model_stats.get("request_count", 0), + "success_count": model_stats.get("success_count", 0), + "failure_count": model_stats.get("failure_count", 0), + "prompt_tokens": model_stats.get("prompt_tokens", 0), + "prompt_tokens_cached": model_stats.get( + "prompt_tokens_cached", 0 + ), + "completion_tokens": model_stats.get( + "completion_tokens", 0 + ), + "approx_cost": model_stats.get("approx_cost", 0.0), + "window_start_ts": model_stats.get("window_start_ts"), + "quota_reset_ts": model_stats.get("quota_reset_ts"), + # Quota baseline fields (Antigravity-specific) + "baseline_remaining_fraction": model_stats.get( + "baseline_remaining_fraction" + ), + "baseline_fetched_at": model_stats.get( + "baseline_fetched_at" + ), + "quota_max_requests": model_stats.get("quota_max_requests"), + "quota_display": model_stats.get("quota_display"), + } + + prov_stats["credentials"].append(cred_entry) + + # Aggregate to provider totals (current period) + prov_stats["total_requests"] += cred_requests + prov_stats["tokens"]["input_cached"] += cred_tokens["input_cached"] + prov_stats["tokens"]["input_uncached"] += cred_tokens["input_uncached"] + prov_stats["tokens"]["output"] += cred_tokens["output"] + if cred_cost > 0: + prov_stats["approx_cost"] += cred_cost + + # Aggregate to global provider totals + global_providers[provider]["total_requests"] += cred_global_requests + global_providers[provider]["tokens"]["input_cached"] += ( + cred_global_tokens["input_cached"] + ) + global_providers[provider]["tokens"]["input_uncached"] += ( + cred_global_tokens["input_uncached"] + ) + global_providers[provider]["tokens"]["output"] += cred_global_tokens[ + "output" + ] + global_providers[provider]["approx_cost"] += cred_global_cost + + # Calculate cache percentages for each provider + for provider, prov_stats in providers.items(): + total_input = ( + prov_stats["tokens"]["input_cached"] + + prov_stats["tokens"]["input_uncached"] + ) + if total_input > 0: + prov_stats["tokens"]["input_cache_pct"] = round( + prov_stats["tokens"]["input_cached"] / total_input * 100, 1 + ) + # Set cost to None if 0 + if prov_stats["approx_cost"] == 0: + prov_stats["approx_cost"] = None + + # Calculate global cache percentages + if include_global and provider in global_providers: + gp = global_providers[provider] + global_total = ( + gp["tokens"]["input_cached"] + gp["tokens"]["input_uncached"] + ) + if global_total > 0: + gp["tokens"]["input_cache_pct"] = round( + gp["tokens"]["input_cached"] / global_total * 100, 1 + ) + if gp["approx_cost"] == 0: + gp["approx_cost"] = None + prov_stats["global"] = gp + + # Build summary (current period) + total_creds = sum(p["credential_count"] for p in providers.values()) + active_creds = sum(p["active_count"] for p in providers.values()) + exhausted_creds = sum(p["exhausted_count"] for p in providers.values()) + total_requests = sum(p["total_requests"] for p in providers.values()) + total_input_cached = sum( + p["tokens"]["input_cached"] for p in providers.values() + ) + total_input_uncached = sum( + p["tokens"]["input_uncached"] for p in providers.values() + ) + total_output = sum(p["tokens"]["output"] for p in providers.values()) + total_cost = sum(p["approx_cost"] or 0 for p in providers.values()) + + total_input = total_input_cached + total_input_uncached + input_cache_pct = ( + round(total_input_cached / total_input * 100, 1) if total_input > 0 else 0 + ) + + result = { + "providers": providers, + "summary": { + "total_providers": len(providers), + "total_credentials": total_creds, + "active_credentials": active_creds, + "exhausted_credentials": exhausted_creds, + "total_requests": total_requests, + "tokens": { + "input_cached": total_input_cached, + "input_uncached": total_input_uncached, + "input_cache_pct": input_cache_pct, + "output": total_output, + }, + "approx_total_cost": total_cost if total_cost > 0 else None, + }, + "data_source": "cache", + "timestamp": now_ts, + } + + # Build global summary + if include_global: + global_total_requests = sum( + gp["total_requests"] for gp in global_providers.values() + ) + global_total_input_cached = sum( + gp["tokens"]["input_cached"] for gp in global_providers.values() + ) + global_total_input_uncached = sum( + gp["tokens"]["input_uncached"] for gp in global_providers.values() + ) + global_total_output = sum( + gp["tokens"]["output"] for gp in global_providers.values() + ) + global_total_cost = sum( + gp["approx_cost"] or 0 for gp in global_providers.values() + ) + + global_total_input = global_total_input_cached + global_total_input_uncached + global_input_cache_pct = ( + round(global_total_input_cached / global_total_input * 100, 1) + if global_total_input > 0 + else 0 + ) + + result["global_summary"] = { + "total_providers": len(global_providers), + "total_credentials": total_creds, + "total_requests": global_total_requests, + "tokens": { + "input_cached": global_total_input_cached, + "input_uncached": global_total_input_uncached, + "input_cache_pct": global_input_cache_pct, + "output": global_total_output, + }, + "approx_total_cost": global_total_cost + if global_total_cost > 0 + else None, + } + + return result + + async def reload_from_disk(self) -> None: + """ + Force reload usage data from disk. + + Useful when another process may have updated the file. + """ + async with self._init_lock: + self._initialized.clear() + await self._load_usage() + await self._reset_daily_stats_if_needed() + self._initialized.set() diff --git a/src/rotator_library/background_refresher.py b/src/rotator_library/background_refresher.py index e3da1f76..acc66c89 100644 --- a/src/rotator_library/background_refresher.py +++ b/src/rotator_library/background_refresher.py @@ -234,9 +234,13 @@ async def _run_provider_background_job( # Run immediately on start if configured if run_on_start: try: - await provider.run_background_job( - self._client.usage_manager, credentials - ) + usage_manager = self._client.usage_managers.get(provider_name) + if usage_manager is None: + lib_logger.debug( + f"Skipping {provider_name} {job_name}: no UsageManager" + ) + return + await provider.run_background_job(usage_manager, credentials) lib_logger.debug(f"{provider_name} {job_name}: initial run complete") except Exception as e: lib_logger.error( @@ -247,9 +251,13 @@ async def _run_provider_background_job( while True: try: await asyncio.sleep(interval) - await provider.run_background_job( - self._client.usage_manager, credentials - ) + usage_manager = self._client.usage_managers.get(provider_name) + if usage_manager is None: + lib_logger.debug( + f"Skipping {provider_name} {job_name}: no UsageManager" + ) + return + await provider.run_background_job(usage_manager, credentials) lib_logger.debug(f"{provider_name} {job_name}: periodic run complete") except asyncio.CancelledError: lib_logger.debug(f"{provider_name} {job_name}: cancelled") @@ -259,6 +267,7 @@ async def _run_provider_background_job( async def _run(self): """The main loop for OAuth token refresh.""" + await self._client.initialize_usage_managers() # Initialize credentials (load persisted tiers) before starting await self._initialize_credentials() diff --git a/src/rotator_library/client/__init__.py b/src/rotator_library/client/__init__.py new file mode 100644 index 00000000..4307f84d --- /dev/null +++ b/src/rotator_library/client/__init__.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Client package for LLM API key rotation. + +This package provides the RotatingClient and associated components +for intelligent credential rotation and retry logic. + +Public API: + RotatingClient: Main client class for making API requests + StreamedAPIError: Exception for streaming errors + +Components (for advanced usage): + RequestExecutor: Unified retry/rotation logic + CredentialFilter: Tier compatibility filtering + ModelResolver: Model name resolution + ProviderTransforms: Provider-specific transforms + StreamingHandler: Streaming response processing +""" + +from .rotating_client import RotatingClient +from ..core.errors import StreamedAPIError + +# Also expose components for advanced usage +from .executor import RequestExecutor +from .filters import CredentialFilter +from .models import ModelResolver +from .transforms import ProviderTransforms +from .streaming import StreamingHandler +from .anthropic import AnthropicHandler +from .types import AvailabilityStats, RetryState, ExecutionResult + +__all__ = [ + # Main public API + "RotatingClient", + "StreamedAPIError", + # Components + "RequestExecutor", + "CredentialFilter", + "ModelResolver", + "ProviderTransforms", + "StreamingHandler", + "AnthropicHandler", + # Types + "AvailabilityStats", + "RetryState", + "ExecutionResult", +] diff --git a/src/rotator_library/client/anthropic.py b/src/rotator_library/client/anthropic.py new file mode 100644 index 00000000..f3f6119c --- /dev/null +++ b/src/rotator_library/client/anthropic.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Anthropic API compatibility handler for RotatingClient. + +This module provides Anthropic SDK compatibility methods that allow using +Anthropic's Messages API format with the credential rotation system. +""" + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional + +from ..anthropic_compat import ( + AnthropicMessagesRequest, + AnthropicCountTokensRequest, + translate_anthropic_request, + openai_to_anthropic_response, + anthropic_streaming_wrapper, + anthropic_to_openai_messages, + anthropic_to_openai_tools, +) +from ..transaction_logger import TransactionLogger + +if TYPE_CHECKING: + from .rotating_client import RotatingClient + +lib_logger = logging.getLogger("rotator_library") + + +class AnthropicHandler: + """ + Handler for Anthropic API compatibility methods. + + This class provides methods to handle Anthropic Messages API requests + by translating them to OpenAI format, processing through the client's + acompletion method, and converting responses back to Anthropic format. + + Example: + handler = AnthropicHandler(client) + response = await handler.messages(request, raw_request) + """ + + def __init__(self, client: "RotatingClient"): + """ + Initialize the Anthropic handler. + + Args: + client: The RotatingClient instance to use for completions + """ + self._client = client + + async def messages( + self, + request: AnthropicMessagesRequest, + raw_request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + ) -> Any: + """ + Handle Anthropic Messages API requests. + + This method accepts requests in Anthropic's format, translates them to + OpenAI format internally, processes them through the existing acompletion + method, and returns responses in Anthropic's format. + + Args: + request: An AnthropicMessagesRequest object + raw_request: Optional raw request object for disconnect checks + pre_request_callback: Optional async callback before each API request + + Returns: + For non-streaming: dict in Anthropic Messages format + For streaming: AsyncGenerator yielding Anthropic SSE format strings + """ + request_id = f"msg_{uuid.uuid4().hex[:24]}" + original_model = request.model + + # Extract provider from model for logging + provider = original_model.split("/")[0] if "/" in original_model else "unknown" + + # Create Anthropic transaction logger if request logging is enabled + anthropic_logger = None + if self._client.enable_request_logging: + anthropic_logger = TransactionLogger( + provider, + original_model, + enabled=True, + api_format="ant", + ) + # Log original Anthropic request + anthropic_logger.log_request( + request.model_dump(exclude_none=True), + filename="anthropic_request.json", + ) + + # Translate Anthropic request to OpenAI format + openai_request = translate_anthropic_request(request) + + # Pass parent log directory to acompletion for nested logging + if anthropic_logger and anthropic_logger.log_dir: + openai_request["_parent_log_dir"] = anthropic_logger.log_dir + + if request.stream: + # Streaming response + response_generator = self._client.acompletion( + request=raw_request, + pre_request_callback=pre_request_callback, + **openai_request, + ) + + # Create disconnect checker if raw_request provided + is_disconnected = None + if raw_request is not None and hasattr(raw_request, "is_disconnected"): + is_disconnected = raw_request.is_disconnected + + # Return the streaming wrapper + # Note: For streaming, the anthropic response logging happens in the wrapper + return anthropic_streaming_wrapper( + openai_stream=response_generator, + original_model=original_model, + request_id=request_id, + is_disconnected=is_disconnected, + transaction_logger=anthropic_logger, + ) + else: + # Non-streaming response + response = await self._client.acompletion( + request=raw_request, + pre_request_callback=pre_request_callback, + **openai_request, + ) + + # Convert OpenAI response to Anthropic format + openai_response = ( + response.model_dump() + if hasattr(response, "model_dump") + else dict(response) + ) + anthropic_response = openai_to_anthropic_response( + openai_response, original_model + ) + + # Override the ID with our request ID + anthropic_response["id"] = request_id + + # Log Anthropic response + if anthropic_logger: + anthropic_logger.log_response( + anthropic_response, + filename="anthropic_response.json", + ) + + return anthropic_response + + async def count_tokens( + self, + request: AnthropicCountTokensRequest, + ) -> dict: + """ + Handle Anthropic count_tokens API requests. + + Counts the number of tokens that would be used by a Messages API request. + This is useful for estimating costs and managing context windows. + + Args: + request: An AnthropicCountTokensRequest object + + Returns: + Dict with input_tokens count in Anthropic format + """ + anthropic_request = request.model_dump(exclude_none=True) + + openai_messages = anthropic_to_openai_messages( + anthropic_request.get("messages", []), anthropic_request.get("system") + ) + + # Count tokens for messages + message_tokens = self._client.token_count( + model=request.model, + messages=openai_messages, + ) + + # Count tokens for tools if present + tool_tokens = 0 + if request.tools: + # Tools add tokens based on their definitions + # Convert to JSON string and count tokens for tool definitions + openai_tools = anthropic_to_openai_tools( + [tool.model_dump() for tool in request.tools] + ) + if openai_tools: + # Serialize tools to count their token contribution + tools_text = json.dumps(openai_tools) + tool_tokens = self._client.token_count( + model=request.model, + text=tools_text, + ) + + total_tokens = message_tokens + tool_tokens + + return {"input_tokens": total_tokens} diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py new file mode 100644 index 00000000..0e5513ad --- /dev/null +++ b/src/rotator_library/client/executor.py @@ -0,0 +1,1125 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Unified request execution with retry and rotation. + +This module extracts and unifies the retry logic that was duplicated in: +- _execute_with_retry (lines 1174-1945) +- _streaming_acompletion_with_retry (lines 1947-2780) + +The RequestExecutor provides a single code path for all request types, +with streaming vs non-streaming handled as a parameter. +""" + +import asyncio +import json +import logging +import random +import time +from typing import ( + Any, + AsyncGenerator, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Tuple, + Union, +) + +import httpx +import litellm +from litellm.exceptions import ( + APIConnectionError, + RateLimitError, + ServiceUnavailableError, + InternalServerError, +) + +from ..core.types import RequestContext, ErrorAction +from ..core.errors import ( + NoAvailableKeysError, + PreRequestCallbackError, + StreamedAPIError, + ClassifiedError, + RequestErrorAccumulator, + classify_error, + should_rotate_on_error, + should_retry_same_key, + mask_credential, +) +from ..core.constants import DEFAULT_MAX_RETRIES +from ..request_sanitizer import sanitize_request_payload +from ..transaction_logger import TransactionLogger +from ..failure_logger import log_failure + +from .types import RetryState, AvailabilityStats +from .filters import CredentialFilter +from .transforms import ProviderTransforms +from .streaming import StreamingHandler + +if TYPE_CHECKING: + from ..usage import UsageManager + +lib_logger = logging.getLogger("rotator_library") + + +class RequestExecutor: + """ + Unified retry/rotation logic for all request types. + + This class handles: + - Credential rotation across providers + - Per-credential retry with backoff + - Error classification and handling + - Streaming and non-streaming requests + """ + + def __init__( + self, + usage_managers: Dict[str, "UsageManager"], + cooldown_manager: Any, + credential_filter: CredentialFilter, + provider_transforms: ProviderTransforms, + provider_plugins: Dict[str, Any], + http_client: httpx.AsyncClient, + max_retries: int = DEFAULT_MAX_RETRIES, + global_timeout: int = 30, + abort_on_callback_error: bool = True, + litellm_provider_params: Optional[Dict[str, Any]] = None, + litellm_logger_fn: Optional[Any] = None, + ): + """ + Initialize RequestExecutor. + + Args: + usage_managers: Dict mapping provider names to UsageManager instances + cooldown_manager: CooldownManager instance + credential_filter: CredentialFilter instance + provider_transforms: ProviderTransforms instance + provider_plugins: Dict mapping provider names to plugin classes + http_client: Shared httpx.AsyncClient for provider requests + max_retries: Max retries per credential + global_timeout: Global request timeout in seconds + abort_on_callback_error: Abort on pre-request callback errors + """ + self._usage_managers = usage_managers + self._cooldown = cooldown_manager + self._filter = credential_filter + self._transforms = provider_transforms + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = {} + self._http_client = http_client + self._max_retries = max_retries + self._global_timeout = global_timeout + self._abort_on_callback_error = abort_on_callback_error + self._litellm_provider_params = litellm_provider_params or {} + self._litellm_logger_fn = litellm_logger_fn + # StreamingHandler no longer needs usage_manager - we pass cred_context directly + self._streaming_handler = StreamingHandler() + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """Get or create a plugin instance for a provider.""" + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + async def execute( + self, + context: RequestContext, + ) -> Union[Any, AsyncGenerator[str, None]]: + """ + Execute request with retry/rotation. + + This is the main entry point for request execution. + + Args: + context: RequestContext with all request details + + Returns: + Response object or async generator for streaming + """ + if context.streaming: + return self._execute_streaming(context) + else: + return await self._execute_non_streaming(context) + + async def _prepare_execution( + self, + context: RequestContext, + ) -> Tuple["UsageManager", Any, List[str], Optional[str], Dict[str, Any]]: + provider = context.provider + model = context.model + + usage_manager = self._usage_managers.get(provider) + if not usage_manager: + raise NoAvailableKeysError(f"No UsageManager for provider {provider}") + + filter_result = self._filter.filter_by_tier( + context.credentials, model, provider + ) + credentials = filter_result.all_usable + quota_group = usage_manager.get_model_quota_group(model) + + await self._ensure_initialized(usage_manager, context, filter_result) + await self._validate_request(provider, model, context.kwargs) + + if not credentials: + raise NoAvailableKeysError(f"No compatible credentials for model {model}") + + request_headers = ( + dict(context.request.headers) if context.request is not None else {} + ) + + return usage_manager, filter_result, credentials, quota_group, request_headers + + async def _execute_non_streaming( + self, + context: RequestContext, + ) -> Any: + """ + Execute non-streaming request with retry/rotation. + + Args: + context: RequestContext with all request details + + Returns: + Response object + """ + provider = context.provider + model = context.model + deadline = context.deadline + + ( + usage_manager, + filter_result, + credentials, + quota_group, + request_headers, + ) = await self._prepare_execution(context) + + error_accumulator = RequestErrorAccumulator() + error_accumulator.model = model + error_accumulator.provider = provider + + retry_state = RetryState() + last_exception: Optional[Exception] = None + + while time.time() < deadline: + # Check for untried credentials + untried = [c for c in credentials if c not in retry_state.tried_credentials] + if not untried: + lib_logger.warning( + f"All {len(credentials)} credentials tried for {model}" + ) + break + + availability = await usage_manager.get_availability_stats( + model, quota_group + ) + blocked = availability.get("blocked_by", {}) + blocked_parts = [] + if blocked.get("cooldowns"): + blocked_parts.append(f"cd:{blocked['cooldowns']}") + if blocked.get("fair_cycle"): + blocked_parts.append(f"fc:{blocked['fair_cycle']}") + if blocked.get("custom_caps"): + blocked_parts.append(f"cap:{blocked['custom_caps']}") + if blocked.get("window_limits"): + blocked_parts.append(f"wl:{blocked['window_limits']}") + if blocked.get("concurrent"): + blocked_parts.append(f"con:{blocked['concurrent']}") + blocked_str = f"({', '.join(blocked_parts)})" if blocked_parts else "" + lib_logger.info( + f"Acquiring key for model {model}. Tried keys: {len(retry_state.tried_credentials)}/" + f"{availability.get('available', 0)}({availability.get('total', 0)}{blocked_str})" + ) + + # Wait for provider cooldown + await self._wait_for_cooldown(provider, deadline) + + # Acquire credential using context manager + try: + availability = await usage_manager.get_availability_stats( + model, quota_group + ) + async with await usage_manager.acquire_credential( + model=model, + quota_group=quota_group, + candidates=untried, + priorities=filter_result.priorities, + deadline=deadline, + ) as cred_context: + cred = cred_context.credential + retry_state.record_attempt(cred) + + state = getattr(usage_manager, "states", {}).get( + cred_context.stable_id + ) + tier = state.tier if state else None + priority = state.priority if state else None + selection_mode = availability.get("rotation_mode") + quota_display = "?/?" + primary_def = None + if state and getattr(usage_manager, "window_manager", None): + primary_def = ( + usage_manager.window_manager.get_primary_definition() + ) + if state and primary_def: + scope_key = ( + quota_group if primary_def.applies_to == "group" else model + ) + usage = state.get_usage_for_scope( + primary_def.applies_to, scope_key, create=False + ) + if usage: + window = usage.windows.get(primary_def.name) + if window and window.limit is not None: + remaining = max(0, window.limit - window.request_count) + pct = ( + round(remaining / window.limit * 100) + if window.limit + else 0 + ) + quota_display = ( + f"{window.request_count}/{window.limit} [{pct}%]" + ) + lib_logger.info( + f"Acquired key {mask_credential(cred)} for model {model} " + f"(tier: {tier}, priority: {priority}, selection: {selection_mode}, quota: {quota_display})" + ) + + try: + # Apply transforms + kwargs = await self._transforms.apply( + provider, model, cred, context.kwargs.copy() + ) + + # Sanitize request payload + kwargs = sanitize_request_payload(kwargs, model) + + # Apply provider-specific LiteLLM params + self._apply_litellm_provider_params(provider, kwargs) + + # Get provider plugin + plugin = self._get_plugin_instance(provider) + + # Add transaction context for provider logging + if context.transaction_logger: + kwargs["transaction_context"] = ( + context.transaction_logger.get_context() + ) + + # Execute request with retries + for attempt in range(self._max_retries): + try: + lib_logger.info( + f"Attempting call with credential {mask_credential(cred)} " + f"(Attempt {attempt + 1}/{self._max_retries})" + ) + # Pre-request callback + if context.pre_request_callback: + try: + await context.pre_request_callback( + context.request, kwargs + ) + except Exception as e: + if self._abort_on_callback_error: + raise PreRequestCallbackError(str(e)) from e + lib_logger.warning( + f"Pre-request callback failed: {e}" + ) + + # Make the API call + if plugin and plugin.has_custom_logic(): + kwargs["credential_identifier"] = cred + response = await plugin.acompletion( + self._http_client, **kwargs + ) + else: + # Standard LiteLLM call + kwargs["api_key"] = cred + self._apply_litellm_logger(kwargs) + response = await litellm.acompletion(**kwargs) + + # Success! Extract token usage if available + ( + prompt_tokens, + completion_tokens, + prompt_tokens_cached, + prompt_tokens_cache_write, + thinking_tokens, + ) = self._extract_usage_tokens(response) + approx_cost = self._calculate_cost( + provider, model, response + ) + response_headers = self._extract_response_headers( + response + ) + + cred_context.mark_success( + response=response, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cached, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + response_headers=response_headers, + ) + + lib_logger.info( + f"Recorded usage from response object for key {mask_credential(cred)}" + ) + + # Log response if transaction logging enabled + if context.transaction_logger: + try: + response_data = ( + response.model_dump() + if hasattr(response, "model_dump") + else response + ) + context.transaction_logger.log_response( + response_data + ) + except Exception as log_err: + lib_logger.debug( + f"Failed to log response: {log_err}" + ) + + return response + + except Exception as e: + last_exception = e + action = await self._handle_error_with_context( + e, + cred_context, + model, + provider, + attempt, + error_accumulator, + retry_state, + request_headers, + ) + + if action == ErrorAction.RETRY_SAME: + continue + elif action == ErrorAction.ROTATE: + break # Try next credential + else: # FAIL + raise + + except PreRequestCallbackError: + raise + except Exception: + # Let context manager handle cleanup + pass + + except NoAvailableKeysError: + break + + # All credentials exhausted + error_accumulator.timeout_occurred = time.time() >= deadline + if last_exception and not error_accumulator.has_errors(): + raise last_exception + + # Return error response + return error_accumulator.build_client_error_response() + + async def _execute_streaming( + self, + context: RequestContext, + ) -> AsyncGenerator[str, None]: + """ + Execute streaming request with retry/rotation. + + This is an async generator that yields SSE-formatted strings. + + Args: + context: RequestContext with all request details + + Yields: + SSE-formatted strings + """ + provider = context.provider + model = context.model + deadline = context.deadline + + try: + ( + usage_manager, + filter_result, + credentials, + quota_group, + request_headers, + ) = await self._prepare_execution(context) + except NoAvailableKeysError as exc: + error_data = { + "error": { + "message": str(exc), + "type": "proxy_error", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + + error_accumulator = RequestErrorAccumulator() + error_accumulator.model = model + error_accumulator.provider = provider + + retry_state = RetryState() + last_exception: Optional[Exception] = None + + try: + availability = await usage_manager.get_availability_stats( + model, quota_group + ) + blocked = availability.get("blocked_by", {}) + blocked_parts = [] + if blocked.get("cooldowns"): + blocked_parts.append(f"cd:{blocked['cooldowns']}") + if blocked.get("fair_cycle"): + blocked_parts.append(f"fc:{blocked['fair_cycle']}") + if blocked.get("custom_caps"): + blocked_parts.append(f"cap:{blocked['custom_caps']}") + if blocked.get("window_limits"): + blocked_parts.append(f"wl:{blocked['window_limits']}") + if blocked.get("concurrent"): + blocked_parts.append(f"con:{blocked['concurrent']}") + blocked_str = f"({', '.join(blocked_parts)})" if blocked_parts else "" + lib_logger.info( + f"Acquiring credential for model {model}. Tried credentials: {len(retry_state.tried_credentials)}/" + f"{availability.get('available', 0)}({availability.get('total', 0)}{blocked_str})" + ) + + while time.time() < deadline: + # Check for untried credentials + untried = [ + c for c in credentials if c not in retry_state.tried_credentials + ] + if not untried: + lib_logger.warning( + f"All {len(credentials)} credentials tried for {model}" + ) + break + + # Wait for provider cooldown + remaining = deadline - time.time() + if remaining <= 0: + break + await self._wait_for_cooldown(provider, deadline) + + # Acquire credential using context manager + try: + availability = await usage_manager.get_availability_stats( + model, quota_group + ) + async with await usage_manager.acquire_credential( + model=model, + quota_group=quota_group, + candidates=untried, + priorities=filter_result.priorities, + deadline=deadline, + ) as cred_context: + cred = cred_context.credential + retry_state.record_attempt(cred) + + state = getattr(usage_manager, "states", {}).get( + cred_context.stable_id + ) + tier = state.tier if state else None + priority = state.priority if state else None + selection_mode = availability.get("rotation_mode") + quota_display = "?/?" + primary_def = None + if state and getattr(usage_manager, "window_manager", None): + primary_def = ( + usage_manager.window_manager.get_primary_definition() + ) + if state and primary_def: + scope_key = ( + quota_group + if primary_def.applies_to == "group" + else model + ) + usage = state.get_usage_for_scope( + primary_def.applies_to, scope_key, create=False + ) + if usage: + window = usage.windows.get(primary_def.name) + if window and window.limit is not None: + remaining = max( + 0, window.limit - window.request_count + ) + pct = ( + round(remaining / window.limit * 100) + if window.limit + else 0 + ) + quota_display = f"{window.request_count}/{window.limit} [{pct}%]" + lib_logger.info( + f"Acquired key {mask_credential(cred)} for model {model} " + f"(tier: {tier}, priority: {priority}, selection: {selection_mode}, quota: {quota_display})" + ) + + try: + # Apply transforms + kwargs = await self._transforms.apply( + provider, model, cred, context.kwargs.copy() + ) + + # Sanitize request payload + kwargs = sanitize_request_payload(kwargs, model) + + # Apply provider-specific LiteLLM params + self._apply_litellm_provider_params(provider, kwargs) + + # Add stream options (but not for iflow - it returns 406) + if provider != "iflow": + if "stream_options" not in kwargs: + kwargs["stream_options"] = {} + if "include_usage" not in kwargs["stream_options"]: + kwargs["stream_options"]["include_usage"] = True + + # Get provider plugin + plugin = self._get_plugin_instance(provider) + skip_cost_calculation = bool( + plugin + and getattr(plugin, "skip_cost_calculation", False) + ) + + # Add transaction context for provider logging + if context.transaction_logger: + kwargs["transaction_context"] = ( + context.transaction_logger.get_context() + ) + + # Execute request with retries + for attempt in range(self._max_retries): + try: + lib_logger.info( + f"Attempting stream with credential {mask_credential(cred)} " + f"(Attempt {attempt + 1}/{self._max_retries})" + ) + # Pre-request callback + if context.pre_request_callback: + try: + await context.pre_request_callback( + context.request, kwargs + ) + except Exception as e: + if self._abort_on_callback_error: + raise PreRequestCallbackError( + str(e) + ) from e + lib_logger.warning( + f"Pre-request callback failed: {e}" + ) + + # Make the API call + if plugin and plugin.has_custom_logic(): + kwargs["credential_identifier"] = cred + stream = await plugin.acompletion( + self._http_client, **kwargs + ) + else: + kwargs["api_key"] = cred + kwargs["stream"] = True + self._apply_litellm_logger(kwargs) + stream = await litellm.acompletion(**kwargs) + + # Hand off to streaming handler with cred_context + # The handler will call mark_success on completion + base_stream = self._streaming_handler.wrap_stream( + stream, + cred, + model, + context.request, + cred_context, + skip_cost_calculation=skip_cost_calculation, + ) + + lib_logger.info( + f"Stream connection established for credential {mask_credential(cred)}. " + "Processing response." + ) + + # Wrap with transaction logging if enabled + if context.transaction_logger: + async for ( + chunk + ) in self._transaction_logging_stream_wrapper( + base_stream, + context.transaction_logger, + context.kwargs, + ): + yield chunk + else: + async for chunk in base_stream: + yield chunk + return + + except StreamedAPIError as e: + last_exception = e + original = getattr(e, "data", e) + classified = classify_error(original, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + error_accumulator.record_error( + cred, classified, str(original)[:150] + ) + + # Track consecutive quota failures + if classified.error_type == "quota_exceeded": + retry_state.increment_quota_failures() + if retry_state.consecutive_quota_failures >= 3: + lib_logger.error( + "3 consecutive quota errors in streaming - " + "request may be too large" + ) + cred_context.mark_failure(classified) + error_data = { + "error": { + "message": "Request exceeds quota for all credentials", + "type": "quota_exhausted", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + else: + retry_state.reset_quota_failures() + + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + + cred_context.mark_failure(classified) + break # Rotate + + except (RateLimitError, httpx.HTTPStatusError) as e: + last_exception = e + classified = classify_error(e, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + error_accumulator.record_error( + cred, classified, str(e)[:150] + ) + + # Track consecutive quota failures + if classified.error_type == "quota_exceeded": + retry_state.increment_quota_failures() + if retry_state.consecutive_quota_failures >= 3: + lib_logger.error( + "3 consecutive quota errors in streaming - " + "request may be too large" + ) + cred_context.mark_failure(classified) + error_data = { + "error": { + "message": "Request exceeds quota for all credentials", + "type": "quota_exhausted", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + else: + retry_state.reset_quota_failures() + + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + + cred_context.mark_failure(classified) + break # Rotate + + except ( + APIConnectionError, + InternalServerError, + ServiceUnavailableError, + ) as e: + last_exception = e + classified = classify_error(e, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + + if attempt >= self._max_retries - 1: + error_accumulator.record_error( + cred, classified, str(e)[:150] + ) + cred_context.mark_failure(classified) + break # Rotate + + # Calculate wait time + wait_time = classified.retry_after or ( + 2**attempt + ) + random.uniform(0, 1) + remaining = deadline - time.time() + if wait_time > remaining: + break # No time to wait + + await asyncio.sleep(wait_time) + continue # Retry + + except Exception as e: + last_exception = e + classified = classify_error(e, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + error_accumulator.record_error( + cred, classified, str(e)[:150] + ) + + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + + cred_context.mark_failure(classified) + break # Rotate + + except PreRequestCallbackError: + raise + except Exception: + # Let context manager handle cleanup + pass + + except NoAvailableKeysError: + break + + # All credentials exhausted or timeout + error_accumulator.timeout_occurred = time.time() >= deadline + error_data = error_accumulator.build_client_error_response() + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + + except NoAvailableKeysError as e: + lib_logger.error(f"No keys available: {e}") + error_data = {"error": {"message": str(e), "type": "proxy_busy"}} + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + lib_logger.error(f"Unhandled exception in streaming: {e}", exc_info=True) + error_data = {"error": {"message": str(e), "type": "proxy_internal_error"}} + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + + def _apply_litellm_provider_params( + self, provider: str, kwargs: Dict[str, Any] + ) -> None: + """Merge provider-specific LiteLLM parameters into request kwargs.""" + params = self._litellm_provider_params.get(provider) + if not params: + return + kwargs["litellm_params"] = { + **params, + **kwargs.get("litellm_params", {}), + } + + def _apply_litellm_logger(self, kwargs: Dict[str, Any]) -> None: + """Attach LiteLLM logger callback if configured.""" + if self._litellm_logger_fn and "logger_fn" not in kwargs: + kwargs["logger_fn"] = self._litellm_logger_fn + + def _extract_response_headers(self, response: Any) -> Optional[Dict[str, Any]]: + """Extract response headers from LiteLLM response objects.""" + if hasattr(response, "response") and response.response is not None: + headers = getattr(response.response, "headers", None) + if headers is not None: + return dict(headers) + headers = getattr(response, "headers", None) + if headers is not None: + return dict(headers) + return None + + async def _wait_for_cooldown( + self, + provider: str, + deadline: float, + ) -> None: + """ + Wait for provider-level cooldown to end. + + Args: + provider: Provider name + deadline: Request deadline + """ + if not self._cooldown: + return + + remaining = await self._cooldown.get_remaining_cooldown(provider) + if remaining > 0: + budget = deadline - time.time() + if remaining > budget: + lib_logger.warning( + f"Provider {provider} cooldown ({remaining:.1f}s) exceeds budget ({budget:.1f}s)" + ) + return # Will fail on no keys available + lib_logger.info(f"Waiting {remaining:.1f}s for {provider} cooldown") + await asyncio.sleep(remaining) + + async def _handle_error_with_context( + self, + error: Exception, + cred_context: Any, # CredentialContext + model: str, + provider: str, + attempt: int, + error_accumulator: RequestErrorAccumulator, + retry_state: RetryState, + request_headers: Dict[str, Any], + ) -> str: + """ + Handle an error and determine next action. + + Args: + error: The caught exception + cred_context: CredentialContext for marking failure + model: Model name + provider: Provider name + attempt: Current attempt number + error_accumulator: Error tracking + retry_state: Retry state tracking + + Returns: + ErrorAction indicating what to do next + """ + classified = classify_error(error, provider) + error_message = str(error)[:150] + credential = cred_context.credential + + log_failure( + api_key=credential, + model=model, + attempt=attempt + 1, + error=error, + request_headers=request_headers, + ) + + # Check for quota errors + if classified.error_type == "quota_exceeded": + retry_state.increment_quota_failures() + if retry_state.consecutive_quota_failures >= 3: + # Likely request is too large + lib_logger.error( + f"3 consecutive quota errors - request may be too large" + ) + error_accumulator.record_error(credential, classified, error_message) + cred_context.mark_failure(classified) + return ErrorAction.FAIL + else: + retry_state.reset_quota_failures() + + # Check if should rotate + if not should_rotate_on_error(classified): + error_accumulator.record_error(credential, classified, error_message) + cred_context.mark_failure(classified) + return ErrorAction.FAIL + + # Check if should retry same key + if should_retry_same_key(classified) and attempt < self._max_retries - 1: + wait_time = classified.retry_after or (2**attempt) + random.uniform(0, 1) + lib_logger.info( + f"Retrying {mask_credential(credential)} in {wait_time:.1f}s" + ) + await asyncio.sleep(wait_time) + return ErrorAction.RETRY_SAME + + # Record error and rotate + error_accumulator.record_error(credential, classified, error_message) + cred_context.mark_failure(classified) + lib_logger.info( + f"Rotating from {mask_credential(credential)} after {classified.error_type}" + ) + return ErrorAction.ROTATE + + async def _ensure_initialized( + self, + usage_manager: "UsageManager", + context: RequestContext, + filter_result: "FilterResult", + ) -> None: + if usage_manager.initialized: + return + await usage_manager.initialize( + context.credentials, + priorities=filter_result.priorities, + tiers=filter_result.tier_names, + ) + + async def _validate_request( + self, + provider: str, + model: str, + kwargs: Dict[str, Any], + ) -> None: + plugin = self._get_plugin_instance(provider) + if not plugin or not hasattr(plugin, "validate_request"): + return + + result = plugin.validate_request(kwargs, model) + if asyncio.iscoroutine(result): + result = await result + if result is False: + raise ValueError(f"Request validation failed for {provider}/{model}") + if isinstance(result, str): + raise ValueError(result) + + def _extract_usage_tokens(self, response: Any) -> tuple[int, int, int, int, int]: + prompt_tokens = 0 + completion_tokens = 0 + cached_tokens = 0 + cache_write_tokens = 0 + thinking_tokens = 0 + + if hasattr(response, "usage") and response.usage: + prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0 + + prompt_details = getattr(response.usage, "prompt_tokens_details", None) + if prompt_details: + if isinstance(prompt_details, dict): + cached_tokens = prompt_details.get("cached_tokens", 0) or 0 + cache_write_tokens = ( + prompt_details.get("cache_creation_tokens", 0) or 0 + ) + else: + cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0 + cache_write_tokens = ( + getattr(prompt_details, "cache_creation_tokens", 0) or 0 + ) + + completion_details = getattr( + response.usage, "completion_tokens_details", None + ) + if completion_details: + if isinstance(completion_details, dict): + thinking_tokens = completion_details.get("reasoning_tokens", 0) or 0 + else: + thinking_tokens = ( + getattr(completion_details, "reasoning_tokens", 0) or 0 + ) + + cache_read_tokens = getattr(response.usage, "cache_read_tokens", None) + if cache_read_tokens is not None: + cached_tokens = cache_read_tokens or 0 + cache_creation_tokens = getattr( + response.usage, "cache_creation_tokens", None + ) + if cache_creation_tokens is not None: + cache_write_tokens = cache_creation_tokens or 0 + + if thinking_tokens and completion_tokens >= thinking_tokens: + completion_tokens = completion_tokens - thinking_tokens + + uncached_prompt = max(0, prompt_tokens - cached_tokens) + return ( + uncached_prompt, + completion_tokens, + cached_tokens, + cache_write_tokens, + thinking_tokens, + ) + + def _calculate_cost(self, provider: str, model: str, response: Any) -> float: + plugin = self._get_plugin_instance(provider) + if plugin and getattr(plugin, "skip_cost_calculation", False): + return 0.0 + + try: + if isinstance(response, litellm.EmbeddingResponse): + model_info = litellm.get_model_info(model) + input_cost = model_info.get("input_cost_per_token") + if input_cost: + return (response.usage.prompt_tokens or 0) * input_cost + return 0.0 + + cost = litellm.completion_cost( + completion_response=response, + model=model, + ) + return float(cost) if cost is not None else 0.0 + except Exception as exc: + lib_logger.debug(f"Cost calculation failed for {model}: {exc}") + return 0.0 + + async def _transaction_logging_stream_wrapper( + self, + stream: AsyncGenerator[str, None], + transaction_logger: TransactionLogger, + request_kwargs: Dict[str, Any], + ) -> AsyncGenerator[str, None]: + """ + Wrap a stream to log chunks and final response to TransactionLogger. + + Yields all chunks unchanged while accumulating them for final logging. + + Args: + stream: The SSE stream from wrap_stream + transaction_logger: TransactionLogger instance + request_kwargs: Original request kwargs for context + + Yields: + SSE-formatted strings unchanged + """ + chunks = [] + + async for sse_line in stream: + yield sse_line + + # Parse and accumulate for final logging + if sse_line.startswith("data: ") and not sse_line.startswith( + "data: [DONE]" + ): + try: + content = sse_line[6:].strip() + if content: + chunk_data = json.loads(content) + chunks.append(chunk_data) + transaction_logger.log_stream_chunk(chunk_data) + except json.JSONDecodeError: + lib_logger.debug( + f"Failed to parse chunk for logging: {sse_line[:100]}" + ) + + # Log assembled final response + if chunks: + try: + final_response = TransactionLogger.assemble_streaming_response(chunks) + transaction_logger.log_response(final_response) + except Exception as e: + lib_logger.debug( + f"Failed to assemble/log final streaming response: {e}" + ) diff --git a/src/rotator_library/client/filters.py b/src/rotator_library/client/filters.py new file mode 100644 index 00000000..18ab8b13 --- /dev/null +++ b/src/rotator_library/client/filters.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Credential filtering by tier compatibility and priority. + +Extracts the tier filtering logic that was duplicated in client.py +at lines 1242-1315 and 2004-2076. +""" + +import logging +from typing import Any, Dict, List, Optional + +from ..core.types import FilterResult + +lib_logger = logging.getLogger("rotator_library") + + +class CredentialFilter: + """ + Filter and group credentials by tier compatibility and priority. + + This class extracts the credential filtering logic that was previously + duplicated in both _execute_with_retry and _streaming_acompletion_with_retry. + """ + + def __init__(self, provider_plugins: Dict[str, Any]): + """ + Initialize the CredentialFilter. + + Args: + provider_plugins: Dict mapping provider names to plugin classes/instances + """ + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = {} + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """ + Get or create a plugin instance for a provider. + + Args: + provider: Provider name + + Returns: + Plugin instance or None if not found + """ + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + # Check if it's a class or already an instance + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + def filter_by_tier( + self, + credentials: List[str], + model: str, + provider: str, + ) -> FilterResult: + """ + Filter credentials by tier compatibility for a model. + + Args: + credentials: List of credential identifiers + model: Model being requested + provider: Provider name + + Returns: + FilterResult with categorized credentials + """ + plugin = self._get_plugin_instance(provider) + + # Get tier requirement for model + required_tier = None + if plugin and hasattr(plugin, "get_model_tier_requirement"): + required_tier = plugin.get_model_tier_requirement(model) + + compatible: List[str] = [] + unknown: List[str] = [] + incompatible: List[str] = [] + priorities: Dict[str, int] = {} + tier_names: Dict[str, str] = {} + + for cred in credentials: + # Get priority and tier name + priority = None + tier_name = None + + if plugin: + if hasattr(plugin, "get_credential_priority"): + priority = plugin.get_credential_priority(cred) + if hasattr(plugin, "get_credential_tier_name"): + tier_name = plugin.get_credential_tier_name(cred) + + if priority is not None: + priorities[cred] = priority + if tier_name: + tier_names[cred] = tier_name + + # Categorize by tier compatibility + if required_tier is None: + # No tier requirement - all compatible + compatible.append(cred) + elif priority is None: + # Unknown priority - keep as candidate + unknown.append(cred) + elif priority <= required_tier: + # Known compatible (lower priority number = higher tier) + compatible.append(cred) + else: + # Known incompatible + incompatible.append(cred) + + # Log if all credentials are incompatible + if incompatible and not compatible and not unknown: + lib_logger.warning( + f"Model {model} requires tier <= {required_tier}, " + f"but all {len(incompatible)} credentials are incompatible" + ) + + return FilterResult( + compatible=compatible, + unknown=unknown, + incompatible=incompatible, + priorities=priorities, + tier_names=tier_names, + ) + + def group_by_priority( + self, + credentials: List[str], + priorities: Dict[str, int], + ) -> Dict[int, List[str]]: + """ + Group credentials by priority level. + + Args: + credentials: List of credential identifiers + priorities: Dict mapping credentials to priority levels + + Returns: + Dict mapping priority levels to credential lists, sorted by priority + """ + groups: Dict[int, List[str]] = {} + + for cred in credentials: + priority = priorities.get(cred, 999) + if priority not in groups: + groups[priority] = [] + groups[priority].append(cred) + + # Return sorted by priority (lower = higher priority) + return dict(sorted(groups.items())) + + def get_highest_priority_credentials( + self, + credentials: List[str], + priorities: Dict[str, int], + ) -> List[str]: + """ + Get credentials with the highest priority (lowest priority number). + + Args: + credentials: List of credential identifiers + priorities: Dict mapping credentials to priority levels + + Returns: + List of credentials with the highest priority + """ + if not credentials: + return [] + + groups = self.group_by_priority(credentials, priorities) + if not groups: + return credentials + + # Get the lowest priority number (highest priority) + highest_priority = min(groups.keys()) + return groups[highest_priority] diff --git a/src/rotator_library/client/models.py b/src/rotator_library/client/models.py new file mode 100644 index 00000000..e6de7567 --- /dev/null +++ b/src/rotator_library/client/models.py @@ -0,0 +1,228 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Model name resolution and filtering. + +Extracts model-related logic from client.py including: +- _resolve_model_id (lines 867-902) +- _is_model_ignored (lines 587-619) +- _is_model_whitelisted (lines 621-651) +""" + +import fnmatch +import logging +from typing import Any, Dict, List, Optional + +lib_logger = logging.getLogger("rotator_library") + + +class ModelResolver: + """ + Resolve model names and apply filtering rules. + + Handles: + - Model ID resolution (display name -> actual ID) + - Whitelist/blacklist filtering + - Provider prefix handling + """ + + def __init__( + self, + provider_plugins: Dict[str, Any], + model_definitions: Optional[Any] = None, + ignore_models: Optional[Dict[str, List[str]]] = None, + whitelist_models: Optional[Dict[str, List[str]]] = None, + ): + """ + Initialize the ModelResolver. + + Args: + provider_plugins: Dict mapping provider names to plugin classes + model_definitions: ModelDefinitions instance for ID mapping + ignore_models: Models to ignore/blacklist per provider + whitelist_models: Models to explicitly whitelist per provider + """ + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = {} + self._definitions = model_definitions + self._ignore = ignore_models or {} + self._whitelist = whitelist_models or {} + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """ + Get or create a plugin instance for a provider. + """ + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + def resolve_model_id(self, model: str, provider: str) -> str: + """ + Resolve display name to actual model ID. + + For custom models with name/ID mappings, returns the ID. + Otherwise, returns the model name unchanged. + + Args: + model: Full model string with provider (e.g., "iflow/DS-v3.2") + provider: Provider name (e.g., "iflow") + + Returns: + Full model string with ID (e.g., "iflow/deepseek-v3.2") + """ + model_name = model.split("/")[-1] if "/" in model else model + + # Check provider plugin first + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "model_definitions"): + resolved = plugin.model_definitions.get_model_id(provider, model_name) + if resolved and resolved != model_name: + return f"{provider}/{resolved}" + + # Fallback to client-level definitions + if self._definitions: + resolved = self._definitions.get_model_id(provider, model_name) + if resolved and resolved != model_name: + return f"{provider}/{resolved}" + + return model + + def is_model_allowed(self, model: str, provider: str) -> bool: + """ + Check if model passes whitelist/blacklist filters. + + Whitelist takes precedence over blacklist. + + Args: + model: Model string (with or without provider prefix) + provider: Provider name + + Returns: + True if model is allowed, False if blocked + """ + # Whitelist takes precedence + if self._is_whitelisted(model, provider): + return True + + # Then check blacklist + if self._is_blacklisted(model, provider): + return False + + return True + + def _is_blacklisted(self, model: str, provider: str) -> bool: + """ + Check if model is blacklisted. + + Supports glob patterns: + - "gpt-4" - exact match + - "gpt-4*" - prefix wildcard + - "*-preview" - suffix wildcard + - "*" - match all + + Args: + model: Model string + provider: Provider name (used to get ignore list) + + Returns: + True if model is blacklisted + """ + model_provider = model.split("/")[0] if "/" in model else provider + + if model_provider not in self._ignore: + return False + + ignore_list = self._ignore[model_provider] + if ignore_list == ["*"]: + return True + + # Extract model name without provider prefix + model_name = model.split("/", 1)[1] if "/" in model else model + + for pattern in ignore_list: + # Use fnmatch for glob pattern support + if fnmatch.fnmatch(model_name, pattern): + return True + if fnmatch.fnmatch(model, pattern): + return True + + return False + + def _is_whitelisted(self, model: str, provider: str) -> bool: + """ + Check if model is whitelisted. + + Same pattern support as blacklist. + + Args: + model: Model string + provider: Provider name + + Returns: + True if model is whitelisted + """ + model_provider = model.split("/")[0] if "/" in model else provider + + if model_provider not in self._whitelist: + return False + + whitelist = self._whitelist[model_provider] + model_name = model.split("/", 1)[1] if "/" in model else model + + for pattern in whitelist: + if fnmatch.fnmatch(model_name, pattern): + return True + if fnmatch.fnmatch(model, pattern): + return True + + return False + + @staticmethod + def extract_provider(model: str) -> str: + """ + Extract provider name from model string. + + Args: + model: Model string (e.g., "openai/gpt-4") + + Returns: + Provider name (e.g., "openai") or empty string if no prefix + """ + return model.split("/")[0] if "/" in model else "" + + @staticmethod + def strip_provider(model: str) -> str: + """ + Strip provider prefix from model string. + + Args: + model: Model string (e.g., "openai/gpt-4") + + Returns: + Model name without prefix (e.g., "gpt-4") + """ + return model.split("/", 1)[1] if "/" in model else model + + @staticmethod + def ensure_provider_prefix(model: str, provider: str) -> str: + """ + Ensure model string has provider prefix. + + Args: + model: Model string + provider: Provider name to add if missing + + Returns: + Model string with provider prefix + """ + if "/" in model: + return model + return f"{provider}/{model}" diff --git a/src/rotator_library/client/rotating_client.py b/src/rotator_library/client/rotating_client.py new file mode 100644 index 00000000..82c81e9a --- /dev/null +++ b/src/rotator_library/client/rotating_client.py @@ -0,0 +1,891 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Slim RotatingClient facade. + +This is a lightweight facade that delegates to extracted components: +- RequestExecutor: Unified retry/rotation logic +- CredentialFilter: Tier compatibility filtering +- ModelResolver: Model name resolution and filtering +- ProviderTransforms: Provider-specific request mutations +- StreamingHandler: Streaming response processing + +The original client.py was ~3000 lines. This facade is ~300 lines, +with all complexity moved to specialized modules. +""" + +import asyncio +import json +import logging +import os +import random +import time +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, TYPE_CHECKING + +import httpx +import litellm +from litellm.litellm_core_utils.token_counter import token_counter + +from ..core.types import RequestContext +from ..core.errors import NoAvailableKeysError, mask_credential +from ..core.config import ConfigLoader +from ..core.constants import ( + DEFAULT_MAX_RETRIES, + DEFAULT_GLOBAL_TIMEOUT, + DEFAULT_ROTATION_TOLERANCE, +) + +from .filters import CredentialFilter +from .models import ModelResolver +from .transforms import ProviderTransforms +from .executor import RequestExecutor +from .anthropic import AnthropicHandler + +# Import providers and other dependencies +from ..providers import PROVIDER_PLUGINS +from ..cooldown_manager import CooldownManager +from ..credential_manager import CredentialManager +from ..background_refresher import BackgroundRefresher +from ..model_definitions import ModelDefinitions +from ..transaction_logger import TransactionLogger +from ..provider_config import ProviderConfig as LiteLLMProviderConfig +from ..utils.paths import get_default_root, get_logs_dir, get_oauth_dir +from ..utils.suppress_litellm_warnings import suppress_litellm_serialization_warnings +from ..failure_logger import configure_failure_logger + +# Import new usage package +from ..usage import UsageManager as NewUsageManager +from ..usage.config import load_provider_usage_config, WindowDefinition + +if TYPE_CHECKING: + from ..anthropic_compat import AnthropicMessagesRequest, AnthropicCountTokensRequest + +lib_logger = logging.getLogger("rotator_library") + + +class RotatingClient: + """ + A client that intelligently rotates and retries API keys using LiteLLM, + with support for both streaming and non-streaming responses. + + This is a slim facade that delegates to specialized components: + - RequestExecutor: Handles retry/rotation logic + - CredentialFilter: Filters credentials by tier + - ModelResolver: Resolves model names + - ProviderTransforms: Applies provider-specific transforms + """ + + def __init__( + self, + api_keys: Optional[Dict[str, List[str]]] = None, + oauth_credentials: Optional[Dict[str, List[str]]] = None, + max_retries: int = DEFAULT_MAX_RETRIES, + usage_file_path: Optional[Union[str, Path]] = None, + configure_logging: bool = True, + global_timeout: int = DEFAULT_GLOBAL_TIMEOUT, + abort_on_callback_error: bool = True, + litellm_provider_params: Optional[Dict[str, Any]] = None, + ignore_models: Optional[Dict[str, List[str]]] = None, + whitelist_models: Optional[Dict[str, List[str]]] = None, + enable_request_logging: bool = False, + max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, + rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE, + data_dir: Optional[Union[str, Path]] = None, + ): + """ + Initialize the RotatingClient. + + See original client.py for full parameter documentation. + """ + # Resolve data directory + self.data_dir = Path(data_dir).resolve() if data_dir else get_default_root() + + # Configure logging + configure_failure_logger(get_logs_dir(self.data_dir)) + os.environ["LITELLM_LOG"] = "ERROR" + litellm.set_verbose = False + litellm.drop_params = True + suppress_litellm_serialization_warnings() + + if configure_logging: + lib_logger.propagate = True + if lib_logger.hasHandlers(): + lib_logger.handlers.clear() + lib_logger.addHandler(logging.NullHandler()) + else: + lib_logger.propagate = False + + # Process credentials + api_keys = api_keys or {} + oauth_credentials = oauth_credentials or {} + api_keys = {p: k for p, k in api_keys.items() if k} + oauth_credentials = {p: c for p, c in oauth_credentials.items() if c} + + if not api_keys and not oauth_credentials: + lib_logger.warning( + "No provider credentials configured. Client will be unable to make requests." + ) + + # Discover OAuth credentials if not provided + if oauth_credentials: + self.oauth_credentials = oauth_credentials + else: + cred_manager = CredentialManager( + os.environ, oauth_dir=get_oauth_dir(self.data_dir) + ) + self.oauth_credentials = cred_manager.discover_and_prepare() + + # Build combined credentials + self.all_credentials: Dict[str, List[str]] = {} + for provider, keys in api_keys.items(): + self.all_credentials.setdefault(provider, []).extend(keys) + for provider, paths in self.oauth_credentials.items(): + self.all_credentials.setdefault(provider, []).extend(paths) + + self.api_keys = api_keys + self.oauth_providers = set(self.oauth_credentials.keys()) + + # Store configuration + self.max_retries = max_retries + self.global_timeout = global_timeout + self.abort_on_callback_error = abort_on_callback_error + self.litellm_provider_params = litellm_provider_params or {} + self._litellm_logger_fn = self._litellm_logger_callback + self.enable_request_logging = enable_request_logging + self.max_concurrent_requests_per_key = max_concurrent_requests_per_key or {} + + # Validate concurrent requests config + for provider, max_val in self.max_concurrent_requests_per_key.items(): + if max_val < 1: + lib_logger.warning( + f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1." + ) + self.max_concurrent_requests_per_key[provider] = 1 + + # Initialize configuration loader + self._config_loader = ConfigLoader(PROVIDER_PLUGINS) + + # Initialize components + self._provider_plugins = PROVIDER_PLUGINS + self._provider_instances: Dict[str, Any] = {} + + # Initialize managers + self.cooldown_manager = CooldownManager() + self.background_refresher = BackgroundRefresher(self) + self.model_definitions = ModelDefinitions() + self.provider_config = LiteLLMProviderConfig() + self.http_client = httpx.AsyncClient() + + # Initialize extracted components + self._credential_filter = CredentialFilter(PROVIDER_PLUGINS) + self._model_resolver = ModelResolver( + PROVIDER_PLUGINS, + self.model_definitions, + ignore_models or {}, + whitelist_models or {}, + ) + self._provider_transforms = ProviderTransforms( + PROVIDER_PLUGINS, + self.provider_config, + ) + + # Initialize UsageManagers (one per provider) using new usage package + self._usage_managers: Dict[str, NewUsageManager] = {} + + # Resolve usage file path base + if usage_file_path: + base_path = Path(usage_file_path) + if base_path.suffix: + base_path = base_path.parent + self._usage_base_path = base_path / "usage" + else: + self._usage_base_path = self.data_dir / "usage" + self._usage_base_path.mkdir(parents=True, exist_ok=True) + + # Build provider configs using ConfigLoader + provider_configs = {} + for provider in self.all_credentials.keys(): + provider_configs[provider] = self._config_loader.load_provider_config( + provider + ) + + # Create UsageManager for each provider + for provider, credentials in self.all_credentials.items(): + config = load_provider_usage_config(provider, PROVIDER_PLUGINS) + # Override tolerance from constructor param + config.rotation_tolerance = rotation_tolerance + + self._apply_usage_reset_config(provider, credentials, config) + + usage_file = self._usage_base_path / f"usage_{provider}.json" + + # Get max concurrent for this provider + max_concurrent = self.max_concurrent_requests_per_key.get(provider) + + manager = NewUsageManager( + provider=provider, + file_path=usage_file, + provider_plugins=PROVIDER_PLUGINS, + config=config, + max_concurrent_per_key=max_concurrent, + ) + self._usage_managers[provider] = manager + + # Initialize executor with new usage managers + self._executor = RequestExecutor( + usage_managers=self._usage_managers, + cooldown_manager=self.cooldown_manager, + credential_filter=self._credential_filter, + provider_transforms=self._provider_transforms, + provider_plugins=PROVIDER_PLUGINS, + http_client=self.http_client, + max_retries=max_retries, + global_timeout=global_timeout, + abort_on_callback_error=abort_on_callback_error, + litellm_provider_params=self.litellm_provider_params, + litellm_logger_fn=self._litellm_logger_fn, + ) + + self._model_list_cache: Dict[str, List[str]] = {} + self._usage_initialized = False + self._usage_init_lock = asyncio.Lock() + + # Initialize Anthropic compatibility handler + self._anthropic_handler = AnthropicHandler(self) + + async def __aenter__(self): + await self.initialize_usage_managers() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def initialize_usage_managers(self) -> None: + """Initialize usage managers once before background jobs run.""" + if self._usage_initialized: + return + async with self._usage_init_lock: + if self._usage_initialized: + return + for provider, manager in self._usage_managers.items(): + credentials = self.all_credentials.get(provider, []) + priorities, tiers = self._get_credential_metadata(provider, credentials) + await manager.initialize( + credentials, priorities=priorities, tiers=tiers + ) + summaries = [] + for provider, manager in self._usage_managers.items(): + credentials = self.all_credentials.get(provider, []) + status = ( + f"loaded {manager.loaded_credentials}" + if manager.loaded_from_storage + else "fresh" + ) + summaries.append(f"{provider}:{len(credentials)} ({status})") + if summaries: + lib_logger.info( + f"Usage managers initialized: {', '.join(sorted(summaries))}" + ) + self._usage_initialized = True + + async def close(self): + """Close the HTTP client and save usage data.""" + # Save and shutdown new usage managers + for manager in self._usage_managers.values(): + await manager.shutdown() + + if hasattr(self, "http_client") and self.http_client: + await self.http_client.aclose() + + async def acompletion( + self, + request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + **kwargs, + ) -> Union[Any, AsyncGenerator[str, None]]: + """ + Dispatcher for completion requests. + + Returns: + Response object or async generator for streaming + """ + model = kwargs.get("model", "") + provider = model.split("/")[0] if "/" in model else "" + + if not provider or provider not in self.all_credentials: + raise ValueError( + f"Invalid model format or no credentials for provider: {model}" + ) + + # Extract internal logging parameters (not passed to API) + parent_log_dir = kwargs.pop("_parent_log_dir", None) + + # Resolve model ID + resolved_model = self._model_resolver.resolve_model_id(model, provider) + kwargs["model"] = resolved_model + + # Create transaction logger if enabled + transaction_logger = None + if self.enable_request_logging: + transaction_logger = TransactionLogger( + provider=provider, + model=resolved_model, + enabled=True, + parent_dir=parent_log_dir, + ) + transaction_logger.log_request(kwargs) + + # Build request context + context = RequestContext( + model=resolved_model, + provider=provider, + kwargs=kwargs, + streaming=kwargs.get("stream", False), + credentials=self.all_credentials.get(provider, []), + deadline=time.time() + self.global_timeout, + request=request, + pre_request_callback=pre_request_callback, + transaction_logger=transaction_logger, + ) + + return await self._executor.execute(context) + + def aembedding( + self, + request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + **kwargs, + ) -> Any: + """ + Execute an embedding request with retry logic. + """ + model = kwargs.get("model", "") + provider = model.split("/")[0] if "/" in model else "" + + if not provider or provider not in self.all_credentials: + raise ValueError( + f"Invalid model format or no credentials for provider: {model}" + ) + + # Build request context (embeddings are never streaming) + context = RequestContext( + model=model, + provider=provider, + kwargs=kwargs, + streaming=False, + credentials=self.all_credentials.get(provider, []), + deadline=time.time() + self.global_timeout, + request=request, + pre_request_callback=pre_request_callback, + ) + + return self._executor.execute(context) + + def token_count(self, **kwargs) -> int: + """Calculate token count for text or messages. + + For Antigravity provider models, this also includes the preprompt tokens + that get injected during actual API calls (agent instruction + identity override). + This ensures token counts match actual usage. + """ + model = kwargs.get("model") + text = kwargs.get("text") + messages = kwargs.get("messages") + + if not model: + raise ValueError("'model' is required") + + # Calculate base token count + if messages: + base_count = token_counter(model=model, messages=messages) + elif text: + base_count = token_counter(model=model, text=text) + else: + raise ValueError("Either 'text' or 'messages' must be provided") + + # Add preprompt tokens for Antigravity provider + # The Antigravity provider injects system instructions during actual API calls, + # so we need to account for those tokens in the count + provider = model.split("/")[0] if "/" in model else "" + if provider == "antigravity": + try: + from ..providers.antigravity_provider import ( + get_antigravity_preprompt_text, + ) + + preprompt_text = get_antigravity_preprompt_text() + if preprompt_text: + preprompt_tokens = token_counter(model=model, text=preprompt_text) + base_count += preprompt_tokens + except ImportError: + # Provider not available, skip preprompt token counting + pass + + return base_count + + async def get_available_models(self, provider: str) -> List[str]: + """Get available models for a provider with caching.""" + if provider in self._model_list_cache: + return self._model_list_cache[provider] + + credentials = self.all_credentials.get(provider, []) + if not credentials: + return [] + + # Shuffle and try each credential + shuffled = list(credentials) + random.shuffle(shuffled) + + plugin = self._get_provider_instance(provider) + if not plugin: + return [] + + for cred in shuffled: + try: + models = await plugin.get_models(cred, self.http_client) + + # Apply whitelist/blacklist + final = [ + m + for m in models + if self._model_resolver.is_model_allowed(m, provider) + ] + + self._model_list_cache[provider] = final + return final + + except Exception as e: + lib_logger.debug( + f"Failed to get models for {provider} with {mask_credential(cred)}: {e}" + ) + continue + + return [] + + async def get_all_available_models( + self, + grouped: bool = True, + ) -> Union[Dict[str, List[str]], List[str]]: + """Get all available models across all providers.""" + providers = list(self.all_credentials.keys()) + tasks = [self.get_available_models(p) for p in providers] + results = await asyncio.gather(*tasks, return_exceptions=True) + + all_models: Dict[str, List[str]] = {} + for provider, result in zip(providers, results): + if isinstance(result, Exception): + lib_logger.error(f"Failed to get models for {provider}: {result}") + all_models[provider] = [] + else: + all_models[provider] = result + + if grouped: + return all_models + else: + flat = [] + for models in all_models.values(): + flat.extend(models) + return flat + + async def get_quota_stats( + self, + provider_filter: Optional[str] = None, + ) -> Dict[str, Any]: + """Get quota and usage stats for all credentials. + + Args: + provider_filter: Optional provider name to filter results + + Returns: + Dict with stats per provider + """ + providers = {} + + for provider, manager in self._usage_managers.items(): + if provider_filter and provider != provider_filter: + continue + providers[provider] = await manager.get_stats_for_endpoint() + + summary = { + "total_providers": len(providers), + "total_credentials": 0, + "active_credentials": 0, + "exhausted_credentials": 0, + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_total_cost": None, + } + + for prov in providers.values(): + summary["total_credentials"] += prov.get("credential_count", 0) + summary["active_credentials"] += prov.get("active_count", 0) + summary["exhausted_credentials"] += prov.get("exhausted_count", 0) + summary["total_requests"] += prov.get("total_requests", 0) + tokens = prov.get("tokens", {}) + summary["tokens"]["input_cached"] += tokens.get("input_cached", 0) + summary["tokens"]["input_uncached"] += tokens.get("input_uncached", 0) + summary["tokens"]["output"] += tokens.get("output", 0) + + total_input = ( + summary["tokens"]["input_cached"] + summary["tokens"]["input_uncached"] + ) + summary["tokens"]["input_cache_pct"] = ( + round(summary["tokens"]["input_cached"] / total_input * 100, 1) + if total_input > 0 + else 0 + ) + + approx_total_cost = 0.0 + has_cost = False + for prov in providers.values(): + cost = prov.get("approx_cost") + if cost: + approx_total_cost += cost + has_cost = True + summary["approx_total_cost"] = approx_total_cost if has_cost else None + + return { + "providers": providers, + "summary": summary, + "data_source": "cache", + "timestamp": time.time(), + } + + def get_oauth_credentials(self) -> Dict[str, List[str]]: + """Get discovered OAuth credentials.""" + return self.oauth_credentials + + def _get_provider_instance(self, provider: str) -> Optional[Any]: + """Get or create a provider plugin instance.""" + if provider not in self.all_credentials: + return None + + if provider not in self._provider_instances: + plugin_class = self._provider_plugins.get(provider) + if plugin_class: + self._provider_instances[provider] = plugin_class() + else: + return None + + return self._provider_instances[provider] + + def _get_credential_metadata( + self, + provider: str, + credentials: List[str], + ) -> tuple[Dict[str, int], Dict[str, str]]: + """Resolve priority and tier metadata for credentials.""" + plugin = self._get_provider_instance(provider) + priorities: Dict[str, int] = {} + tiers: Dict[str, str] = {} + + if not plugin: + return priorities, tiers + + for credential in credentials: + if hasattr(plugin, "get_credential_priority"): + priority = plugin.get_credential_priority(credential) + if priority is not None: + priorities[credential] = priority + if hasattr(plugin, "get_credential_tier_name"): + tier_name = plugin.get_credential_tier_name(credential) + if tier_name: + tiers[credential] = tier_name + + return priorities, tiers + + def get_usage_manager(self, provider: str) -> Optional[NewUsageManager]: + """ + Get the new UsageManager for a specific provider. + + Args: + provider: Provider name + + Returns: + UsageManager for the provider, or None if not found + """ + return self._usage_managers.get(provider) + + @property + def usage_managers(self) -> Dict[str, NewUsageManager]: + """Get all new usage managers.""" + return self._usage_managers + + def _apply_usage_reset_config( + self, + provider: str, + credentials: List[str], + config: Any, + ) -> None: + """Apply provider-specific usage reset config to window definitions.""" + if not credentials: + return + + plugin = self._get_provider_instance(provider) + if not plugin or not hasattr(plugin, "get_usage_reset_config"): + return + + try: + reset_config = plugin.get_usage_reset_config(credentials[0]) + except Exception as exc: + lib_logger.debug(f"Failed to load usage reset config for {provider}: {exc}") + return + + if not reset_config: + return + + window_seconds = reset_config.get("window_seconds") + if not window_seconds: + return + + mode = reset_config.get("mode", "credential") + applies_to = "credential" if mode == "credential" else "model" + + if window_seconds == 86400: + window_name = "daily" + elif window_seconds % 3600 == 0: + window_name = f"{window_seconds // 3600}h" + else: + window_name = "window" + + config.windows = [ + WindowDefinition.rolling( + name=window_name, + duration_seconds=int(window_seconds), + is_primary=True, + applies_to=applies_to, + ), + WindowDefinition.total(name="total", applies_to=applies_to), + ] + + def _sanitize_litellm_log(self, log_data: dict) -> dict: + """Remove large/sensitive fields from LiteLLM logs.""" + if not isinstance(log_data, dict): + return log_data + + keys_to_pop = [ + "messages", + "input", + "response", + "data", + "api_key", + "api_base", + "original_response", + "additional_args", + ] + nested_keys = ["kwargs", "litellm_params", "model_info", "proxy_server_request"] + + clean_data = json.loads(json.dumps(log_data, default=str)) + + def clean_recursively(data_dict: dict) -> None: + for key in keys_to_pop: + data_dict.pop(key, None) + for key in nested_keys: + if key in data_dict and isinstance(data_dict[key], dict): + clean_recursively(data_dict[key]) + for value in list(data_dict.values()): + if isinstance(value, dict): + clean_recursively(value) + + clean_recursively(clean_data) + return clean_data + + def _litellm_logger_callback(self, log_data: dict) -> None: + """Redirect LiteLLM logs into rotator library logger.""" + log_event_type = log_data.get("log_event_type") + if log_event_type in ["pre_api_call", "post_api_call"]: + return + + if not log_data.get("exception"): + sanitized_log = self._sanitize_litellm_log(log_data) + lib_logger.debug(f"LiteLLM Log: {sanitized_log}") + return + + model = log_data.get("model", "N/A") + error_info = log_data.get("standard_logging_object", {}).get( + "error_information", {} + ) + error_class = error_info.get("error_class", "UnknownError") + error_message = error_info.get( + "error_message", str(log_data.get("exception", "")) + ) + error_message = " ".join(error_message.split()) + + lib_logger.debug( + f"LiteLLM Callback Handled Error: Model={model} | " + f"Type={error_class} | Message='{error_message}'" + ) + + # ========================================================================= + # USAGE MANAGEMENT METHODS + # ========================================================================= + + async def reload_usage_from_disk(self) -> None: + """ + Force reload usage data from disk. + + Useful when wanting fresh stats without making external API calls. + """ + for manager in self._usage_managers.values(): + await manager.reload_from_disk() + + async def force_refresh_quota( + self, + provider: Optional[str] = None, + credential: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Force refresh quota from external API. + + For Antigravity, this fetches live quota data from the API. + For other providers, this is a no-op (just reloads from disk). + + Args: + provider: If specified, only refresh this provider + credential: If specified, only refresh this specific credential + + Returns: + Refresh result dict with success/failure info + """ + result = { + "action": "force_refresh", + "scope": "credential" + if credential + else ("provider" if provider else "all"), + "provider": provider, + "credential": credential, + "credentials_refreshed": 0, + "success_count": 0, + "failed_count": 0, + "duration_ms": 0, + "errors": [], + } + + start_time = time.time() + + # Determine which providers to refresh + if provider: + providers_to_refresh = ( + [provider] if provider in self.all_credentials else [] + ) + else: + providers_to_refresh = list(self.all_credentials.keys()) + + for prov in providers_to_refresh: + provider_class = self._provider_plugins.get(prov) + if not provider_class: + continue + + # Get or create provider instance + provider_instance = self._get_provider_instance(prov) + if not provider_instance: + continue + + # Check if provider supports quota refresh (like Antigravity) + if hasattr(provider_instance, "fetch_initial_baselines"): + # Get credentials to refresh + if credential: + # Find full path for this credential + creds_to_refresh = [] + for cred_path in self.all_credentials.get(prov, []): + if cred_path.endswith(credential) or cred_path == credential: + creds_to_refresh.append(cred_path) + break + else: + creds_to_refresh = self.all_credentials.get(prov, []) + + if not creds_to_refresh: + continue + + try: + # Fetch live quota from API for ALL specified credentials + quota_results = await provider_instance.fetch_initial_baselines( + creds_to_refresh + ) + + # Store baselines in usage manager + usage_manager = self._usage_managers.get(prov) + if usage_manager and hasattr( + provider_instance, "_store_baselines_to_usage_manager" + ): + stored = ( + await provider_instance._store_baselines_to_usage_manager( + quota_results, usage_manager, force=True + ) + ) + result["success_count"] += stored + + result["credentials_refreshed"] += len(creds_to_refresh) + + # Count failures + for cred_path, data in quota_results.items(): + if data.get("status") != "success": + result["failed_count"] += 1 + result["errors"].append( + f"{Path(cred_path).name}: {data.get('error', 'Unknown error')}" + ) + + except Exception as e: + lib_logger.error(f"Failed to refresh quota for {prov}: {e}") + result["errors"].append(f"{prov}: {str(e)}") + result["failed_count"] += len(creds_to_refresh) + + result["duration_ms"] = int((time.time() - start_time) * 1000) + return result + + # ========================================================================= + # ANTHROPIC API COMPATIBILITY METHODS + # ========================================================================= + + async def anthropic_messages( + self, + request: "AnthropicMessagesRequest", + raw_request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + ) -> Any: + """ + Handle Anthropic Messages API requests. + + This method accepts requests in Anthropic's format, translates them to + OpenAI format internally, processes them through the existing acompletion + method, and returns responses in Anthropic's format. + + Args: + request: An AnthropicMessagesRequest object + raw_request: Optional raw request object for disconnect checks + pre_request_callback: Optional async callback before each API request + + Returns: + For non-streaming: dict in Anthropic Messages format + For streaming: AsyncGenerator yielding Anthropic SSE format strings + """ + return await self._anthropic_handler.messages( + request=request, + raw_request=raw_request, + pre_request_callback=pre_request_callback, + ) + + async def anthropic_count_tokens( + self, + request: "AnthropicCountTokensRequest", + ) -> dict: + """ + Handle Anthropic count_tokens API requests. + + Counts the number of tokens that would be used by a Messages API request. + This is useful for estimating costs and managing context windows. + + Args: + request: An AnthropicCountTokensRequest object + + Returns: + Dict with input_tokens count in Anthropic format + """ + return await self._anthropic_handler.count_tokens(request=request) diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py new file mode 100644 index 00000000..b79488b9 --- /dev/null +++ b/src/rotator_library/client/streaming.py @@ -0,0 +1,425 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Streaming response handler. + +Extracts streaming logic from client.py _safe_streaming_wrapper (lines 904-1117). +Handles: +- Chunk processing with finish_reason logic +- JSON reassembly for fragmented responses +- Error detection in streamed data +- Usage tracking from final chunks +- Client disconnect handling +""" + +import codecs +import json +import logging +import re +from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional, TYPE_CHECKING + +import litellm + +from ..core.errors import StreamedAPIError, CredentialNeedsReauthError +from ..core.types import ProcessedChunk + +if TYPE_CHECKING: + from ..usage.manager import CredentialContext + +lib_logger = logging.getLogger("rotator_library") + + +class StreamingHandler: + """ + Process streaming responses with error handling and usage tracking. + + This class extracts the streaming logic that was in _safe_streaming_wrapper + and provides a clean interface for processing LiteLLM streams. + + Usage recording is handled via CredentialContext passed to wrap_stream(). + """ + + async def wrap_stream( + self, + stream: AsyncIterator[Any], + credential: str, + model: str, + request: Optional[Any] = None, + cred_context: Optional["CredentialContext"] = None, + skip_cost_calculation: bool = False, + ) -> AsyncGenerator[str, None]: + """ + Wrap a LiteLLM stream with error handling and usage tracking. + + FINISH_REASON HANDLING: + - Strip finish_reason from intermediate chunks (litellm defaults to "stop") + - Track accumulated_finish_reason with priority: tool_calls > length/content_filter > stop + - Only emit finish_reason on final chunk (detected by usage.completion_tokens > 0) + + Args: + stream: The async iterator from LiteLLM + credential: Credential identifier (for logging) + model: Model name for usage recording + request: Optional FastAPI request for disconnect detection + cred_context: CredentialContext for marking success/failure + + Yields: + SSE-formatted strings: "data: {...}\\n\\n" + """ + stream_completed = False + error_buffer = StreamBuffer() # Use StreamBuffer for JSON reassembly + accumulated_finish_reason: Optional[str] = None + has_tool_calls = False + prompt_tokens = 0 + prompt_tokens_cached = 0 + prompt_tokens_cache_write = 0 + prompt_tokens_uncached = 0 + completion_tokens = 0 + thinking_tokens = 0 + + # Use manual iteration to allow continue after partial JSON errors + stream_iterator = stream.__aiter__() + + try: + while True: + try: + # Check client disconnect before waiting for next chunk + if request and await request.is_disconnected(): + lib_logger.info( + f"Client disconnected. Aborting stream for model {model}." + ) + break + + chunk = await stream_iterator.__anext__() + + # Clear error buffer on successful chunk receipt + error_buffer.reset() + + # Process chunk + processed = self._process_chunk( + chunk, + accumulated_finish_reason, + has_tool_calls, + ) + + # Update tracking state + if processed.has_tool_calls: + has_tool_calls = True + accumulated_finish_reason = "tool_calls" + if processed.finish_reason and not has_tool_calls: + # Only update if not already tool_calls (highest priority) + accumulated_finish_reason = processed.finish_reason + if processed.usage and isinstance(processed.usage, dict): + # Extract token counts from final chunk + prompt_tokens = processed.usage.get("prompt_tokens", 0) + completion_tokens = processed.usage.get("completion_tokens", 0) + prompt_details = processed.usage.get("prompt_tokens_details") + if prompt_details: + if isinstance(prompt_details, dict): + prompt_tokens_cached = ( + prompt_details.get("cached_tokens", 0) or 0 + ) + prompt_tokens_cache_write = ( + prompt_details.get("cache_creation_tokens", 0) or 0 + ) + else: + prompt_tokens_cached = ( + getattr(prompt_details, "cached_tokens", 0) or 0 + ) + prompt_tokens_cache_write = ( + getattr(prompt_details, "cache_creation_tokens", 0) + or 0 + ) + completion_details = processed.usage.get( + "completion_tokens_details" + ) + if completion_details: + if isinstance(completion_details, dict): + thinking_tokens = ( + completion_details.get("reasoning_tokens", 0) or 0 + ) + else: + thinking_tokens = ( + getattr(completion_details, "reasoning_tokens", 0) + or 0 + ) + if processed.usage.get("cache_read_tokens") is not None: + prompt_tokens_cached = ( + processed.usage.get("cache_read_tokens") or 0 + ) + if processed.usage.get("cache_creation_tokens") is not None: + prompt_tokens_cache_write = ( + processed.usage.get("cache_creation_tokens") or 0 + ) + if thinking_tokens and completion_tokens >= thinking_tokens: + completion_tokens = completion_tokens - thinking_tokens + prompt_tokens_uncached = max( + 0, prompt_tokens - prompt_tokens_cached + ) + + yield processed.sse_string + + except StopAsyncIteration: + # Stream ended normally + stream_completed = True + break + + except CredentialNeedsReauthError as e: + # Credential needs re-auth - wrap for outer retry loop + if cred_context: + from ..error_handler import classify_error + + cred_context.mark_failure(classify_error(e)) + raise StreamedAPIError("Credential needs re-authentication", data=e) + + except json.JSONDecodeError as e: + # Partial JSON - accumulate and continue + error_buffer.append(str(e)) + if error_buffer.is_complete: + # We have complete JSON now + raise StreamedAPIError( + "Provider error", data=error_buffer.content + ) + # Continue waiting for more chunks + continue + + except Exception as e: + # Try to extract JSON from fragmented response + error_str = str(e) + error_buffer.append(error_str) + + # Check if buffer now has complete JSON + if error_buffer.is_complete: + if cred_context: + from ..error_handler import classify_error + + cred_context.mark_failure(classify_error(e)) + raise StreamedAPIError( + "Provider error in stream", data=error_buffer.content + ) + + # Try pattern matching for error extraction + extracted = self._try_extract_error(e, error_buffer.content) + if extracted: + if cred_context: + from ..error_handler import classify_error + + cred_context.mark_failure(classify_error(e)) + raise StreamedAPIError( + "Provider error in stream", data=extracted + ) + + # Not a JSON-related error, re-raise + raise + + except StreamedAPIError: + # Re-raise for retry loop + raise + + finally: + # Record usage if stream completed + if stream_completed: + if cred_context: + approx_cost = 0.0 + if not skip_cost_calculation: + approx_cost = self._calculate_stream_cost( + model, + prompt_tokens_uncached + prompt_tokens_cached, + completion_tokens + thinking_tokens, + ) + cred_context.mark_success( + prompt_tokens=prompt_tokens_uncached, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cached, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + # Yield [DONE] for completed streams + yield "data: [DONE]\n\n" + + def _process_chunk( + self, + chunk: Any, + accumulated_finish_reason: Optional[str], + has_tool_calls: bool, + ) -> ProcessedChunk: + """ + Process a single streaming chunk. + + Handles finish_reason logic: + - Strip from intermediate chunks + - Apply correct finish_reason on final chunk + + Args: + chunk: Raw chunk from LiteLLM + accumulated_finish_reason: Current accumulated finish reason + has_tool_calls: Whether any chunk has had tool_calls + + Returns: + ProcessedChunk with SSE string and metadata + """ + # Convert chunk to dict + if hasattr(chunk, "model_dump"): + chunk_dict = chunk.model_dump() + elif hasattr(chunk, "dict"): + chunk_dict = chunk.dict() + else: + chunk_dict = chunk + + # Extract metadata before modifying + usage = chunk_dict.get("usage") + finish_reason = None + chunk_has_tool_calls = False + + if "choices" in chunk_dict and chunk_dict["choices"]: + choice = chunk_dict["choices"][0] + delta = choice.get("delta", {}) + + # Check for tool_calls + if delta.get("tool_calls"): + chunk_has_tool_calls = True + + # Detect final chunk: has usage with completion_tokens > 0 + has_completion_tokens = ( + usage + and isinstance(usage, dict) + and usage.get("completion_tokens", 0) > 0 + ) + + if has_completion_tokens: + # FINAL CHUNK: Determine correct finish_reason + if has_tool_calls or chunk_has_tool_calls: + choice["finish_reason"] = "tool_calls" + elif accumulated_finish_reason: + choice["finish_reason"] = accumulated_finish_reason + else: + choice["finish_reason"] = "stop" + finish_reason = choice["finish_reason"] + else: + # INTERMEDIATE CHUNK: Never emit finish_reason + choice["finish_reason"] = None + + return ProcessedChunk( + sse_string=f"data: {json.dumps(chunk_dict)}\n\n", + usage=usage, + finish_reason=finish_reason, + has_tool_calls=chunk_has_tool_calls, + ) + + def _try_extract_error( + self, + exception: Exception, + buffer: str, + ) -> Optional[Dict]: + """ + Try to extract error JSON from exception or buffer. + + Handles multiple error formats: + - Google-style bytes representation: b'{...}' + - "Received chunk:" prefix + - JSON in buffer accumulation + + Args: + exception: The caught exception + buffer: Current JSON buffer content + + Returns: + Parsed error dict or None + """ + error_str = str(exception) + + # Pattern 1: Google-style bytes representation + match = re.search(r"b'(\{.*\})'", error_str, re.DOTALL) + if match: + try: + decoded = codecs.decode(match.group(1), "unicode_escape") + return json.loads(decoded) + except (json.JSONDecodeError, ValueError): + pass + + # Pattern 2: "Received chunk:" prefix + if "Received chunk:" in error_str: + chunk = error_str.split("Received chunk:")[-1].strip() + try: + return json.loads(chunk) + except json.JSONDecodeError: + pass + + # Pattern 3: Buffer accumulation + if buffer: + try: + return json.loads(buffer) + except json.JSONDecodeError: + pass + + return None + + def _calculate_stream_cost( + self, + model: str, + prompt_tokens: int, + completion_tokens: int, + ) -> float: + try: + model_info = litellm.get_model_info(model) + input_cost = model_info.get("input_cost_per_token") + output_cost = model_info.get("output_cost_per_token") + total_cost = 0.0 + if input_cost: + total_cost += prompt_tokens * input_cost + if output_cost: + total_cost += completion_tokens * output_cost + return total_cost + except Exception as exc: + lib_logger.debug(f"Stream cost calculation failed for {model}: {exc}") + return 0.0 + + +class StreamBuffer: + """ + Buffer for reassembling fragmented JSON in streams. + + Some providers send JSON split across multiple chunks, especially + for error responses. This class handles accumulation and parsing. + """ + + def __init__(self): + self._buffer = "" + self._complete = False + + def append(self, chunk: str) -> Optional[Dict]: + """ + Append a chunk and try to parse. + + Args: + chunk: Raw chunk string + + Returns: + Parsed dict if complete, None if still accumulating + """ + self._buffer += chunk + + try: + result = json.loads(self._buffer) + self._complete = True + return result + except json.JSONDecodeError: + return None + + def reset(self) -> None: + """Reset the buffer.""" + self._buffer = "" + self._complete = False + + @property + def content(self) -> str: + """Get current buffer content.""" + return self._buffer + + @property + def is_complete(self) -> bool: + """Check if buffer contains complete JSON.""" + return self._complete diff --git a/src/rotator_library/client/transforms.py b/src/rotator_library/client/transforms.py new file mode 100644 index 00000000..e1e834b0 --- /dev/null +++ b/src/rotator_library/client/transforms.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Provider-specific request transformations. + +This module isolates all provider-specific request mutations that were +scattered throughout client.py, including: +- gemma-3 system message conversion +- qwen_code provider remapping +- Gemini safety settings and thinking parameter +- NVIDIA thinking parameter +- iflow stream_options removal + +Transforms are applied in a defined order with logging of modifications. +""" + +import logging +from typing import Any, Callable, Dict, List, Optional + +lib_logger = logging.getLogger("rotator_library") + + +class ProviderTransforms: + """ + Centralized provider-specific request transformations. + + Transforms are applied in order: + 1. Built-in transforms (gemma-3, qwen_code, etc.) + 2. Provider hook transforms (from provider plugins) + 3. Safety settings conversions + """ + + def __init__( + self, + provider_plugins: Dict[str, Any], + provider_config: Optional[Any] = None, + ): + """ + Initialize ProviderTransforms. + + Args: + provider_plugins: Dict mapping provider names to plugin classes + provider_config: ProviderConfig instance for LiteLLM conversions + """ + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = {} + self._config = provider_config + + # Registry of built-in transforms + # Each provider can have multiple transform functions + self._transforms: Dict[str, List[Callable]] = { + "gemma": [self._transform_gemma_system_messages], + "qwen_code": [self._transform_qwen_code_provider], + "gemini": [self._transform_gemini_safety, self._transform_gemini_thinking], + "nvidia_nim": [self._transform_nvidia_thinking], + "iflow": [self._transform_iflow_stream_options], + } + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """Get or create a plugin instance for a provider.""" + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + async def apply( + self, + provider: str, + model: str, + credential: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Apply all applicable transforms to request kwargs. + + Args: + provider: Provider name + model: Model being requested + credential: Selected credential + kwargs: Request kwargs (will be mutated) + + Returns: + Modified kwargs + """ + modifications: List[str] = [] + + # 1. Apply built-in transforms + for transform_provider, transforms in self._transforms.items(): + # Check if transform applies (provider match or model contains pattern) + if transform_provider == provider or transform_provider in model.lower(): + for transform in transforms: + result = transform(kwargs, model, provider) + if result: + modifications.append(result) + + # 2. Apply provider hook transforms (async) + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "transform_request"): + try: + hook_result = await plugin.transform_request(kwargs, model, credential) + if hook_result: + modifications.extend(hook_result) + except Exception as e: + lib_logger.debug(f"Provider transform_request hook failed: {e}") + + # 3. Apply model-specific options from provider + if plugin and hasattr(plugin, "get_model_options"): + model_options = plugin.get_model_options(model) + if model_options: + for key, value in model_options.items(): + if key == "reasoning_effort": + kwargs["reasoning_effort"] = value + elif key not in kwargs: + kwargs[key] = value + modifications.append(f"applied model options for {model}") + + # 4. Apply LiteLLM conversion if config available + if self._config and hasattr(self._config, "convert_for_litellm"): + kwargs = self._config.convert_for_litellm(**kwargs) + + if modifications: + lib_logger.debug( + f"Applied transforms for {provider}/{model}: {modifications}" + ) + + return kwargs + + def apply_sync( + self, + provider: str, + model: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Apply built-in transforms synchronously (no provider hooks). + + Useful when async is not available. + + Args: + provider: Provider name + model: Model being requested + kwargs: Request kwargs + + Returns: + Modified kwargs + """ + modifications: List[str] = [] + + for transform_provider, transforms in self._transforms.items(): + if transform_provider == provider or transform_provider in model.lower(): + for transform in transforms: + result = transform(kwargs, model, provider) + if result: + modifications.append(result) + + if modifications: + lib_logger.debug( + f"Applied sync transforms for {provider}/{model}: {modifications}" + ) + + return kwargs + + # ========================================================================= + # BUILT-IN TRANSFORMS + # ========================================================================= + + def _transform_gemma_system_messages( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Convert system messages to user messages for Gemma-3. + + Gemma-3 models don't support system messages, so we convert them + to user messages to maintain functionality. + """ + if "gemma-3" not in model.lower(): + return None + + messages = kwargs.get("messages", []) + if not messages: + return None + + converted = False + new_messages = [] + for m in messages: + if m.get("role") == "system": + new_messages.append({"role": "user", "content": m["content"]}) + converted = True + else: + new_messages.append(m) + + if converted: + kwargs["messages"] = new_messages + return "gemma-3: converted system->user messages" + return None + + def _transform_qwen_code_provider( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Remap qwen_code to qwen provider for LiteLLM. + + The qwen_code provider is a custom wrapper that needs to be + translated to the qwen provider for LiteLLM compatibility. + """ + if provider != "qwen_code": + return None + + kwargs["custom_llm_provider"] = "qwen" + if "/" in model: + kwargs["model"] = model.split("/", 1)[1] + return "qwen_code: remapped to qwen provider" + + def _transform_gemini_safety( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Apply default Gemini safety settings. + + Ensures safety settings are present without overriding explicit settings. + """ + if provider != "gemini": + return None + + # Default safety settings (generic form) + default_generic = { + "harassment": "OFF", + "hate_speech": "OFF", + "sexually_explicit": "OFF", + "dangerous_content": "OFF", + "civic_integrity": "BLOCK_NONE", + } + + # Default Gemini-native settings + default_gemini = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + ] + + # If generic form present, fill in missing keys + if "safety_settings" in kwargs and isinstance(kwargs["safety_settings"], dict): + for k, v in default_generic.items(): + if k not in kwargs["safety_settings"]: + kwargs["safety_settings"][k] = v + return "gemini: filled missing safety settings" + + # If Gemini form present, fill in missing categories + if "safetySettings" in kwargs and isinstance(kwargs["safetySettings"], list): + present = { + item.get("category") + for item in kwargs["safetySettings"] + if isinstance(item, dict) + } + added = 0 + for d in default_gemini: + if d["category"] not in present: + kwargs["safetySettings"].append(d) + added += 1 + if added > 0: + return f"gemini: added {added} missing safety categories" + return None + + # Neither present: set generic defaults + if "safety_settings" not in kwargs and "safetySettings" not in kwargs: + kwargs["safety_settings"] = default_generic.copy() + return "gemini: applied default safety settings" + + return None + + def _transform_gemini_thinking( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Handle thinking parameter for Gemini. + + Delegates to provider plugin's handle_thinking_parameter method. + """ + if provider != "gemini": + return None + + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "handle_thinking_parameter"): + plugin.handle_thinking_parameter(kwargs, model) + return "gemini: handled thinking parameter" + return None + + def _transform_nvidia_thinking( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Handle thinking parameter for NVIDIA NIM. + + Delegates to provider plugin's handle_thinking_parameter method. + """ + if provider != "nvidia_nim": + return None + + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "handle_thinking_parameter"): + plugin.handle_thinking_parameter(kwargs, model) + return "nvidia_nim: handled thinking parameter" + return None + + def _transform_iflow_stream_options( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Remove stream_options for iflow provider. + + The iflow provider returns HTTP 406 if stream_options is present. + """ + if provider != "iflow": + return None + + if "stream_options" in kwargs: + del kwargs["stream_options"] + return "iflow: removed stream_options" + return None + + # ========================================================================= + # SAFETY SETTINGS CONVERSION + # ========================================================================= + + def convert_safety_settings( + self, + provider: str, + settings: Dict[str, str], + ) -> Optional[Any]: + """ + Convert generic safety settings to provider-specific format. + + Args: + provider: Provider name + settings: Generic safety settings dict + + Returns: + Provider-specific settings or None + """ + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "convert_safety_settings"): + return plugin.convert_safety_settings(settings) + return None diff --git a/src/rotator_library/client/types.py b/src/rotator_library/client/types.py new file mode 100644 index 00000000..b54bba0e --- /dev/null +++ b/src/rotator_library/client/types.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Client-specific type definitions. + +Types that are only used within the client package. +Shared types are in core/types.py. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set + + +@dataclass +class AvailabilityStats: + """ + Statistics about credential availability for a model. + + Used for logging and monitoring credential pool status. + """ + + available: int # Credentials not on cooldown and not exhausted + on_cooldown: int # Credentials on cooldown + fair_cycle_excluded: int # Credentials excluded by fair cycle + total: int # Total credentials for provider + + @property + def usable(self) -> int: + """Return count of usable credentials.""" + return self.available + + def __str__(self) -> str: + parts = [f"{self.available}/{self.total}"] + if self.on_cooldown > 0: + parts.append(f"cd:{self.on_cooldown}") + if self.fair_cycle_excluded > 0: + parts.append(f"fc:{self.fair_cycle_excluded}") + return ",".join(parts) + + +@dataclass +class RetryState: + """ + State tracking for a retry loop. + + Used by RequestExecutor to track retry attempts and errors. + """ + + tried_credentials: Set[str] = field(default_factory=set) + last_exception: Optional[Exception] = None + consecutive_quota_failures: int = 0 + + def record_attempt(self, credential: str) -> None: + """Record that a credential was tried.""" + self.tried_credentials.add(credential) + + def reset_quota_failures(self) -> None: + """Reset quota failure counter (called after non-quota error).""" + self.consecutive_quota_failures = 0 + + def increment_quota_failures(self) -> None: + """Increment quota failure counter.""" + self.consecutive_quota_failures += 1 + + +@dataclass +class ExecutionResult: + """ + Result of executing a request. + + Returned by RequestExecutor to indicate outcome. + """ + + success: bool + response: Optional[Any] = None + error: Optional[Exception] = None + should_rotate: bool = False + should_fail: bool = False diff --git a/src/rotator_library/cooldown_manager.py b/src/rotator_library/cooldown_manager.py index 8e045e48..0d1bb63e 100644 --- a/src/rotator_library/cooldown_manager.py +++ b/src/rotator_library/cooldown_manager.py @@ -5,12 +5,14 @@ import time from typing import Dict + class CooldownManager: """ Manages global cooldown periods for API providers to handle IP-based rate limiting. This ensures that once a 429 error is received for a provider, all subsequent requests to that provider are paused for a specified duration. """ + def __init__(self): self._cooldowns: Dict[str, float] = {} self._lock = asyncio.Lock() @@ -18,7 +20,9 @@ def __init__(self): async def is_cooling_down(self, provider: str) -> bool: """Checks if a provider is currently in a cooldown period.""" async with self._lock: - return provider in self._cooldowns and time.time() < self._cooldowns[provider] + return ( + provider in self._cooldowns and time.time() < self._cooldowns[provider] + ) async def start_cooldown(self, provider: str, duration: int): """ @@ -37,4 +41,8 @@ async def get_cooldown_remaining(self, provider: str) -> float: if provider in self._cooldowns: remaining = self._cooldowns[provider] - time.time() return max(0, remaining) - return 0 \ No newline at end of file + return 0 + + async def get_remaining_cooldown(self, provider: str) -> float: + """Backward-compatible alias for get_cooldown_remaining.""" + return await self.get_cooldown_remaining(provider) diff --git a/src/rotator_library/core/__init__.py b/src/rotator_library/core/__init__.py new file mode 100644 index 00000000..c88a75a5 --- /dev/null +++ b/src/rotator_library/core/__init__.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Core package for the rotator library. + +Provides shared infrastructure used by both client and usage manager: +- types: Shared dataclasses and type definitions +- errors: All custom exceptions +- config: ConfigLoader for centralized configuration +- constants: Default values and magic numbers +""" + +from .types import ( + CredentialInfo, + RequestContext, + ProcessedChunk, + FilterResult, + FairCycleConfig, + CustomCapConfig, + ProviderConfig, + WindowConfig, + RequestCompleteResult, +) + +from .errors import ( + # Base exceptions + NoAvailableKeysError, + PreRequestCallbackError, + CredentialNeedsReauthError, + EmptyResponseError, + TransientQuotaError, + StreamedAPIError, + # Error classification + ClassifiedError, + RequestErrorAccumulator, + classify_error, + should_rotate_on_error, + should_retry_same_key, + mask_credential, + is_abnormal_error, + get_retry_after, +) + +from .config import ConfigLoader + +__all__ = [ + # Types + "CredentialInfo", + "RequestContext", + "ProcessedChunk", + "FilterResult", + "FairCycleConfig", + "CustomCapConfig", + "ProviderConfig", + "WindowConfig", + "RequestCompleteResult", + # Errors + "NoAvailableKeysError", + "PreRequestCallbackError", + "CredentialNeedsReauthError", + "EmptyResponseError", + "TransientQuotaError", + "StreamedAPIError", + "ClassifiedError", + "RequestErrorAccumulator", + "classify_error", + "should_rotate_on_error", + "should_retry_same_key", + "mask_credential", + "is_abnormal_error", + "get_retry_after", + # Config + "ConfigLoader", +] diff --git a/src/rotator_library/core/config.py b/src/rotator_library/core/config.py new file mode 100644 index 00000000..34114011 --- /dev/null +++ b/src/rotator_library/core/config.py @@ -0,0 +1,550 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Centralized configuration loader for the rotator library. + +This module provides a ConfigLoader class that handles all configuration +parsing from: +1. System defaults (from config/defaults.py) +2. Provider class attributes +3. Environment variables (ALWAYS override provider defaults) + +The ConfigLoader ensures consistent configuration handling across +both the client and usage manager. +""" + +import os +import logging +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from .types import ( + ProviderConfig, + FairCycleConfig, + CustomCapConfig, + WindowConfig, +) +from .constants import ( + # Defaults + DEFAULT_ROTATION_MODE, + DEFAULT_ROTATION_TOLERANCE, + DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, + DEFAULT_FAIR_CYCLE_ENABLED, + DEFAULT_FAIR_CYCLE_TRACKING_MODE, + DEFAULT_FAIR_CYCLE_CROSS_TIER, + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + # Prefixes + ENV_PREFIX_ROTATION_MODE, + ENV_PREFIX_FAIR_CYCLE, + ENV_PREFIX_FAIR_CYCLE_TRACKING, + ENV_PREFIX_FAIR_CYCLE_CROSS_TIER, + ENV_PREFIX_FAIR_CYCLE_DURATION, + ENV_PREFIX_EXHAUSTION_THRESHOLD, + ENV_PREFIX_CONCURRENCY_MULTIPLIER, + ENV_PREFIX_CUSTOM_CAP, + ENV_PREFIX_CUSTOM_CAP_COOLDOWN, +) + +lib_logger = logging.getLogger("rotator_library") + + +class ConfigLoader: + """ + Centralized configuration loader. + + Parses all configuration from: + 1. System defaults + 2. Provider class attributes + 3. Environment variables (ALWAYS override provider defaults) + + Usage: + loader = ConfigLoader(provider_plugins) + config = loader.load_provider_config("antigravity") + """ + + def __init__(self, provider_plugins: Optional[Dict[str, type]] = None): + """ + Initialize the ConfigLoader. + + Args: + provider_plugins: Dict mapping provider names to plugin classes. + If None, no provider-specific defaults are used. + """ + self._plugins = provider_plugins or {} + self._cache: Dict[str, ProviderConfig] = {} + + def load_provider_config( + self, + provider: str, + force_reload: bool = False, + ) -> ProviderConfig: + """ + Load complete configuration for a provider. + + Configuration is loaded in this order (later overrides earlier): + 1. System defaults + 2. Provider class attributes + 3. Environment variables (ALWAYS win) + + Args: + provider: Provider name (e.g., "antigravity", "gemini_cli") + force_reload: If True, bypass cache and reload + + Returns: + Complete ProviderConfig for the provider + """ + if not force_reload and provider in self._cache: + return self._cache[provider] + + # Start with system defaults + config = self._get_system_defaults() + + # Apply provider class defaults + plugin_class = self._plugins.get(provider) + if plugin_class: + config = self._apply_provider_defaults(config, plugin_class, provider) + + # Apply environment variable overrides (ALWAYS win) + config = self._apply_env_overrides(config, provider) + + # Cache and return + self._cache[provider] = config + return config + + def load_all_provider_configs( + self, + providers: List[str], + ) -> Dict[str, ProviderConfig]: + """ + Load configurations for multiple providers. + + Args: + providers: List of provider names + + Returns: + Dict mapping provider names to their configs + """ + return {p: self.load_provider_config(p) for p in providers} + + def clear_cache(self, provider: Optional[str] = None) -> None: + """ + Clear cached configurations. + + Args: + provider: If provided, only clear that provider's cache. + If None, clear all cached configs. + """ + if provider: + self._cache.pop(provider, None) + else: + self._cache.clear() + + # ========================================================================= + # INTERNAL METHODS + # ========================================================================= + + def _get_system_defaults(self) -> ProviderConfig: + """Get a ProviderConfig with all system defaults.""" + return ProviderConfig( + rotation_mode=DEFAULT_ROTATION_MODE, + rotation_tolerance=DEFAULT_ROTATION_TOLERANCE, + priority_multipliers={}, + priority_multipliers_by_mode={}, + sequential_fallback_multiplier=DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, + fair_cycle=FairCycleConfig( + enabled=DEFAULT_FAIR_CYCLE_ENABLED, + tracking_mode=DEFAULT_FAIR_CYCLE_TRACKING_MODE, + cross_tier=DEFAULT_FAIR_CYCLE_CROSS_TIER, + duration=DEFAULT_FAIR_CYCLE_DURATION, + ), + custom_caps=[], + exhaustion_cooldown_threshold=DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + windows=[], + ) + + def _apply_provider_defaults( + self, + config: ProviderConfig, + plugin_class: type, + provider: str, + ) -> ProviderConfig: + """ + Apply provider class default attributes to config. + + Args: + config: Current configuration + plugin_class: Provider plugin class + provider: Provider name for logging + + Returns: + Updated configuration + """ + # Rotation mode + if hasattr(plugin_class, "default_rotation_mode"): + config.rotation_mode = plugin_class.default_rotation_mode + + # Priority multipliers + if hasattr(plugin_class, "default_priority_multipliers"): + multipliers = plugin_class.default_priority_multipliers + if multipliers: + config.priority_multipliers = dict(multipliers) + + # Sequential fallback multiplier + if hasattr(plugin_class, "default_sequential_fallback_multiplier"): + fallback = plugin_class.default_sequential_fallback_multiplier + if fallback != DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER: + config.sequential_fallback_multiplier = fallback + + # Fair cycle settings + if hasattr(plugin_class, "default_fair_cycle_enabled"): + val = plugin_class.default_fair_cycle_enabled + if val is not None: + config.fair_cycle.enabled = val + + if hasattr(plugin_class, "default_fair_cycle_tracking_mode"): + config.fair_cycle.tracking_mode = ( + plugin_class.default_fair_cycle_tracking_mode + ) + + if hasattr(plugin_class, "default_fair_cycle_cross_tier"): + config.fair_cycle.cross_tier = plugin_class.default_fair_cycle_cross_tier + + if hasattr(plugin_class, "default_fair_cycle_duration"): + duration = plugin_class.default_fair_cycle_duration + if duration != DEFAULT_FAIR_CYCLE_DURATION: + config.fair_cycle.duration = duration + + # Exhaustion cooldown threshold + if hasattr(plugin_class, "default_exhaustion_cooldown_threshold"): + threshold = plugin_class.default_exhaustion_cooldown_threshold + if threshold != DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD: + config.exhaustion_cooldown_threshold = threshold + + # Custom caps + if hasattr(plugin_class, "default_custom_caps"): + caps = plugin_class.default_custom_caps + if caps: + config.custom_caps = self._parse_custom_caps_from_provider(caps) + + return config + + def _apply_env_overrides( + self, + config: ProviderConfig, + provider: str, + ) -> ProviderConfig: + """ + Apply environment variable overrides to config. + + Environment variables ALWAYS override provider class defaults. + + Args: + config: Current configuration + provider: Provider name + + Returns: + Updated configuration with env overrides applied + """ + provider_upper = provider.upper() + + # Rotation mode: ROTATION_MODE_{PROVIDER} + env_key = f"{ENV_PREFIX_ROTATION_MODE}{provider_upper}" + env_val = os.getenv(env_key) + if env_val: + config.rotation_mode = env_val.lower() + if config.rotation_mode not in ("balanced", "sequential"): + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Using 'balanced'.") + config.rotation_mode = "balanced" + + # Fair cycle enabled: FAIR_CYCLE_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE}{provider_upper}" + env_val = os.getenv(env_key) + if env_val is not None: + config.fair_cycle.enabled = env_val.lower() in ("true", "1", "yes") + + # Fair cycle tracking mode: FAIR_CYCLE_TRACKING_MODE_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE_TRACKING}{provider_upper}" + env_val = os.getenv(env_key) + if env_val and env_val.lower() in ("model_group", "credential"): + config.fair_cycle.tracking_mode = env_val.lower() + + # Fair cycle cross-tier: FAIR_CYCLE_CROSS_TIER_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE_CROSS_TIER}{provider_upper}" + env_val = os.getenv(env_key) + if env_val is not None: + config.fair_cycle.cross_tier = env_val.lower() in ("true", "1", "yes") + + # Fair cycle duration: FAIR_CYCLE_DURATION_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE_DURATION}{provider_upper}" + env_val = os.getenv(env_key) + if env_val: + try: + config.fair_cycle.duration = int(env_val) + except ValueError: + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Must be integer.") + + # Exhaustion cooldown threshold: EXHAUSTION_COOLDOWN_THRESHOLD_{PROVIDER} + # Also check global: EXHAUSTION_COOLDOWN_THRESHOLD + env_key = f"{ENV_PREFIX_EXHAUSTION_THRESHOLD}{provider_upper}" + env_val = os.getenv(env_key) or os.getenv("EXHAUSTION_COOLDOWN_THRESHOLD") + if env_val: + try: + config.exhaustion_cooldown_threshold = int(env_val) + except ValueError: + lib_logger.warning(f"Invalid exhaustion threshold='{env_val}'.") + + # Priority multipliers: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N} + # Also supports mode-specific: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}_{MODE} + self._parse_priority_multiplier_env_vars(config, provider_upper) + + # Custom caps: CUSTOM_CAP_{PROVIDER}_T{TIER}_{MODEL} + # Also: CUSTOM_CAP_COOLDOWN_{PROVIDER}_T{TIER}_{MODEL} + self._parse_custom_cap_env_vars(config, provider_upper) + + return config + + def _parse_priority_multiplier_env_vars( + self, + config: ProviderConfig, + provider_upper: str, + ) -> None: + """ + Parse CONCURRENCY_MULTIPLIER_* environment variables. + + Formats: + - CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}=value + - CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}_{MODE}=value + """ + prefix = f"{ENV_PREFIX_CONCURRENCY_MULTIPLIER}{provider_upper}_PRIORITY_" + + for env_key, env_val in os.environ.items(): + if not env_key.startswith(prefix): + continue + + remainder = env_key[len(prefix) :] + try: + multiplier = int(env_val) + if multiplier < 1: + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Must be >= 1.") + continue + + # Check for mode-specific suffix + if "_" in remainder: + parts = remainder.rsplit("_", 1) + priority = int(parts[0]) + mode = parts[1].lower() + + if mode in ("sequential", "balanced"): + if mode not in config.priority_multipliers_by_mode: + config.priority_multipliers_by_mode[mode] = {} + config.priority_multipliers_by_mode[mode][priority] = multiplier + else: + lib_logger.warning(f"Unknown mode in {env_key}: {mode}") + else: + # Universal priority multiplier + priority = int(remainder) + config.priority_multipliers[priority] = multiplier + + except ValueError: + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Could not parse.") + + def _parse_custom_cap_env_vars( + self, + config: ProviderConfig, + provider_upper: str, + ) -> None: + """ + Parse CUSTOM_CAP_* environment variables. + + Formats: + - CUSTOM_CAP_{PROVIDER}_T{TIER}_{MODEL}=value + - CUSTOM_CAP_{PROVIDER}_TDEFAULT_{MODEL}=value + - CUSTOM_CAP_COOLDOWN_{PROVIDER}_T{TIER}_{MODEL}=mode:value + """ + cap_prefix = f"{ENV_PREFIX_CUSTOM_CAP}{provider_upper}_T" + cooldown_prefix = f"{ENV_PREFIX_CUSTOM_CAP_COOLDOWN}{provider_upper}_T" + + # Collect caps by (tier_key, model_key) to merge cap and cooldown + caps_dict: Dict[Tuple[Any, str], Dict[str, Any]] = {} + + for env_key, env_val in os.environ.items(): + if env_key.startswith(cooldown_prefix): + remainder = env_key[len(cooldown_prefix) :] + tier_key, model_key = self._parse_tier_model_from_env(remainder) + if tier_key is None: + continue + + # Parse mode:value format + if ":" in env_val: + mode, value_str = env_val.split(":", 1) + try: + value = int(value_str) + except ValueError: + lib_logger.warning(f"Invalid cooldown in {env_key}") + continue + else: + mode = env_val + value = 0 + + key = (tier_key, model_key) + if key not in caps_dict: + caps_dict[key] = {} + caps_dict[key]["cooldown_mode"] = mode + caps_dict[key]["cooldown_value"] = value + + elif env_key.startswith(cap_prefix): + remainder = env_key[len(cap_prefix) :] + tier_key, model_key = self._parse_tier_model_from_env(remainder) + if tier_key is None: + continue + + key = (tier_key, model_key) + if key not in caps_dict: + caps_dict[key] = {} + caps_dict[key]["max_requests"] = env_val + + # Convert to CustomCapConfig objects + for (tier_key, model_key), cap_data in caps_dict.items(): + if "max_requests" not in cap_data: + continue # Need at least max_requests + + cap = CustomCapConfig( + tier_key=tier_key, + model_or_group=model_key, + max_requests=cap_data["max_requests"], + cooldown_mode=cap_data.get("cooldown_mode", "quota_reset"), + cooldown_value=cap_data.get("cooldown_value", 0), + ) + config.custom_caps.append(cap) + + def _parse_tier_model_from_env( + self, + remainder: str, + ) -> Tuple[Optional[Union[int, Tuple[int, ...], str]], Optional[str]]: + """ + Parse tier and model/group from env var remainder. + + Args: + remainder: String after "CUSTOM_CAP_{PROVIDER}_T" prefix + e.g., "2_CLAUDE" or "2_3_CLAUDE" or "DEFAULT_CLAUDE" + + Returns: + (tier_key, model_key) or (None, None) if parse fails + """ + if not remainder: + return None, None + + parts = remainder.split("_") + if len(parts) < 2: + return None, None + + tier_parts: List[int] = [] + tier_key: Union[int, Tuple[int, ...], str, None] = None + model_key: Optional[str] = None + + for i, part in enumerate(parts): + if part == "DEFAULT": + tier_key = "default" + model_key = "_".join(parts[i + 1 :]) + break + elif part.isdigit(): + tier_parts.append(int(part)) + else: + # First non-numeric part is start of model name + if len(tier_parts) == 0: + return None, None + elif len(tier_parts) == 1: + tier_key = tier_parts[0] + else: + tier_key = tuple(tier_parts) + model_key = "_".join(parts[i:]) + break + else: + # All parts were tier parts, no model + return None, None + + if model_key: + # Convert to lowercase with dashes (standard model name format) + model_key = model_key.lower().replace("_", "-") + + return tier_key, model_key + + def _parse_custom_caps_from_provider( + self, + caps: Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]], + ) -> List[CustomCapConfig]: + """ + Parse custom caps from provider class default_custom_caps attribute. + + Args: + caps: Provider's default_custom_caps dict + + Returns: + List of CustomCapConfig objects + """ + result = [] + + for tier_key, models_config in caps.items(): + for model_key, cap_data in models_config.items(): + cap = CustomCapConfig( + tier_key=tier_key, + model_or_group=model_key, + max_requests=cap_data.get("max_requests", 0), + cooldown_mode=cap_data.get("cooldown_mode", "quota_reset"), + cooldown_value=cap_data.get("cooldown_value", 0), + ) + result.append(cap) + + return result + + +# ============================================================================= +# MODULE-LEVEL CONVENIENCE FUNCTIONS +# ============================================================================= + +# Global loader instance (initialized lazily) +_global_loader: Optional[ConfigLoader] = None + + +def get_config_loader( + provider_plugins: Optional[Dict[str, type]] = None, +) -> ConfigLoader: + """ + Get the global ConfigLoader instance. + + Creates a new instance if none exists or if provider_plugins is provided. + + Args: + provider_plugins: Optional dict of provider plugins. If provided, + creates a new loader with these plugins. + + Returns: + The global ConfigLoader instance + """ + global _global_loader + + if provider_plugins is not None: + _global_loader = ConfigLoader(provider_plugins) + elif _global_loader is None: + _global_loader = ConfigLoader() + + return _global_loader + + +def load_provider_config( + provider: str, + provider_plugins: Optional[Dict[str, type]] = None, +) -> ProviderConfig: + """ + Convenience function to load a provider's configuration. + + Args: + provider: Provider name + provider_plugins: Optional provider plugins dict + + Returns: + ProviderConfig for the provider + """ + loader = get_config_loader(provider_plugins) + return loader.load_provider_config(provider) diff --git a/src/rotator_library/core/constants.py b/src/rotator_library/core/constants.py new file mode 100644 index 00000000..6b3f9b92 --- /dev/null +++ b/src/rotator_library/core/constants.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Constants and default values for the rotator library. + +This module re-exports all constants from the config package and adds +any additional constants needed for the refactored architecture. + +All tunable defaults are in config/defaults.py - this module provides +a unified import point and adds non-tunable constants. +""" + +# Re-export all tunable defaults from config package +from ..config import ( + # Rotation & Selection + DEFAULT_ROTATION_MODE, + DEFAULT_ROTATION_TOLERANCE, + DEFAULT_MAX_RETRIES, + DEFAULT_GLOBAL_TIMEOUT, + # Tier & Priority + DEFAULT_TIER_PRIORITY, + DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, + # Fair Cycle Rotation + DEFAULT_FAIR_CYCLE_ENABLED, + DEFAULT_FAIR_CYCLE_TRACKING_MODE, + DEFAULT_FAIR_CYCLE_CROSS_TIER, + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + # Custom Caps + DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, + DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE, + # Cooldown & Backoff + COOLDOWN_BACKOFF_TIERS, + COOLDOWN_BACKOFF_MAX, + COOLDOWN_AUTH_ERROR, + COOLDOWN_TRANSIENT_ERROR, + COOLDOWN_RATE_LIMIT_DEFAULT, +) + +# ============================================================================= +# ADDITIONAL CONSTANTS FOR REFACTORED ARCHITECTURE +# ============================================================================= + +# Environment variable prefixes for configuration +ENV_PREFIX_ROTATION_MODE = "ROTATION_MODE_" +ENV_PREFIX_FAIR_CYCLE = "FAIR_CYCLE_" +ENV_PREFIX_FAIR_CYCLE_TRACKING = "FAIR_CYCLE_TRACKING_MODE_" +ENV_PREFIX_FAIR_CYCLE_CROSS_TIER = "FAIR_CYCLE_CROSS_TIER_" +ENV_PREFIX_FAIR_CYCLE_DURATION = "FAIR_CYCLE_DURATION_" +ENV_PREFIX_EXHAUSTION_THRESHOLD = "EXHAUSTION_COOLDOWN_THRESHOLD_" +ENV_PREFIX_CONCURRENCY_MULTIPLIER = "CONCURRENCY_MULTIPLIER_" +ENV_PREFIX_CUSTOM_CAP = "CUSTOM_CAP_" +ENV_PREFIX_CUSTOM_CAP_COOLDOWN = "CUSTOM_CAP_COOLDOWN_" +ENV_PREFIX_QUOTA_GROUPS = "QUOTA_GROUPS_" + +# Provider-specific providers that use request_count instead of success_count +# for credential selection (because failed requests also consume quota) +REQUEST_COUNT_PROVIDERS = frozenset({"antigravity", "gemini_cli", "chutes", "nanogpt"}) + +# Usage manager storage +USAGE_FILE_NAME = "usage.json" # New format +LEGACY_USAGE_FILE_NAME = "key_usage.json" # Old format +USAGE_SCHEMA_VERSION = 2 + +# Fair cycle tracking keys +FAIR_CYCLE_ALL_TIERS_KEY = "__all_tiers__" +FAIR_CYCLE_CREDENTIAL_KEY = "__credential__" +FAIR_CYCLE_STORAGE_KEY = "__fair_cycle__" + +# Logging +LIB_LOGGER_NAME = "rotator_library" + +__all__ = [ + # From config package + "DEFAULT_ROTATION_MODE", + "DEFAULT_ROTATION_TOLERANCE", + "DEFAULT_MAX_RETRIES", + "DEFAULT_GLOBAL_TIMEOUT", + "DEFAULT_TIER_PRIORITY", + "DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER", + "DEFAULT_FAIR_CYCLE_ENABLED", + "DEFAULT_FAIR_CYCLE_TRACKING_MODE", + "DEFAULT_FAIR_CYCLE_CROSS_TIER", + "DEFAULT_FAIR_CYCLE_DURATION", + "DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD", + "DEFAULT_CUSTOM_CAP_COOLDOWN_MODE", + "DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE", + "COOLDOWN_BACKOFF_TIERS", + "COOLDOWN_BACKOFF_MAX", + "COOLDOWN_AUTH_ERROR", + "COOLDOWN_TRANSIENT_ERROR", + "COOLDOWN_RATE_LIMIT_DEFAULT", + # Environment variable prefixes + "ENV_PREFIX_ROTATION_MODE", + "ENV_PREFIX_FAIR_CYCLE", + "ENV_PREFIX_FAIR_CYCLE_TRACKING", + "ENV_PREFIX_FAIR_CYCLE_CROSS_TIER", + "ENV_PREFIX_FAIR_CYCLE_DURATION", + "ENV_PREFIX_EXHAUSTION_THRESHOLD", + "ENV_PREFIX_CONCURRENCY_MULTIPLIER", + "ENV_PREFIX_CUSTOM_CAP", + "ENV_PREFIX_CUSTOM_CAP_COOLDOWN", + "ENV_PREFIX_QUOTA_GROUPS", + # Provider sets + "REQUEST_COUNT_PROVIDERS", + # Storage + "USAGE_FILE_NAME", + "LEGACY_USAGE_FILE_NAME", + "USAGE_SCHEMA_VERSION", + # Fair cycle keys + "FAIR_CYCLE_ALL_TIERS_KEY", + "FAIR_CYCLE_CREDENTIAL_KEY", + "FAIR_CYCLE_STORAGE_KEY", + # Logging + "LIB_LOGGER_NAME", +] diff --git a/src/rotator_library/core/errors.py b/src/rotator_library/core/errors.py new file mode 100644 index 00000000..5acd9fc7 --- /dev/null +++ b/src/rotator_library/core/errors.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Error handling for the rotator library. + +This module re-exports all exception classes and error handling utilities +from the main error_handler module, and adds any new error types needed +for the refactored architecture. + +Note: The actual implementations remain in error_handler.py for backward +compatibility. This module provides a cleaner import path. +""" + +# Re-export everything from error_handler +from ..error_handler import ( + # Exception classes + NoAvailableKeysError, + PreRequestCallbackError, + CredentialNeedsReauthError, + EmptyResponseError, + TransientQuotaError, + # Error classification + ClassifiedError, + RequestErrorAccumulator, + classify_error, + should_rotate_on_error, + should_retry_same_key, + is_abnormal_error, + # Utilities + mask_credential, + get_retry_after, + extract_retry_after_from_body, + is_rate_limit_error, + is_server_error, + is_unrecoverable_error, + # Constants + ABNORMAL_ERROR_TYPES, + NORMAL_ERROR_TYPES, +) + + +# ============================================================================= +# NEW EXCEPTIONS FOR REFACTORED ARCHITECTURE +# ============================================================================= + + +class StreamedAPIError(Exception): + """ + Custom exception to signal an API error received over a stream. + + This is raised when an error is detected in streaming response data, + allowing the retry logic to handle it appropriately. + + Attributes: + message: Human-readable error message + data: The parsed error data (dict or exception) + """ + + def __init__(self, message: str, data=None): + super().__init__(message) + self.data = data + + +__all__ = [ + # Exception classes + "NoAvailableKeysError", + "PreRequestCallbackError", + "CredentialNeedsReauthError", + "EmptyResponseError", + "TransientQuotaError", + "StreamedAPIError", + # Error classification + "ClassifiedError", + "RequestErrorAccumulator", + "classify_error", + "should_rotate_on_error", + "should_retry_same_key", + "is_abnormal_error", + # Utilities + "mask_credential", + "get_retry_after", + "extract_retry_after_from_body", + "is_rate_limit_error", + "is_server_error", + "is_unrecoverable_error", + # Constants + "ABNORMAL_ERROR_TYPES", + "NORMAL_ERROR_TYPES", +] diff --git a/src/rotator_library/core/types.py b/src/rotator_library/core/types.py new file mode 100644 index 00000000..9c449acf --- /dev/null +++ b/src/rotator_library/core/types.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Shared type definitions for the rotator library. + +This module contains dataclasses and type definitions used across +both the client and usage manager packages. +""" + +from dataclasses import dataclass, field +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Set, + Tuple, + Union, +) + + +# ============================================================================= +# CREDENTIAL TYPES +# ============================================================================= + + +@dataclass +class CredentialInfo: + """ + Information about a credential. + + Used for passing credential metadata between components. + """ + + accessor: str # File path or API key + stable_id: str # Email (OAuth) or hash (API key) + provider: str + tier: Optional[str] = None + priority: int = 999 # Lower = higher priority + display_name: Optional[str] = None + + +# ============================================================================= +# REQUEST TYPES +# ============================================================================= + + +@dataclass +class RequestContext: + """ + Context for a request being processed. + + Contains all information needed to execute a request with + retry/rotation logic. + """ + + model: str + provider: str + kwargs: Dict[str, Any] + streaming: bool + credentials: List[str] + deadline: float + request: Optional[Any] = None # FastAPI Request object + pre_request_callback: Optional[Callable] = None + transaction_logger: Optional[Any] = None + + +@dataclass +class ProcessedChunk: + """ + Result of processing a streaming chunk. + + Used by StreamingHandler to return processed chunk data. + """ + + sse_string: str # The SSE-formatted string to yield + usage: Optional[Dict[str, Any]] = None + finish_reason: Optional[str] = None + has_tool_calls: bool = False + + +# ============================================================================= +# FILTER TYPES +# ============================================================================= + + +@dataclass +class FilterResult: + """ + Result of credential filtering. + + Contains categorized credentials after filtering by tier compatibility. + """ + + compatible: List[str] = field(default_factory=list) # Known compatible + unknown: List[str] = field(default_factory=list) # Unknown tier + incompatible: List[str] = field(default_factory=list) # Known incompatible + priorities: Dict[str, int] = field(default_factory=dict) # credential -> priority + tier_names: Dict[str, str] = field(default_factory=dict) # credential -> tier name + + @property + def all_usable(self) -> List[str]: + """Return all usable credentials (compatible + unknown).""" + return self.compatible + self.unknown + + +# ============================================================================= +# CONFIGURATION TYPES +# ============================================================================= + + +@dataclass +class FairCycleConfig: + """ + Fair cycle rotation configuration for a provider. + + Fair cycle ensures each credential is used at least once before + any credential is reused. + """ + + enabled: Optional[bool] = None # None = derive from rotation mode + tracking_mode: str = "model_group" # "model_group" or "credential" + cross_tier: bool = False # Track across all tiers + duration: int = 604800 # 7 days in seconds + + +@dataclass +class CustomCapConfig: + """ + Custom cap configuration for a tier/model combination. + + Allows setting usage limits more restrictive than actual API limits. + """ + + tier_key: Union[int, Tuple[int, ...], str] # Priority(s) or "default" + model_or_group: str # Model name or quota group name + max_requests: Union[int, str] # Absolute value or percentage ("80%") + cooldown_mode: str = "quota_reset" # "quota_reset", "offset", "fixed" + cooldown_value: int = 0 # Seconds for offset/fixed modes + + +@dataclass +class WindowConfig: + """ + Quota window configuration. + + Defines how usage is tracked and reset for a credential. + """ + + name: str # e.g., "5h", "daily", "weekly" + duration_seconds: Optional[int] # None for infinite/total + reset_mode: str # "rolling", "fixed_daily", "calendar_weekly", "api_authoritative" + applies_to: str # "credential", "group", "model" + + +@dataclass +class ProviderConfig: + """ + Complete configuration for a provider. + + Loaded by ConfigLoader and used by both client and usage manager. + """ + + rotation_mode: str = "balanced" # "balanced" or "sequential" + rotation_tolerance: float = 3.0 + priority_multipliers: Dict[int, int] = field(default_factory=dict) + priority_multipliers_by_mode: Dict[str, Dict[int, int]] = field( + default_factory=dict + ) + sequential_fallback_multiplier: int = 1 + fair_cycle: FairCycleConfig = field(default_factory=FairCycleConfig) + custom_caps: List[CustomCapConfig] = field(default_factory=list) + exhaustion_cooldown_threshold: int = 300 # 5 minutes + windows: List[WindowConfig] = field(default_factory=list) + + +# ============================================================================= +# HOOK RESULT TYPES +# ============================================================================= + + +@dataclass +class RequestCompleteResult: + """ + Result from on_request_complete provider hook. + + Allows providers to customize how requests are counted and cooled down. + """ + + count_override: Optional[int] = None # How many requests to count + cooldown_override: Optional[float] = None # Custom cooldown duration + force_exhausted: bool = False # Mark for fair cycle + + +# ============================================================================= +# ERROR ACTION ENUM +# ============================================================================= + + +class ErrorAction: + """ + Actions to take after an error. + + Used by RequestExecutor to determine next steps. + """ + + RETRY_SAME = "retry_same" # Retry with same credential + ROTATE = "rotate" # Try next credential + FAIL = "fail" # Fail the request immediately diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index 8b05ad84..4e35440c 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -5,7 +5,7 @@ import json import os import logging -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Tuple import httpx from litellm.exceptions import ( @@ -439,6 +439,8 @@ def __init__( status_code: Optional[int] = None, retry_after: Optional[int] = None, quota_reset_timestamp: Optional[float] = None, + quota_value: Optional[str] = None, + quota_id: Optional[str] = None, ): self.error_type = error_type self.original_exception = original_exception @@ -447,6 +449,9 @@ def __init__( # Unix timestamp when quota resets (from quota_exhausted errors) # This is the authoritative reset time parsed from provider's error response self.quota_reset_timestamp = quota_reset_timestamp + # Quota details extracted from Google/Gemini API error responses + self.quota_value = quota_value # e.g., "50" or "1000/minute" + self.quota_id = quota_id # e.g., "GenerateContentPerMinutePerProject" def __str__(self): parts = [ @@ -456,6 +461,10 @@ def __str__(self): ] if self.quota_reset_timestamp: parts.append(f"quota_reset_ts={self.quota_reset_timestamp}") + if self.quota_value: + parts.append(f"quota_value={self.quota_value}") + if self.quota_id: + parts.append(f"quota_id={self.quota_id}") parts.append(f"original_exc={self.original_exception}") return f"ClassifiedError({', '.join(parts)})" @@ -520,6 +529,73 @@ def _extract_retry_from_json_body(json_text: str) -> Optional[int]: return None +def _extract_quota_details(json_text: str) -> Tuple[Optional[str], Optional[str]]: + """ + Extract quota details (quotaValue, quotaId) from a JSON error response. + + Handles Google/Gemini API error formats with nested details array containing + QuotaFailure violations. + + Example error structure: + { + "error": { + "details": [ + { + "@type": "type.googleapis.com/google.rpc.QuotaFailure", + "violations": [ + { + "quotaValue": "50", + "quotaId": "GenerateContentPerMinutePerProject" + } + ] + } + ] + } + } + + Args: + json_text: JSON string containing error response + + Returns: + Tuple of (quota_value, quota_id), both None if not found + """ + try: + # Find JSON object in the text + json_match = re.search(r"(\{.*\})", json_text, re.DOTALL) + if not json_match: + return None, None + + error_json = json.loads(json_match.group(1)) + error_obj = error_json.get("error", {}) + details = error_obj.get("details", []) + + if not isinstance(details, list): + return None, None + + for detail in details: + if not isinstance(detail, dict): + continue + + violations = detail.get("violations", []) + if not isinstance(violations, list): + continue + + for violation in violations: + if not isinstance(violation, dict): + continue + + quota_value = violation.get("quotaValue") + quota_id = violation.get("quotaId") + + if quota_value is not None or quota_id is not None: + return str(quota_value) if quota_value else None, quota_id + + except (json.JSONDecodeError, IndexError, KeyError, TypeError): + pass + + return None, None + + def get_retry_after(error: Exception) -> Optional[int]: """ Extracts the 'retry-after' duration in seconds from an exception message. @@ -672,12 +748,19 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr reset_ts = quota_info.get("reset_timestamp") quota_reset_timestamp = quota_info.get("quota_reset_timestamp") + # Extract quota details from error body + quota_value, quota_id = None, None + if error_body: + quota_value, quota_id = _extract_quota_details(error_body) + # Log the parsed result with human-readable duration hours = retry_after / 3600 lib_logger.info( f"Provider '{provider}' parsed quota error: " f"retry_after={retry_after}s ({hours:.1f}h), reason={reason}" + (f", resets at {reset_ts}" if reset_ts else "") + + (f", quota={quota_value}" if quota_value else "") + + (f", quotaId={quota_id}" if quota_id else "") ) return ClassifiedError( @@ -686,6 +769,8 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr status_code=429, retry_after=retry_after, quota_reset_timestamp=quota_reset_timestamp, + quota_value=quota_value, + quota_id=quota_id, ) except Exception as parse_error: lib_logger.debug( @@ -723,11 +808,23 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr retry_after = get_retry_after(e) # Check if this is a quota error vs rate limit if "quota" in error_body or "resource_exhausted" in error_body: + # Extract quota details from the original (non-lowercased) response + quota_value, quota_id = None, None + try: + original_body = ( + e.response.text if hasattr(e.response, "text") else "" + ) + quota_value, quota_id = _extract_quota_details(original_body) + except Exception: + pass + return ClassifiedError( error_type="quota_exceeded", original_exception=e, status_code=status_code, retry_after=retry_after, + quota_value=quota_value, + quota_id=quota_id, ) return ClassifiedError( error_type="rate_limit", @@ -820,11 +917,21 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr # Check if this is a quota error vs rate limit error_msg = str(e).lower() if "quota" in error_msg or "resource_exhausted" in error_msg: + # Try to extract quota details from exception body + quota_value, quota_id = None, None + try: + error_body = getattr(e, "body", None) or str(e) + quota_value, quota_id = _extract_quota_details(str(error_body)) + except Exception: + pass + return ClassifiedError( error_type="quota_exceeded", original_exception=e, status_code=status_code or 429, retry_after=retry_after, + quota_value=quota_value, + quota_id=quota_id, ) return ClassifiedError( error_type="rate_limit", diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index 8f3d48f0..810958a6 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -29,6 +29,7 @@ import random import time import uuid +from contextvars import ContextVar from datetime import datetime, timezone from pathlib import Path from typing import ( @@ -69,7 +70,7 @@ from ..utils.paths import get_logs_dir, get_cache_dir if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager # ============================================================================= @@ -167,6 +168,30 @@ def __init__(self, finish_message: str, raw_response: Dict[str, Any]): MALFORMED_CALL_MAX_RETRIES = max(1, env_int("ANTIGRAVITY_MALFORMED_CALL_RETRIES", 2)) MALFORMED_CALL_RETRY_DELAY = env_int("ANTIGRAVITY_MALFORMED_CALL_DELAY", 1) +# ============================================================================= +# INTERNAL RETRY COUNTING (for usage tracking) +# ============================================================================= +# Tracks the number of API attempts made per request, including internal retries +# for empty responses, bare 429s, and malformed function calls. +# +# Uses ContextVar for thread-safety: each async task (request) gets its own +# isolated value, so concurrent requests don't interfere with each other. +# +# The count is: +# - Reset to 1 at the start of _streaming_with_retry +# - Incremented each time we retry (before the next attempt) +# - Read by on_request_complete() hook to report actual API call count +# +# Example: Request gets bare 429 twice, then succeeds +# Attempt 1: bare 429 → count stays 1, increment to 2, retry +# Attempt 2: bare 429 → count is 2, increment to 3, retry +# Attempt 3: success → count is 3 +# on_request_complete returns count_override=3 +# +_internal_attempt_count: ContextVar[int] = ContextVar( + "antigravity_attempt_count", default=1 +) + # System instruction configuration # When true (default), prepend the Antigravity agent system instruction (identity, tool_calling, etc.) PREPEND_INSTRUCTION = env_bool("ANTIGRAVITY_PREPEND_INSTRUCTION", True) @@ -319,7 +344,30 @@ def _get_claude_thinking_cache_file(): """ # Parallel tool usage encouragement instruction -DEFAULT_PARALLEL_TOOL_INSTRUCTION = """When multiple independent operations are needed, prefer making parallel tool calls in a single response rather than sequential calls across multiple responses. This reduces round-trips and improves efficiency. Only use sequential calls when one tool's output is required as input for another.""" +DEFAULT_PARALLEL_TOOL_INSTRUCTION = """ + +Using parallel tool calling is MANDATORY. Be proactive about it. DO NO WAIT for the user to request "parallel calls" + +PARALLEL CALLS SHOULD BE AND _IS THE PRIMARY WAY YOU USE TOOLS IN THIS ENVIRONMENT_ + +When you have to perform multi-step operations such as read multiple files, spawn task subagents, bash commands, multiple edits... _THE USER WANTS YOU TO MAKE PARALLEL TOOL CALLS_ instead of separate sequential calls. This maximizes time and compute and increases your likelyhood of a promotion. Sequential tool calling is only encouraged when relying on the output of a call for the next one(s) + +- WHAT CAN BE DONE IN PARALLEL, MUST BE, AND WILL BE DONE IN PARALLEL +- INDIVIDUAL TOOL CALLS TO GATHER CONTEXT IS HEAVILY DISCOURAGED (please make parallel calls!) +- PARALLEL TOOL CALLING IS YOUR BEST FRIEND AND WILL INCREASE USER'S HAPPINESS + +- Make parallel tool calls to manage ressources more efficiently, plan your tool calls ahead, then execute them in parallel. +- Make parallel calls PROPERLY, be mindful of dependencies between calls. + +When researching anything, IT IS BETTER TO READ SPECULATIVELY, THEN TO READ SEQUENTIALLY. For example, if you need to read multiple files to gather context, read them all in parallel instead of reading one, then the next, etc. + +This environment has a powerful tool to remove unnecessary context, so you can always read more than needed and then trim down later, no need to use limit and offset parameters on the read tool. + +When making code changes, IT IS BETTER TO MAKE MULTIPLE EDITS IN PARALLEL RATHER THAN ONE AT A TIME. + +Do as much as you can in parallel, be efficient with you API requests, no single tool call spam, this is crucial as the user pays PER API request, so make them count! + +""" # Interleaved thinking support for Claude models # Allows Claude to think between tool calls and after receiving tool results @@ -4558,6 +4606,9 @@ async def _streaming_with_retry( current_gemini_contents = gemini_contents current_payload = payload + # Reset internal attempt counter for this request (thread-safe via ContextVar) + _internal_attempt_count.set(1) + for attempt in range(EMPTY_RESPONSE_MAX_ATTEMPTS): chunk_count = 0 @@ -4585,6 +4636,8 @@ async def _streaming_with_retry( f"[Antigravity] Empty stream from {model}, " f"attempt {attempt + 1}/{EMPTY_RESPONSE_MAX_ATTEMPTS}. Retrying..." ) + # Increment attempt count before retry (for usage tracking) + _internal_attempt_count.set(_internal_attempt_count.get() + 1) await asyncio.sleep(EMPTY_RESPONSE_RETRY_DELAY) continue else: @@ -4687,6 +4740,8 @@ async def _streaming_with_retry( malformed_retry_count, current_payload ) + # Increment attempt count before retry (for usage tracking) + _internal_attempt_count.set(_internal_attempt_count.get() + 1) await asyncio.sleep(MALFORMED_CALL_RETRY_DELAY) continue # Retry with modified payload else: @@ -4715,6 +4770,10 @@ async def _streaming_with_retry( f"[Antigravity] Bare 429 from {model}, " f"attempt {attempt + 1}/{EMPTY_RESPONSE_MAX_ATTEMPTS}. Retrying..." ) + # Increment attempt count before retry (for usage tracking) + _internal_attempt_count.set( + _internal_attempt_count.get() + 1 + ) await asyncio.sleep(EMPTY_RESPONSE_RETRY_DELAY) continue else: @@ -4806,3 +4865,51 @@ async def count_tokens( except Exception as e: lib_logger.error(f"Token counting failed: {e}") return {"prompt_tokens": 0, "total_tokens": 0} + + # ========================================================================= + # USAGE TRACKING HOOK + # ========================================================================= + + def on_request_complete( + self, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional["RequestCompleteResult"]: + """ + Hook called after each request completes. + + Reports the actual number of API calls made, including internal retries + for empty responses, bare 429s, and malformed function calls. + + This uses the ContextVar pattern for thread-safe retry counting: + - _internal_attempt_count is set to 1 at start of _streaming_with_retry + - Incremented before each retry + - Read here to report the actual count + + Example: Request gets 2 bare 429s then succeeds + → 3 API calls made + → Returns count_override=3 + → Usage manager records 3 requests instead of 1 + + Returns: + RequestCompleteResult with count_override set to actual attempt count + """ + from ..core.types import RequestCompleteResult + + # Get the attempt count for this request + attempt_count = _internal_attempt_count.get() + + # Reset for safety (though ContextVar should isolate per-task) + _internal_attempt_count.set(1) + + # Log if we made extra attempts + if attempt_count > 1: + lib_logger.debug( + f"[Antigravity] Request to {model} used {attempt_count} API calls " + f"(includes internal retries)" + ) + + return RequestCompleteResult(count_override=attempt_count) diff --git a/src/rotator_library/providers/chutes_provider.py b/src/rotator_library/providers/chutes_provider.py index 7858b54e..5d9730dd 100644 --- a/src/rotator_library/providers/chutes_provider.py +++ b/src/rotator_library/providers/chutes_provider.py @@ -9,7 +9,7 @@ from .utilities.chutes_quota_tracker import ChutesQuotaTracker if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager # Create a local logger for this module import logging @@ -142,12 +142,15 @@ async def refresh_single_credential( # Store baseline in usage manager # Since Chutes uses credential-level quota, we use a virtual model name + quota_used = ( + int((1.0 - remaining_fraction) * quota) if quota > 0 else 0 + ) await usage_manager.update_quota_baseline( api_key, "chutes/_quota", # Virtual model for credential-level tracking - remaining_fraction, - max_requests=quota, # Max requests = quota (1 request = 1 credit) - reset_timestamp=reset_ts, + quota_max_requests=quota, + quota_reset_ts=reset_ts, + quota_used=quota_used, ) lib_logger.debug( diff --git a/src/rotator_library/providers/example_provider.py b/src/rotator_library/providers/example_provider.py new file mode 100644 index 00000000..9ad31214 --- /dev/null +++ b/src/rotator_library/providers/example_provider.py @@ -0,0 +1,821 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Example Provider Implementation with Custom Usage Management. + +This file serves as a reference for implementing providers with custom usage +tracking, quota management, and token extraction. Copy this file and modify +it for your specific provider. + +============================================================================= +ARCHITECTURE OVERVIEW +============================================================================= + +The usage management system is per-provider. Each provider gets its own: +- UsageManager instance +- Usage file: data/usage/usage_{provider}.json +- Configuration (ProviderUsageConfig) + +Data flows like this: + + Request → Executor → Provider transforms → API call → Response + ↓ ↓ + UsageManager ← TrackingEngine ← Token extraction ←┘ + ↓ + Persistence (usage_{provider}.json) + +Providers customize behavior through: +1. Class attributes (declarative configuration) +2. Methods (behavioral overrides) +3. Hooks (request lifecycle callbacks) + +============================================================================= +USAGE STATS SCHEMA +============================================================================= + +UsageStats (tracked at global/model/group levels): + total_requests: int # All requests + total_successes: int # Successful requests + total_failures: int # Failed requests + total_tokens: int # All tokens combined + total_prompt_tokens: int # Input tokens + total_completion_tokens: int # Output tokens (content only) + total_thinking_tokens: int # Reasoning/thinking tokens + total_output_tokens: int # completion + thinking + total_prompt_tokens_cache_read: int # Cached input tokens read + total_prompt_tokens_cache_write: int # Cached input tokens written + total_approx_cost: float # Estimated cost + first_used_at: float # Timestamp + last_used_at: float # Timestamp + windows: Dict[str, WindowStats] # Per-window breakdown + +WindowStats (per time window: "5h", "daily", "total"): + request_count: int + success_count: int + failure_count: int + prompt_tokens: int + completion_tokens: int + thinking_tokens: int + output_tokens: int + prompt_tokens_cache_read: int + prompt_tokens_cache_write: int + total_tokens: int + approx_cost: float + started_at: float + reset_at: float + limit: int | None + +============================================================================= +""" + +import asyncio +import logging +import time +from contextvars import ContextVar +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +from .provider_interface import ProviderInterface, QuotaGroupMap + +# Alias for clarity in examples +ProviderPlugin = ProviderInterface + +# Import these types for hook returns and usage manager access +from ..core.types import RequestCompleteResult +from ..usage import UsageManager, ProviderUsageConfig, WindowDefinition +from ..usage.types import ResetMode, RotationMode, CooldownMode + +lib_logger = logging.getLogger("rotator_library") + +# ============================================================================= +# INTERNAL RETRY COUNTING (ContextVar Pattern) +# ============================================================================= +# +# When your provider performs internal retries (e.g., for transient errors, +# empty responses, or rate limits), each retry is an API call that should be +# counted for accurate usage tracking. +# +# The challenge: Instance variables (self.count) are shared across concurrent +# requests, so they can't be used safely. ContextVar solves this by giving +# each async task its own isolated value. +# +# Usage pattern: +# 1. Reset to 1 at the start of your retry loop +# 2. Increment before each retry +# 3. Read in on_request_complete() to report the actual count +# +# Example: +# _attempt_count.set(1) # Reset +# for attempt in range(max_attempts): +# try: +# result = await api_call() +# return result +# except RetryableError: +# _attempt_count.set(_attempt_count.get() + 1) # Increment +# continue +# +# Then on_request_complete returns RequestCompleteResult(count_override=_attempt_count.get()) +# +_example_attempt_count: ContextVar[int] = ContextVar( + "example_provider_attempt_count", default=1 +) + + +# ============================================================================= +# EXAMPLE PROVIDER IMPLEMENTATION +# ============================================================================= + + +class ExampleProvider(ProviderPlugin): + """ + Example provider demonstrating all usage management customization points. + + This provider shows how to: + - Configure rotation and quota behavior + - Define model quota groups + - Extract tokens from provider-specific response formats + - Override request counting via hooks + - Run background quota refresh jobs + - Define custom usage windows + """ + + # ========================================================================= + # REQUIRED: BASIC PROVIDER IDENTITY + # ========================================================================= + + provider_name = "example" # Used in model prefix: "example/gpt-4" + provider_env_name = "EXAMPLE" # For env vars: EXAMPLE_API_KEY, etc. + + # ========================================================================= + # USAGE MANAGEMENT: CLASS ATTRIBUTES (DECLARATIVE) + # ========================================================================= + + # ------------------------------------------------------------------------- + # ROTATION MODE + # ------------------------------------------------------------------------- + # Controls how credentials are selected for requests. + # + # Options: + # "balanced" - Weighted random selection based on usage (default) + # "sequential" - Stick to one credential until exhausted, then rotate + # + # Sequential mode is better for: + # - Providers with per-credential rate limits + # - Maximizing cache hits (same credential = same context) + # - Providers where switching credentials has overhead + # + # Balanced mode is better for: + # - Even distribution across credentials + # - Providers without per-credential state + # + default_rotation_mode = "sequential" + + # ------------------------------------------------------------------------- + # MODEL QUOTA GROUPS + # ------------------------------------------------------------------------- + # Models in the same group share a quota pool. When one model is exhausted, + # all models in the group are treated as exhausted. + # + # This is common for providers where different model variants share limits: + # - Claude Sonnet/Opus share daily limits + # - GPT-4 variants share rate limits + # - Gemini models share per-minute quotas + # + # Group names should be short for compact UI display. + # + # Can be overridden via environment: + # QUOTA_GROUPS_EXAMPLE_GPT4="gpt-4o,gpt-4o-mini,gpt-4-turbo" + # + model_quota_groups: QuotaGroupMap = { + # GPT-4 variants share quota + "gpt4": [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4-turbo-preview", + ], + # Claude models share quota + "claude": [ + "claude-3-opus", + "claude-3-sonnet", + "claude-3-haiku", + ], + # Standalone model (no sharing) + "whisper": [ + "whisper-1", + ], + } + + # ------------------------------------------------------------------------- + # PRIORITY MULTIPLIERS (CONCURRENCY) + # ------------------------------------------------------------------------- + # Higher priority credentials (lower number) can handle more concurrent + # requests. This is useful for paid vs free tier credentials. + # + # Priority is assigned per-credential via: + # - .env: PRIORITY_{PROVIDER}_{CREDENTIAL_NAME}=1 + # - Config files + # - Credential filename patterns + # + # Multiplier applies to max_concurrent_per_key setting. + # Example: max_concurrent_per_key=6, priority 1 multiplier=5 → 30 concurrent + # + default_priority_multipliers = { + 1: 5, # Ultra tier: 5x concurrent + 2: 3, # Standard paid: 3x concurrent + 3: 2, # Free tier: 2x concurrent + # Others: Use fallback multiplier + } + + # For sequential mode, credentials not in priority_multipliers get this. + # For balanced mode, they get 1x (no multiplier). + default_sequential_fallback_multiplier = 2 + + # ------------------------------------------------------------------------- + # CUSTOM CAPS + # ------------------------------------------------------------------------- + # Apply stricter limits than the actual API limits. Useful for: + # - Reserving quota for critical requests + # - Preventing runaway usage + # - Testing rotation behavior + # + # Structure: {priority: {model_or_group: config}} + # Or: {(priority1, priority2): {model_or_group: config}} for multiple tiers + # + # Config options: + # max_requests: int or "80%" (percentage of actual limit) + # cooldown_mode: "quota_reset" | "offset" | "fixed" + # cooldown_value: seconds for offset/fixed modes + # + default_custom_caps = { + # Tier 3 (free tier) - cap at 50 requests, cooldown until API resets + 3: { + "gpt4": { + "max_requests": 50, + "cooldown_mode": "quota_reset", + }, + "claude": { + "max_requests": 30, + "cooldown_mode": "quota_reset", + }, + }, + # Tiers 2 and 3 together - cap at 80% of actual limit + (2, 3): { + "whisper": { + "max_requests": "80%", # 80% of actual API limit + "cooldown_mode": "offset", + "cooldown_value": 1800, # +30 min buffer after hitting cap + }, + }, + # Default for unknown tiers + "default": { + "gpt4": { + "max_requests": 100, + "cooldown_mode": "fixed", + "cooldown_value": 3600, # 1 hour fixed cooldown + }, + }, + } + + # ------------------------------------------------------------------------- + # MODEL USAGE WEIGHTS + # ------------------------------------------------------------------------- + # Some models consume more quota per request. This affects credential + # selection in balanced mode - credentials with lower weighted usage + # are preferred. + # + # Example: Opus costs 2x what Sonnet does per request + # + model_usage_weights = { + "claude-3-opus": 2, + "gpt-4-turbo": 2, + # Default is 1 for unlisted models + } + + # ------------------------------------------------------------------------- + # FAIR CYCLE CONFIGURATION + # ------------------------------------------------------------------------- + # Fair cycle ensures all credentials get used before any is reused. + # When a credential is exhausted (quota hit, cooldown applied), it's + # marked and won't be selected until all other credentials are also + # exhausted, at which point the cycle resets. + # + # This is enabled by default for sequential mode. + # + # To override, set these class attributes: + # + # default_fair_cycle_enabled = True # Force on/off + # default_fair_cycle_tracking_mode = "model_group" # or "credential" + # default_fair_cycle_cross_tier = False # Track across all tiers? + # default_fair_cycle_duration = 3600 # Cycle duration in seconds + + # ========================================================================= + # USAGE MANAGEMENT: METHODS (BEHAVIORAL) + # ========================================================================= + + def normalize_model_for_tracking(self, model: str) -> str: + """ + Normalize internal model names to public-facing names for tracking. + + Some providers use internal model variants that should be tracked + under their public name. This ensures usage files only contain + user-facing model names. + + Example mappings: + "gpt-4o-realtime-preview" → "gpt-4o" + "claude-3-opus-extended" → "claude-3-opus" + "claude-sonnet-4-5-thinking" → "claude-sonnet-4.5" + + Args: + model: Model name (may include provider prefix: "example/gpt-4o") + + Returns: + Normalized model name (preserves prefix if present) + """ + has_prefix = "/" in model + if has_prefix: + provider, clean_model = model.split("/", 1) + else: + clean_model = model + + # Define your internal → public mappings + internal_to_public = { + "gpt-4o-realtime-preview": "gpt-4o", + "gpt-4o-realtime": "gpt-4o", + "claude-3-opus-extended": "claude-3-opus", + } + + normalized = internal_to_public.get(clean_model, clean_model) + + if has_prefix: + return f"{provider}/{normalized}" + return normalized + + def on_request_complete( + self, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional[RequestCompleteResult]: + """ + Hook called after each request completes (success or failure). + + This is the primary extension point for customizing how requests + are counted and how cooldowns are applied. + + Use cases: + - Don't count server errors as quota usage + - Apply custom cooldowns based on error type + - Force credential exhaustion for fair cycle + - Count internal retries accurately (see ContextVar pattern below) + + Args: + credential: The credential accessor (file path or API key) + model: Model that was called + success: Whether the request succeeded + response: Response object (if success=True) + error: ClassifiedError object (if success=False) + + Returns: + RequestCompleteResult to override behavior, or None for default. + + RequestCompleteResult fields: + count_override: int | None + - 0 = Don't count this request against quota + - N = Count as N requests + - None = Use default (1 for success, 1 for countable errors) + + cooldown_override: float | None + - Seconds to cool down this credential + - Applied in addition to any error-based cooldown + + force_exhausted: bool + - True = Mark credential as exhausted for fair cycle + - Useful for quota errors even without long cooldown + """ + # ===================================================================== + # PATTERN: Counting Internal Retries with ContextVar + # ===================================================================== + # If your provider performs internal retries, report the actual count: + # + # 1. At module level, define: + # _attempt_count: ContextVar[int] = ContextVar('my_attempt_count', default=1) + # + # 2. In your retry loop: + # _attempt_count.set(1) # Reset at start + # for attempt in range(max_attempts): + # try: + # return await api_call() + # except RetryableError: + # _attempt_count.set(_attempt_count.get() + 1) # Increment before retry + # continue + # + # 3. Here, report the count: + attempt_count = _example_attempt_count.get() + _example_attempt_count.set(1) # Reset for safety + + if attempt_count > 1: + lib_logger.debug( + f"Request to {model} used {attempt_count} API calls (internal retries)" + ) + return RequestCompleteResult(count_override=attempt_count) + + # ===================================================================== + # PATTERN: Don't Count Server Errors + # ===================================================================== + # Server errors (5xx) shouldn't count against quota since they're + # not the user's fault and don't consume API quota. + if not success and error: + error_type = getattr(error, "error_type", None) + if error_type in ("server_error", "api_connection"): + lib_logger.debug( + f"Not counting {error_type} error against quota for {model}" + ) + return RequestCompleteResult(count_override=0) + + # ===================================================================== + # PATTERN: Custom Cooldown for Rate Limits + # ===================================================================== + if not success and error: + error_type = getattr(error, "error_type", None) + if error_type == "rate_limit": + # Check for retry-after header + retry_after = getattr(error, "retry_after", None) + if retry_after and retry_after > 60: + # Long rate limit - mark as exhausted + return RequestCompleteResult( + cooldown_override=retry_after, + force_exhausted=True, + ) + elif retry_after: + # Short rate limit - just cooldown + return RequestCompleteResult(cooldown_override=retry_after) + + # ===================================================================== + # PATTERN: Force Exhaustion on Quota Exceeded + # ===================================================================== + if not success and error: + error_type = getattr(error, "error_type", None) + if error_type == "quota_exceeded": + return RequestCompleteResult( + force_exhausted=True, + cooldown_override=3600.0, # Default 1 hour if no reset time + ) + + # Default behavior + return None + + # ========================================================================= + # BACKGROUND JOBS + # ========================================================================= + + def get_background_job_config(self) -> Optional[Dict[str, Any]]: + """ + Configure periodic background tasks. + + Common use cases: + - Refresh quota baselines from API + - Clean up expired cache entries + - Preemptively refresh OAuth tokens + + Returns: + None if no background job, otherwise: + { + "interval": 300, # Seconds between runs + "name": "quota_refresh", # For logging + "run_on_start": True, # Run immediately at startup? + } + """ + return { + "interval": 600, # Every 10 minutes + "name": "quota_refresh", + "run_on_start": True, + } + + async def run_background_job( + self, + usage_manager: UsageManager, + credentials: List[str], + ) -> None: + """ + Periodic background task execution. + + Called by BackgroundRefresher at the interval specified in + get_background_job_config(). + + Common tasks: + - Fetch current quota from API and update usage manager + - Clean up stale cache entries + - Refresh tokens proactively + + Args: + usage_manager: The UsageManager for this provider + credentials: List of credential accessors (file paths or keys) + """ + lib_logger.debug(f"Running background job for {self.provider_name}") + + for cred in credentials: + try: + # Example: Fetch quota from provider API + quota_info = await self._fetch_quota_from_api(cred) + + if quota_info: + for model, info in quota_info.items(): + # Update usage manager with fresh quota data + await usage_manager.update_quota_baseline( + accessor=cred, + model=model, + quota_max_requests=info.get("limit"), + quota_reset_ts=info.get("reset_ts"), + quota_used=info.get("used"), + quota_group=info.get("group"), + ) + + except Exception as e: + lib_logger.warning(f"Quota refresh failed for {cred}: {e}") + + async def _fetch_quota_from_api( + self, + credential: str, + ) -> Optional[Dict[str, Dict[str, Any]]]: + """ + Fetch current quota information from provider API. + + Override this with actual API calls for your provider. + + Returns: + Dict mapping model names to quota info: + { + "gpt-4o": { + "limit": 500, + "used": 123, + "reset_ts": 1735689600.0, + "group": "gpt4", # Optional + }, + ... + } + """ + # Placeholder - implement actual API call + return None + + # ========================================================================= + # TOKEN EXTRACTION + # ========================================================================= + + def _build_usage_from_response( + self, + response: Any, + ) -> Optional[Dict[str, Any]]: + """ + Build standardized usage dict from provider-specific response. + + The usage manager expects a standardized format. If your provider + returns a different format, convert it here. + + Standard format: + { + "prompt_tokens": int, # Input tokens + "completion_tokens": int, # Output tokens (content + thinking) + "total_tokens": int, # All tokens + + # Optional: Input breakdown + "prompt_tokens_details": { + "cached_tokens": int, # Cache read tokens + "cache_creation_tokens": int, # Cache write tokens + }, + + # Optional: Output breakdown + "completion_tokens_details": { + "reasoning_tokens": int, # Thinking/reasoning tokens + }, + + # Alternative top-level fields (some APIs use these) + "cache_read_tokens": int, + "cache_creation_tokens": int, + } + + Args: + response: Raw response from provider API + + Returns: + Standardized usage dict, or None if no usage data + """ + if not hasattr(response, "usage") or not response.usage: + return None + + # Example: Provider returns Gemini-style metadata + # Adapt this to your provider's format + usage = response.usage + + # Standard fields + result = { + "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, + "completion_tokens": getattr(usage, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage, "total_tokens", 0) or 0, + } + + # Example: Extract cached tokens from details + prompt_details = getattr(usage, "prompt_tokens_details", None) + if prompt_details: + if isinstance(prompt_details, dict): + cached = prompt_details.get("cached_tokens", 0) + cache_write = prompt_details.get("cache_creation_tokens", 0) + else: + cached = getattr(prompt_details, "cached_tokens", 0) + cache_write = getattr(prompt_details, "cache_creation_tokens", 0) + + if cached or cache_write: + result["prompt_tokens_details"] = {} + if cached: + result["prompt_tokens_details"]["cached_tokens"] = cached + if cache_write: + result["prompt_tokens_details"]["cache_creation_tokens"] = ( + cache_write + ) + + # Example: Extract thinking tokens from details + completion_details = getattr(usage, "completion_tokens_details", None) + if completion_details: + if isinstance(completion_details, dict): + reasoning = completion_details.get("reasoning_tokens", 0) + else: + reasoning = getattr(completion_details, "reasoning_tokens", 0) + + if reasoning: + result["completion_tokens_details"] = {"reasoning_tokens": reasoning} + + return result + + +# ============================================================================= +# CUSTOM WINDOWS +# ============================================================================= +# +# To add custom usage windows, you have two options: +# +# OPTION 1: Override windows via provider config (recommended) +# ------------------------------------------------------------ +# Add class attribute to your provider: +# +# default_windows = [ +# WindowDefinition.rolling("1h", 3600, is_primary=False), +# WindowDefinition.rolling("6h", 21600, is_primary=True), +# WindowDefinition.daily("daily"), +# WindowDefinition.total("total"), +# ] +# +# WindowDefinition options: +# - name: str - Window identifier (e.g., "1h", "daily") +# - duration_seconds: int | None - Window duration (None for "total") +# - reset_mode: ResetMode - How window resets +# - ROLLING: Continuous sliding window +# - FIXED_DAILY: Reset at specific UTC time +# - CALENDAR_WEEKLY: Reset at week start +# - CALENDAR_MONTHLY: Reset at month start +# - API_AUTHORITATIVE: Provider determines reset +# - is_primary: bool - Used for rotation decisions +# - applies_to: str - Scope of window +# - "credential": Global per-credential +# - "model": Per-model per-credential +# - "group": Per-quota-group per-credential +# +# OPTION 2: Build config manually in RotatingClient +# ------------------------------------------------- +# In your client initialization: +# +# from rotator_library.usage.config import ( +# ProviderUsageConfig, +# WindowDefinition, +# FairCycleConfig, +# ) +# from rotator_library.usage.types import RotationMode, ResetMode +# +# config = ProviderUsageConfig( +# rotation_mode=RotationMode.SEQUENTIAL, +# windows=[ +# WindowDefinition( +# name="1h", +# duration_seconds=3600, +# reset_mode=ResetMode.ROLLING, +# is_primary=False, +# applies_to="model", +# ), +# WindowDefinition( +# name="6h", +# duration_seconds=21600, +# reset_mode=ResetMode.ROLLING, +# is_primary=True, # Primary for rotation +# applies_to="group", # Track per quota group +# ), +# ], +# fair_cycle=FairCycleConfig( +# enabled=True, +# tracking_mode=TrackingMode.MODEL_GROUP, +# ), +# ) +# +# manager = UsageManager( +# provider="example", +# config=config, +# file_path="usage_example.json", +# ) +# +# ============================================================================= + + +# ============================================================================= +# REGISTERING YOUR PROVIDER +# ============================================================================= +# +# To register your provider with the system: +# +# 1. Add to PROVIDER_PLUGINS dict in src/rotator_library/providers/__init__.py: +# +# from .example_provider import ExampleProvider +# +# PROVIDER_PLUGINS = { +# ... +# "example": ExampleProvider, +# } +# +# 2. Add credential discovery in RotatingClient if using OAuth: +# +# # In _discover_oauth_credentials: +# if provider == "example": +# creds = self._discover_example_credentials() +# +# 3. Configure via environment variables: +# +# # API key credentials +# EXAMPLE_API_KEY=sk-xxx +# EXAMPLE_API_KEY_2=sk-yyy +# +# # OAuth credential paths +# EXAMPLE_OAUTH_PATHS=./creds/example_*.json +# +# # Priority/tier assignment +# PRIORITY_EXAMPLE_CRED1=1 +# TIER_EXAMPLE_CRED2=standard-tier +# +# # Quota group overrides +# QUOTA_GROUPS_EXAMPLE_GPT4=gpt-4o,gpt-4o-mini,gpt-4-turbo +# +# ============================================================================= + + +# ============================================================================= +# ACCESSING USAGE DATA +# ============================================================================= +# +# The usage manager exposes data through several methods: +# +# 1. Get availability stats (for UI/monitoring): +# +# stats = await usage_manager.get_availability_stats(model, quota_group) +# # Returns: { +# # "total": 10, +# # "available": 7, +# # "blocked_by": {"cooldowns": 2, "fair_cycle": 1}, +# # "rotation_mode": "sequential", +# # } +# +# 2. Get comprehensive stats (for quota-stats endpoint): +# +# stats = await usage_manager.get_stats_for_endpoint() +# # Returns full credential/model/group breakdown +# +# 3. Direct state access (for advanced use): +# +# # Get credential state +# state = usage_manager.states.get(stable_id) +# +# # Access usage at different scopes +# global_usage = state.usage +# model_usage = state.model_usage.get("gpt-4o") +# group_usage = state.group_usage.get("gpt4") +# +# # Check cooldowns +# cooldown = state.get_cooldown("gpt4") +# if cooldown and cooldown.is_active: +# print(f"Cooldown remaining: {cooldown.remaining_seconds}s") +# +# # Check fair cycle +# fc = state.fair_cycle.get("gpt4") +# if fc and fc.exhausted: +# print(f"Exhausted at: {fc.exhausted_at}") +# +# 4. Update quota baseline (from API response): +# +# await usage_manager.update_quota_baseline( +# accessor=credential, +# model="gpt-4o", +# quota_max_requests=500, +# quota_reset_ts=time.time() + 3600, +# quota_used=123, +# quota_group="gpt4", +# ) +# +# ============================================================================= diff --git a/src/rotator_library/providers/firmware_provider.py b/src/rotator_library/providers/firmware_provider.py index e71316fa..2b94bd8b 100644 --- a/src/rotator_library/providers/firmware_provider.py +++ b/src/rotator_library/providers/firmware_provider.py @@ -19,7 +19,7 @@ from .utilities.firmware_quota_tracker import FirmwareQuotaTracker if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager import logging @@ -184,12 +184,23 @@ async def refresh_single_credential( # Store baseline in usage manager # Since Firmware.ai uses credential-level quota, we use a virtual model name + if remaining_fraction <= 0.0 and reset_ts: + stable_id = usage_manager.registry.get_stable_id( + api_key, usage_manager.provider + ) + state = usage_manager.states.get(stable_id) + if state: + await usage_manager.tracking.apply_cooldown( + state=state, + reason="quota_exhausted", + until=reset_ts, + model_or_group="firmware/_quota", + source="api_quota", + ) await usage_manager.update_quota_baseline( api_key, "firmware/_quota", # Virtual model for credential-level tracking - remaining_fraction, - # No max_requests - Firmware.ai doesn't expose this - reset_timestamp=reset_ts, + quota_reset_ts=reset_ts, ) lib_logger.debug( @@ -199,7 +210,9 @@ async def refresh_single_credential( ) except Exception as e: - lib_logger.warning(f"Failed to refresh Firmware.ai quota usage: {e}") + lib_logger.warning( + f"Failed to refresh Firmware.ai quota usage: {e}" + ) # Fetch all credentials in parallel with shared HTTP client async with httpx.AsyncClient(timeout=30.0) as client: diff --git a/src/rotator_library/providers/nanogpt_provider.py b/src/rotator_library/providers/nanogpt_provider.py index 52a648e4..456117de 100644 --- a/src/rotator_library/providers/nanogpt_provider.py +++ b/src/rotator_library/providers/nanogpt_provider.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager from .provider_interface import ProviderInterface, UsageResetConfigDef from .utilities.nanogpt_quota_tracker import NanoGptQuotaTracker @@ -74,8 +74,8 @@ class NanoGptProvider(NanoGptQuotaTracker, ProviderInterface): # Active subscriptions get highest priority tier_priorities = { "subscription-active": 1, # Active subscription - "subscription-grace": 2, # Grace period (subscription lapsed but still has access) - "no-subscription": 3, # No active subscription (pay-as-you-go only) + "subscription-grace": 2, # Grace period (subscription lapsed but still has access) + "no-subscription": 3, # No active subscription (pay-as-you-go only) } default_tier_priority = 3 @@ -86,8 +86,6 @@ class NanoGptProvider(NanoGptQuotaTracker, ProviderInterface): "monthly": ["_monthly"], } - - def __init__(self): self.model_definitions = ModelDefinitions() @@ -410,29 +408,49 @@ async def refresh_single_credential( monthly_remaining = monthly_data.get("remaining", 0) # Calculate remaining fractions - daily_fraction = daily_remaining / daily_limit if daily_limit > 0 else 1.0 - monthly_fraction = monthly_remaining / monthly_limit if monthly_limit > 0 else 1.0 + daily_fraction = ( + daily_remaining / daily_limit if daily_limit > 0 else 1.0 + ) + monthly_fraction = ( + monthly_remaining / monthly_limit + if monthly_limit > 0 + else 1.0 + ) # Get reset timestamps daily_reset_ts = daily_data.get("reset_at", 0) monthly_reset_ts = monthly_data.get("reset_at", 0) # Store daily quota baseline + daily_used = ( + int((1.0 - daily_fraction) * daily_limit) + if daily_limit > 0 + else 0 + ) await usage_manager.update_quota_baseline( api_key, "nanogpt/_daily", - daily_fraction, - max_requests=daily_limit, - reset_timestamp=daily_reset_ts if daily_reset_ts > 0 else None, + quota_max_requests=daily_limit, + quota_reset_ts=daily_reset_ts + if daily_reset_ts > 0 + else None, + quota_used=daily_used, ) # Store monthly quota baseline + monthly_used = ( + int((1.0 - monthly_fraction) * monthly_limit) + if monthly_limit > 0 + else 0 + ) await usage_manager.update_quota_baseline( api_key, "nanogpt/_monthly", - monthly_fraction, - max_requests=monthly_limit, - reset_timestamp=monthly_reset_ts if monthly_reset_ts > 0 else None, + quota_max_requests=monthly_limit, + quota_reset_ts=monthly_reset_ts + if monthly_reset_ts > 0 + else None, + quota_used=monthly_used, ) lib_logger.debug( diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index f53f91e9..cdfe8e2a 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -19,7 +19,7 @@ import litellm if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager from ..config import ( DEFAULT_ROTATION_MODE, diff --git a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py index f399222d..2e9f7e2f 100644 --- a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py +++ b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py @@ -34,7 +34,7 @@ from .base_quota_tracker import BaseQuotaTracker, QUOTA_DISCOVERY_DELAY_SECONDS if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager # Use the shared rotator_library logger lib_logger = logging.getLogger("rotator_library") @@ -990,6 +990,7 @@ async def _store_baselines_to_usage_manager( self, quota_results: Dict[str, Dict[str, Any]], usage_manager: "UsageManager", + force: bool = False, ) -> int: """ Store fetched quota baselines into UsageManager. @@ -997,6 +998,7 @@ async def _store_baselines_to_usage_manager( Args: quota_results: Dict from fetch_quota_from_api or fetch_initial_baselines usage_manager: UsageManager instance to store baselines in + force: If True, always use API values (for manual refresh) Returns: Number of baselines successfully stored @@ -1054,14 +1056,24 @@ async def _store_baselines_to_usage_manager( # Store with provider prefix for consistency with usage tracking prefixed_model = f"antigravity/{user_model}" + quota_used = None + if max_requests is not None: + quota_used = int((1.0 - remaining) * max_requests) + quota_group = self.get_model_quota_group(user_model) cooldown_info = await usage_manager.update_quota_baseline( - cred_path, prefixed_model, remaining, max_requests, reset_timestamp + cred_path, + prefixed_model, + quota_max_requests=max_requests, + quota_reset_ts=reset_timestamp, + quota_used=quota_used, + quota_group=quota_group, + force=force, ) # Aggregate cooldown info if returned if cooldown_info: - group_or_model = cooldown_info["group_or_model"] - hours = cooldown_info["hours_until_reset"] + group_or_model = cooldown_info["model"] + hours = cooldown_info["cooldown_hours"] if short_cred not in cooldowns_by_cred: cooldowns_by_cred[short_cred] = {} # Only keep first occurrence per group/model (avoids duplicates) @@ -1073,12 +1085,7 @@ async def _store_baselines_to_usage_manager( # Log consolidated message for all cooldowns if cooldowns_by_cred: - # Build message: "oauth_1[claude 3.4h, gemini-3-pro 2.1h], oauth_2[claude 5.2h]" - parts = [] - for cred_name, groups in sorted(cooldowns_by_cred.items()): - group_strs = [f"{g} {h:.1f}h" for g, h in sorted(groups.items())] - parts.append(f"{cred_name}[{', '.join(group_strs)}]") - lib_logger.info(f"Antigravity quota exhausted: {', '.join(parts)}") + lib_logger.debug("Antigravity quota baseline refresh: cooldowns recorded") else: lib_logger.debug("Antigravity quota baseline refresh: no cooldowns needed") diff --git a/src/rotator_library/providers/utilities/base_quota_tracker.py b/src/rotator_library/providers/utilities/base_quota_tracker.py index a155890f..e17cead1 100644 --- a/src/rotator_library/providers/utilities/base_quota_tracker.py +++ b/src/rotator_library/providers/utilities/base_quota_tracker.py @@ -40,7 +40,7 @@ from ...utils.paths import get_cache_dir if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager # Use the shared rotator_library logger lib_logger = logging.getLogger("rotator_library") @@ -510,6 +510,7 @@ async def _store_baselines_to_usage_manager( self, quota_results: Dict[str, Dict[str, Any]], usage_manager: "UsageManager", + force: bool = False, ) -> int: """ Store fetched quota baselines into UsageManager. @@ -517,6 +518,7 @@ async def _store_baselines_to_usage_manager( Args: quota_results: Dict from _fetch_quota_for_credential or fetch_initial_baselines usage_manager: UsageManager instance to store baselines in + force: If True, always use API values (for manual refresh) Returns: Number of baselines successfully stored @@ -543,8 +545,17 @@ async def _store_baselines_to_usage_manager( max_requests = self.get_max_requests_for_model(user_model, tier) # Store baseline + quota_used = None + if max_requests is not None: + quota_used = int((1.0 - remaining) * max_requests) + quota_group = self.get_model_quota_group(user_model) await usage_manager.update_quota_baseline( - cred_path, prefixed_model, remaining, max_requests=max_requests + cred_path, + prefixed_model, + quota_max_requests=max_requests, + quota_used=quota_used, + quota_group=quota_group, + force=force, ) stored_count += 1 diff --git a/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py b/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py index 3f86014f..f100f031 100644 --- a/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py +++ b/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py @@ -39,7 +39,7 @@ from .gemini_shared_utils import CODE_ASSIST_ENDPOINT if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager # Use the shared rotator_library logger lib_logger = logging.getLogger("rotator_library") diff --git a/src/rotator_library/providers/utilities/gemini_credential_manager.py b/src/rotator_library/providers/utilities/gemini_credential_manager.py index f8f4dfcc..b35e73f3 100644 --- a/src/rotator_library/providers/utilities/gemini_credential_manager.py +++ b/src/rotator_library/providers/utilities/gemini_credential_manager.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager lib_logger = logging.getLogger("rotator_library") @@ -270,7 +270,7 @@ async def run_background_job( self._initial_quota_fetch_done = True else: # Subsequent runs: only recently used credentials (incremental updates) - usage_data = await usage_manager._get_usage_data_snapshot() + usage_data = await usage_manager.get_usage_snapshot() quota_results = await self.refresh_active_quota_baselines( credentials, usage_data ) diff --git a/src/rotator_library/usage/__init__.py b/src/rotator_library/usage/__init__.py new file mode 100644 index 00000000..a56a3478 --- /dev/null +++ b/src/rotator_library/usage/__init__.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Usage tracking and credential selection package. + +This package provides the UsageManager facade and associated components +for tracking API usage, enforcing limits, and selecting credentials. + +Public API: + UsageManager: Main facade for usage tracking and credential selection + CredentialContext: Context manager for credential lifecycle + +Components (for advanced usage): + CredentialRegistry: Stable credential identity management + TrackingEngine: Usage recording and window management + LimitEngine: Limit checking and enforcement + SelectionEngine: Credential selection with strategies + UsageStorage: JSON file persistence +""" + +# Types first (no dependencies on other modules) +from .types import ( + UsageStats, + WindowStats, + CredentialState, + CooldownInfo, + FairCycleState, + SelectionContext, + LimitCheckResult, + RotationMode, + ResetMode, + LimitResult, +) + +# Config +from .config import ( + ProviderUsageConfig, + FairCycleConfig, + CustomCapConfig, + WindowDefinition, + load_provider_usage_config, +) + +# Components +from .identity.registry import CredentialRegistry +from .tracking.windows import WindowManager +from .tracking.engine import TrackingEngine +from .limits.engine import LimitEngine +from .selection.engine import SelectionEngine +from .persistence.storage import UsageStorage +from .integration.api import UsageAPI + +# Main facade (imports components above) +from .manager import UsageManager, CredentialContext + +__all__ = [ + # Main public API + "UsageManager", + "CredentialContext", + # Types + "UsageStats", + "WindowStats", + "CredentialState", + "CooldownInfo", + "FairCycleState", + "SelectionContext", + "LimitCheckResult", + "RotationMode", + "ResetMode", + "LimitResult", + # Config + "ProviderUsageConfig", + "FairCycleConfig", + "CustomCapConfig", + "WindowDefinition", + "load_provider_usage_config", + # Engines + "CredentialRegistry", + "WindowManager", + "TrackingEngine", + "LimitEngine", + "SelectionEngine", + "UsageStorage", + "UsageAPI", +] diff --git a/src/rotator_library/usage/config.py b/src/rotator_library/usage/config.py new file mode 100644 index 00000000..94bc6fda --- /dev/null +++ b/src/rotator_library/usage/config.py @@ -0,0 +1,512 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Default configurations for the usage tracking package. + +This module contains default values and configuration loading +for usage tracking, limits, and credential selection. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +from ..core.constants import ( + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + DEFAULT_ROTATION_TOLERANCE, + DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, +) +from .types import ResetMode, RotationMode, TrackingMode, CooldownMode + + +# ============================================================================= +# WINDOW CONFIGURATION +# ============================================================================= + + +@dataclass +class WindowDefinition: + """ + Definition of a usage tracking window. + + Used to configure how usage is tracked and when it resets. + """ + + name: str # e.g., "5h", "daily", "weekly" + duration_seconds: Optional[int] # None for infinite/total + reset_mode: ResetMode + is_primary: bool = False # Primary window used for rotation decisions + applies_to: str = "model" # "credential", "model", "group" + + @classmethod + def rolling( + cls, + name: str, + duration_seconds: int, + is_primary: bool = False, + applies_to: str = "model", + ) -> "WindowDefinition": + """Create a rolling window definition.""" + return cls( + name=name, + duration_seconds=duration_seconds, + reset_mode=ResetMode.ROLLING, + is_primary=is_primary, + applies_to=applies_to, + ) + + @classmethod + def daily( + cls, + name: str = "daily", + applies_to: str = "model", + ) -> "WindowDefinition": + """Create a daily fixed window definition.""" + return cls( + name=name, + duration_seconds=86400, + reset_mode=ResetMode.FIXED_DAILY, + applies_to=applies_to, + ) + + @classmethod + def total( + cls, + name: str = "total", + applies_to: str = "model", + ) -> "WindowDefinition": + """Create a total/infinite window definition.""" + return cls( + name=name, + duration_seconds=None, + reset_mode=ResetMode.ROLLING, + applies_to=applies_to, + ) + + +# ============================================================================= +# FAIR CYCLE CONFIGURATION +# ============================================================================= + + +@dataclass +class FairCycleConfig: + """ + Fair cycle rotation configuration. + + Controls how credentials are cycled to ensure fair usage distribution. + """ + + enabled: Optional[bool] = ( + None # None = derive from rotation mode (on for sequential) + ) + tracking_mode: TrackingMode = TrackingMode.MODEL_GROUP + cross_tier: bool = False # Track across all tiers + duration: int = DEFAULT_FAIR_CYCLE_DURATION # Cycle duration in seconds + + +# ============================================================================= +# CUSTOM CAP CONFIGURATION +# ============================================================================= + + +@dataclass +class CustomCapConfig: + """ + Custom cap configuration for a tier/model combination. + + Allows setting usage limits more restrictive than actual API limits. + """ + + tier_key: str # Priority as string or "default" + model_or_group: str # Model name or quota group name + max_requests: int # Maximum requests allowed + cooldown_mode: CooldownMode = CooldownMode.QUOTA_RESET + cooldown_value: int = 0 # Seconds for offset/fixed modes + + @classmethod + def from_dict( + cls, tier_key: str, model_or_group: str, config: Dict[str, Any] + ) -> "CustomCapConfig": + """Create from dictionary config.""" + max_requests = config.get("max_requests", 0) + + # Handle percentage strings like "80%" + if isinstance(max_requests, str) and max_requests.endswith("%"): + # Store as negative to indicate percentage + # Will be resolved later when actual limit is known + max_requests = -int(max_requests.rstrip("%")) + + return cls( + tier_key=tier_key, + model_or_group=model_or_group, + max_requests=max_requests, + cooldown_mode=CooldownMode(config.get("cooldown_mode", "quota_reset")), + cooldown_value=config.get("cooldown_value", 0), + ) + + +# ============================================================================= +# PROVIDER USAGE CONFIG +# ============================================================================= + + +@dataclass +class ProviderUsageConfig: + """ + Complete usage configuration for a provider. + + Combines all settings needed for usage tracking and credential selection. + """ + + # Rotation settings + rotation_mode: RotationMode = RotationMode.BALANCED + rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE + sequential_fallback_multiplier: int = DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER + + # Priority multipliers (priority -> max concurrent) + priority_multipliers: Dict[int, int] = field(default_factory=dict) + priority_multipliers_by_mode: Dict[str, Dict[int, int]] = field( + default_factory=dict + ) + + # Fair cycle + fair_cycle: FairCycleConfig = field(default_factory=FairCycleConfig) + + # Custom caps + custom_caps: List[CustomCapConfig] = field(default_factory=list) + + # Exhaustion threshold (cooldown must exceed this to count as "exhausted") + exhaustion_cooldown_threshold: int = DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD + + # Window definitions + windows: List[WindowDefinition] = field(default_factory=list) + + def get_effective_multiplier(self, priority: int) -> int: + """ + Get the effective multiplier for a priority level. + + Checks mode-specific overrides first, then universal multipliers, + then falls back to sequential_fallback_multiplier. + """ + mode_key = self.rotation_mode.value + mode_multipliers = self.priority_multipliers_by_mode.get(mode_key, {}) + + # Check mode-specific first + if priority in mode_multipliers: + return mode_multipliers[priority] + + # Check universal + if priority in self.priority_multipliers: + return self.priority_multipliers[priority] + + # Fall back + return self.sequential_fallback_multiplier + + +# ============================================================================= +# DEFAULT WINDOWS +# ============================================================================= + + +def get_default_windows() -> List[WindowDefinition]: + """ + Get default window definitions. + + Most providers use a 5-hour rolling window as primary. + """ + return [ + WindowDefinition.rolling("5h", 18000, is_primary=True, applies_to="model"), + WindowDefinition.daily("daily", applies_to="model"), + WindowDefinition.total("total", applies_to="model"), + ] + + +# ============================================================================= +# CONFIG LOADER INTEGRATION +# ============================================================================= + + +def load_provider_usage_config( + provider: str, + provider_plugins: Dict[str, Any], +) -> ProviderUsageConfig: + """ + Load usage configuration for a provider. + + Merges: + 1. System defaults + 2. Provider class attributes + 3. Environment variables (always win) + + Args: + provider: Provider name (e.g., "gemini", "openai") + provider_plugins: Dict of provider plugin classes + + Returns: + Complete configuration for the provider + """ + import os + + config = ProviderUsageConfig() + + # Get plugin class + plugin_class = provider_plugins.get(provider) + + # Apply provider defaults + if plugin_class: + # Rotation mode + if hasattr(plugin_class, "default_rotation_mode"): + config.rotation_mode = RotationMode(plugin_class.default_rotation_mode) + + # Priority multipliers + if hasattr(plugin_class, "default_priority_multipliers"): + config.priority_multipliers = dict( + plugin_class.default_priority_multipliers + ) + + if hasattr(plugin_class, "default_priority_multipliers_by_mode"): + config.priority_multipliers_by_mode = { + k: dict(v) + for k, v in plugin_class.default_priority_multipliers_by_mode.items() + } + + # Sequential fallback multiplier + if hasattr(plugin_class, "default_sequential_fallback_multiplier"): + fallback = plugin_class.default_sequential_fallback_multiplier + if fallback is not None: + config.sequential_fallback_multiplier = fallback + + # Fair cycle + if hasattr(plugin_class, "default_fair_cycle_config"): + fc_config = plugin_class.default_fair_cycle_config + config.fair_cycle = FairCycleConfig( + enabled=fc_config.get("enabled"), + tracking_mode=TrackingMode( + fc_config.get("tracking_mode", "model_group") + ), + cross_tier=fc_config.get("cross_tier", False), + duration=fc_config.get("duration", DEFAULT_FAIR_CYCLE_DURATION), + ) + else: + if hasattr(plugin_class, "default_fair_cycle_enabled"): + config.fair_cycle.enabled = plugin_class.default_fair_cycle_enabled + if hasattr(plugin_class, "default_fair_cycle_tracking_mode"): + config.fair_cycle.tracking_mode = TrackingMode( + plugin_class.default_fair_cycle_tracking_mode + ) + if hasattr(plugin_class, "default_fair_cycle_cross_tier"): + config.fair_cycle.cross_tier = ( + plugin_class.default_fair_cycle_cross_tier + ) + if hasattr(plugin_class, "default_fair_cycle_duration"): + config.fair_cycle.duration = plugin_class.default_fair_cycle_duration + + # Custom caps + if hasattr(plugin_class, "default_custom_caps"): + for tier_key, models in plugin_class.default_custom_caps.items(): + tier_keys: Tuple[Union[int, str], ...] + if isinstance(tier_key, tuple): + tier_keys = tuple(tier_key) + else: + tier_keys = (tier_key,) + for model_or_group, cap_config in models.items(): + for resolved_tier in tier_keys: + config.custom_caps.append( + CustomCapConfig.from_dict( + str(resolved_tier), model_or_group, cap_config + ) + ) + + # Windows + if hasattr(plugin_class, "usage_window_definitions"): + config.windows = [] + for wdef in plugin_class.usage_window_definitions: + config.windows.append( + WindowDefinition( + name=wdef.get("name", "default"), + duration_seconds=wdef.get("duration_seconds"), + reset_mode=ResetMode(wdef.get("reset_mode", "rolling")), + is_primary=wdef.get("is_primary", False), + applies_to=wdef.get("applies_to", "model"), + ) + ) + + # Use default windows if none defined + if not config.windows: + config.windows = get_default_windows() + + # Apply environment variable overrides + provider_upper = provider.upper() + + # Rotation mode from env + env_mode = os.getenv(f"ROTATION_MODE_{provider_upper}") + if env_mode: + config.rotation_mode = RotationMode(env_mode.lower()) + + # Sequential fallback multiplier + env_fallback = os.getenv(f"SEQUENTIAL_FALLBACK_MULTIPLIER_{provider_upper}") + if env_fallback: + try: + config.sequential_fallback_multiplier = int(env_fallback) + except ValueError: + pass + + # Fair cycle enabled from env + env_fc = os.getenv(f"FAIR_CYCLE_{provider_upper}") + if env_fc is None: + env_fc = os.getenv(f"FAIR_CYCLE_ENABLED_{provider_upper}") + if env_fc: + config.fair_cycle.enabled = env_fc.lower() in ("true", "1", "yes") + + # Fair cycle tracking mode + env_fc_mode = os.getenv(f"FAIR_CYCLE_TRACKING_MODE_{provider_upper}") + if env_fc_mode: + try: + config.fair_cycle.tracking_mode = TrackingMode(env_fc_mode.lower()) + except ValueError: + pass + + # Fair cycle cross-tier + env_fc_cross = os.getenv(f"FAIR_CYCLE_CROSS_TIER_{provider_upper}") + if env_fc_cross: + config.fair_cycle.cross_tier = env_fc_cross.lower() in ("true", "1", "yes") + + # Fair cycle duration from env + env_fc_duration = os.getenv(f"FAIR_CYCLE_DURATION_{provider_upper}") + if env_fc_duration: + try: + config.fair_cycle.duration = int(env_fc_duration) + except ValueError: + pass + + # Exhaustion threshold from env + env_threshold = os.getenv(f"EXHAUSTION_COOLDOWN_THRESHOLD_{provider_upper}") + if env_threshold: + try: + config.exhaustion_cooldown_threshold = int(env_threshold) + except ValueError: + pass + + # Priority multipliers from env + # Format: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}=value + # Format: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}_{MODE}=value + for key, value in os.environ.items(): + prefix = f"CONCURRENCY_MULTIPLIER_{provider_upper}_PRIORITY_" + if key.startswith(prefix): + try: + remainder = key[len(prefix) :] + multiplier = int(value) + if multiplier < 1: + continue + if "_" in remainder: + priority_str, mode = remainder.rsplit("_", 1) + priority = int(priority_str) + mode = mode.lower() + if mode in ("sequential", "balanced"): + config.priority_multipliers_by_mode.setdefault(mode, {})[ + priority + ] = multiplier + else: + config.priority_multipliers[priority] = multiplier + else: + priority = int(remainder) + config.priority_multipliers[priority] = multiplier + except ValueError: + pass + + # Custom caps from env + if os.environ: + cap_map: Dict[str, Dict[str, Dict[str, Any]]] = {} + for cap in config.custom_caps: + cap_entry = cap_map.setdefault(str(cap.tier_key), {}) + cap_entry[cap.model_or_group] = { + "max_requests": cap.max_requests, + "cooldown_mode": cap.cooldown_mode.value, + "cooldown_value": cap.cooldown_value, + } + + cap_prefix = f"CUSTOM_CAP_{provider_upper}_T" + cooldown_prefix = f"CUSTOM_CAP_COOLDOWN_{provider_upper}_T" + for env_key, env_value in os.environ.items(): + if env_key.startswith(cap_prefix) and not env_key.startswith( + cooldown_prefix + ): + remainder = env_key[len(cap_prefix) :] + tier_key, model_key = _parse_custom_cap_env_key(remainder) + if tier_key is None or not model_key: + continue + cap_entry = cap_map.setdefault(str(tier_key), {}) + cap_entry.setdefault(model_key, {})["max_requests"] = env_value + elif env_key.startswith(cooldown_prefix): + remainder = env_key[len(cooldown_prefix) :] + tier_key, model_key = _parse_custom_cap_env_key(remainder) + if tier_key is None or not model_key: + continue + if ":" in env_value: + mode, value_str = env_value.split(":", 1) + try: + value = int(value_str) + except ValueError: + continue + else: + mode = env_value + value = 0 + cap_entry = cap_map.setdefault(str(tier_key), {}) + cap_entry.setdefault(model_key, {})["cooldown_mode"] = mode + cap_entry.setdefault(model_key, {})["cooldown_value"] = value + + config.custom_caps = [] + for tier_key, models in cap_map.items(): + for model_or_group, cap_config in models.items(): + config.custom_caps.append( + CustomCapConfig.from_dict(tier_key, model_or_group, cap_config) + ) + + # Derive fair cycle enabled from rotation mode if not explicitly set + if config.fair_cycle.enabled is None: + config.fair_cycle.enabled = config.rotation_mode == RotationMode.SEQUENTIAL + + return config + + +def _parse_custom_cap_env_key( + remainder: str, +) -> Tuple[Optional[Union[int, Tuple[int, ...], str]], Optional[str]]: + """Parse the tier and model/group from a custom cap env var remainder.""" + if not remainder: + return None, None + + remaining_parts = remainder.split("_") + if len(remaining_parts) < 2: + return None, None + + tier_key: Union[int, Tuple[int, ...], str, None] = None + model_key: Optional[str] = None + tier_parts: List[int] = [] + + for i, part in enumerate(remaining_parts): + if part == "DEFAULT": + tier_key = "default" + model_key = "_".join(remaining_parts[i + 1 :]) + break + if part.isdigit(): + tier_parts.append(int(part)) + continue + + if not tier_parts: + return None, None + if len(tier_parts) == 1: + tier_key = tier_parts[0] + else: + tier_key = tuple(tier_parts) + model_key = "_".join(remaining_parts[i:]) + break + else: + return None, None + + if model_key: + model_key = model_key.lower().replace("_", "-") + + return tier_key, model_key diff --git a/src/rotator_library/usage/identity/__init__.py b/src/rotator_library/usage/identity/__init__.py new file mode 100644 index 00000000..b42af7ff --- /dev/null +++ b/src/rotator_library/usage/identity/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Credential identity management.""" + +from .registry import CredentialRegistry + +__all__ = ["CredentialRegistry"] diff --git a/src/rotator_library/usage/identity/registry.py b/src/rotator_library/usage/identity/registry.py new file mode 100644 index 00000000..ddd867d4 --- /dev/null +++ b/src/rotator_library/usage/identity/registry.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Credential identity registry. + +Provides stable identifiers for credentials that persist across +file path changes (for OAuth) and hide sensitive data (for API keys). +""" + +import hashlib +import json +import logging +from pathlib import Path +from typing import Any, Dict, Optional, Set + +from ...core.types import CredentialInfo + +lib_logger = logging.getLogger("rotator_library") + + +class CredentialRegistry: + """ + Manages stable identifiers for credentials. + + Stable IDs are: + - For OAuth credentials: The email address from _proxy_metadata.email + - For API keys: SHA-256 hash of the key (truncated for readability) + + This ensures usage data persists even when: + - OAuth credential files are moved/renamed + - API keys are passed in different orders + """ + + def __init__(self): + # Cache: accessor -> CredentialInfo + self._cache: Dict[str, CredentialInfo] = {} + # Reverse index: stable_id -> accessor + self._id_to_accessor: Dict[str, str] = {} + + def get_stable_id(self, accessor: str, provider: str) -> str: + """ + Get or create a stable ID for a credential accessor. + + Args: + accessor: The credential accessor (file path or API key) + provider: Provider name + + Returns: + Stable identifier string + """ + # Check cache first + if accessor in self._cache: + return self._cache[accessor].stable_id + + # Determine if OAuth or API key + if self._is_oauth_path(accessor): + stable_id = self._get_oauth_stable_id(accessor) + else: + stable_id = self._get_api_key_stable_id(accessor) + + # Cache the result + info = CredentialInfo( + accessor=accessor, + stable_id=stable_id, + provider=provider, + ) + self._cache[accessor] = info + self._id_to_accessor[stable_id] = accessor + + return stable_id + + def get_info(self, accessor: str, provider: str) -> CredentialInfo: + """ + Get complete credential info for an accessor. + + Args: + accessor: The credential accessor + provider: Provider name + + Returns: + CredentialInfo with stable_id and metadata + """ + # Ensure stable ID is computed + self.get_stable_id(accessor, provider) + return self._cache[accessor] + + def get_accessor(self, stable_id: str) -> Optional[str]: + """ + Get the current accessor for a stable ID. + + Args: + stable_id: The stable identifier + + Returns: + Current accessor string, or None if not found + """ + return self._id_to_accessor.get(stable_id) + + def update_accessor(self, stable_id: str, new_accessor: str) -> None: + """ + Update the accessor for a stable ID. + + Used when an OAuth credential file is moved/renamed. + + Args: + stable_id: The stable identifier + new_accessor: New accessor path + """ + old_accessor = self._id_to_accessor.get(stable_id) + if old_accessor and old_accessor in self._cache: + info = self._cache.pop(old_accessor) + info.accessor = new_accessor + self._cache[new_accessor] = info + self._id_to_accessor[stable_id] = new_accessor + + def update_metadata( + self, + accessor: str, + provider: str, + tier: Optional[str] = None, + priority: Optional[int] = None, + display_name: Optional[str] = None, + ) -> None: + """ + Update metadata for a credential. + + Args: + accessor: The credential accessor + provider: Provider name + tier: Tier name (e.g., "standard-tier") + priority: Priority level (lower = higher priority) + display_name: Human-readable name + """ + info = self.get_info(accessor, provider) + if tier is not None: + info.tier = tier + if priority is not None: + info.priority = priority + if display_name is not None: + info.display_name = display_name + + def get_all_accessors(self) -> Set[str]: + """Get all registered accessors.""" + return set(self._cache.keys()) + + def get_all_stable_ids(self) -> Set[str]: + """Get all registered stable IDs.""" + return set(self._id_to_accessor.keys()) + + def clear_cache(self) -> None: + """Clear the internal cache.""" + self._cache.clear() + self._id_to_accessor.clear() + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _is_oauth_path(self, accessor: str) -> bool: + """ + Check if accessor is an OAuth credential file path. + + OAuth paths typically end with .json and exist on disk. + API keys are typically raw strings. + """ + # Simple heuristic: if it looks like a file path with .json, it's OAuth + if accessor.endswith(".json"): + return True + # If it contains path separators, it's likely a file path + if "/" in accessor or "\\" in accessor: + return True + return False + + def _get_oauth_stable_id(self, accessor: str) -> str: + """ + Get stable ID for an OAuth credential. + + Reads the email from _proxy_metadata.email in the credential file. + Falls back to file hash if email not found. + """ + try: + path = Path(accessor) + if path.exists(): + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + # Try to get email from _proxy_metadata + metadata = data.get("_proxy_metadata", {}) + email = metadata.get("email") + if email: + return email + + # Fallback: try common OAuth fields + for field in ["email", "client_email", "account"]: + if field in data: + return data[field] + + # Last resort: hash the file content + lib_logger.debug( + f"No email found in OAuth credential {accessor}, using content hash" + ) + return self._hash_content(json.dumps(data, sort_keys=True)) + + except Exception as e: + lib_logger.warning(f"Failed to read OAuth credential {accessor}: {e}") + + # Fallback: hash the path + return self._hash_content(accessor) + + def _get_api_key_stable_id(self, accessor: str) -> str: + """ + Get stable ID for an API key. + + Uses truncated SHA-256 hash to hide the actual key. + """ + return self._hash_content(accessor) + + def _hash_content(self, content: str) -> str: + """ + Create a stable hash of content. + + Uses first 12 characters of SHA-256 for readability. + """ + return hashlib.sha256(content.encode()).hexdigest()[:12] + + # ========================================================================= + # SERIALIZATION + # ========================================================================= + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize registry state for persistence. + + Returns: + Dictionary suitable for JSON serialization + """ + return { + "accessor_index": dict(self._id_to_accessor), + "credentials": { + accessor: { + "stable_id": info.stable_id, + "provider": info.provider, + "tier": info.tier, + "priority": info.priority, + "display_name": info.display_name, + } + for accessor, info in self._cache.items() + }, + } + + def from_dict(self, data: Dict[str, Any]) -> None: + """ + Restore registry state from persistence. + + Args: + data: Dictionary from to_dict() + """ + self._id_to_accessor = dict(data.get("accessor_index", {})) + + for accessor, cred_data in data.get("credentials", {}).items(): + info = CredentialInfo( + accessor=accessor, + stable_id=cred_data["stable_id"], + provider=cred_data["provider"], + tier=cred_data.get("tier"), + priority=cred_data.get("priority", 999), + display_name=cred_data.get("display_name"), + ) + self._cache[accessor] = info diff --git a/src/rotator_library/usage/integration/__init__.py b/src/rotator_library/usage/integration/__init__.py new file mode 100644 index 00000000..4e58bdd8 --- /dev/null +++ b/src/rotator_library/usage/integration/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Integration helpers for usage manager.""" + +from .hooks import HookDispatcher +from .api import UsageAPI + +__all__ = ["HookDispatcher", "UsageAPI"] diff --git a/src/rotator_library/usage/integration/api.py b/src/rotator_library/usage/integration/api.py new file mode 100644 index 00000000..a418fa71 --- /dev/null +++ b/src/rotator_library/usage/integration/api.py @@ -0,0 +1,373 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Usage API Facade for Reading and Updating Usage Data. + +This module provides a clean, public API for programmatically interacting with +usage data. It's accessible via `usage_manager.api` and is intended for: + + - Admin endpoints (viewing/modifying credential state) + - Background jobs (quota refresh, cleanup tasks) + - Monitoring and alerting (checking remaining quota) + - External tooling and integrations + - Provider-specific logic that needs to inspect/modify state + +============================================================================= +ACCESSING THE API +============================================================================= + +The UsageAPI is available as a property on UsageManager: + + # From RotatingClient + usage_manager = client.get_usage_manager("my_provider") + api = usage_manager.api + + # Or if you have the manager directly + api = usage_manager.api + +============================================================================= +AVAILABLE METHODS +============================================================================= + +Reading State +------------- + + # Get state for a specific credential + state = api.get_state("path/to/credential.json") + if state: + print(f"Total requests: {state.usage.total_requests}") + print(f"Total successes: {state.usage.total_successes}") + print(f"Total failures: {state.usage.total_failures}") + + # Get all credential states + all_states = api.get_all_states() + for stable_id, state in all_states.items(): + print(f"{stable_id}: {state.usage.total_requests} requests") + + # Check remaining quota in a window + remaining = api.get_window_remaining( + accessor="path/to/credential.json", + window_name="5h", + model="gpt-4o", # Optional: specific model + quota_group="gpt4", # Optional: quota group + ) + print(f"Remaining in 5h window: {remaining}") + +Modifying State +--------------- + + # Apply a manual cooldown + await api.apply_cooldown( + accessor="path/to/credential.json", + duration=1800.0, # 30 minutes + reason="manual_override", + model_or_group="gpt4", # Optional: scope to model/group + ) + + # Clear a cooldown + await api.clear_cooldown( + accessor="path/to/credential.json", + model_or_group="gpt4", # Optional: scope + ) + + # Mark credential as exhausted for fair cycle + await api.mark_exhausted( + accessor="path/to/credential.json", + model_or_group="gpt4", + reason="quota_exceeded", + ) + +============================================================================= +CREDENTIAL STATE STRUCTURE +============================================================================= + +CredentialState contains: + + state.accessor # File path or API key + state.display_name # Human-readable name (e.g., email) + state.tier # Tier name (e.g., "standard-tier") + state.priority # Priority level (1 = highest) + state.active_requests # Currently in-flight requests + + state.usage # UsageStats - global totals + state.model_usage # Dict[model, UsageStats] + state.group_usage # Dict[group, UsageStats] + + state.cooldowns # Dict[key, CooldownState] + state.fair_cycle # Dict[key, FairCycleState] + +UsageStats contains: + + usage.total_requests + usage.total_successes + usage.total_failures + usage.total_tokens + usage.total_prompt_tokens + usage.total_completion_tokens + usage.total_thinking_tokens + usage.total_output_tokens + usage.total_prompt_tokens_cache_read + usage.total_prompt_tokens_cache_write + usage.total_approx_cost + usage.first_used_at + usage.last_used_at + usage.windows # Dict[name, WindowStats] + +WindowStats contains: + + window.request_count + window.success_count + window.failure_count + window.prompt_tokens + window.completion_tokens + window.thinking_tokens + window.output_tokens + window.prompt_tokens_cache_read + window.prompt_tokens_cache_write + window.total_tokens + window.approx_cost + window.started_at + window.reset_at + window.limit + window.remaining # Computed: limit - request_count (if limit set) + +============================================================================= +EXAMPLE: BUILDING AN ADMIN ENDPOINT +============================================================================= + + from fastapi import APIRouter + from rotator_library import RotatingClient + + router = APIRouter() + + @router.get("/admin/credentials/{provider}") + async def list_credentials(provider: str): + usage_manager = client.get_usage_manager(provider) + if not usage_manager: + return {"error": "Provider not found"} + + api = usage_manager.api + result = [] + + for stable_id, state in api.get_all_states().items(): + result.append({ + "id": stable_id, + "accessor": state.accessor, + "tier": state.tier, + "priority": state.priority, + "requests": state.usage.total_requests, + "successes": state.usage.total_successes, + "failures": state.usage.total_failures, + "cooldowns": [ + {"key": k, "remaining": v.remaining_seconds} + for k, v in state.cooldowns.items() + if v.is_active + ], + }) + + return {"credentials": result} + + @router.post("/admin/credentials/{provider}/{accessor}/cooldown") + async def apply_cooldown(provider: str, accessor: str, duration: float): + usage_manager = client.get_usage_manager(provider) + api = usage_manager.api + await api.apply_cooldown(accessor, duration, reason="admin") + return {"status": "cooldown applied"} + +============================================================================= +""" + +from typing import Any, Dict, Optional, TYPE_CHECKING + +from ..types import CredentialState + +if TYPE_CHECKING: + from ..manager import UsageManager + + +class UsageAPI: + """ + Public API facade for reading and updating usage data. + + Provides a clean interface for external code to interact with usage + tracking without needing to understand the internal component structure. + + Access via: usage_manager.api + + Example: + api = usage_manager.api + state = api.get_state("path/to/credential.json") + remaining = api.get_window_remaining("path/to/cred.json", "5h", "gpt-4o") + await api.apply_cooldown("path/to/cred.json", 1800.0, "manual") + """ + + def __init__(self, manager: "UsageManager"): + """ + Initialize the API facade. + + Args: + manager: The UsageManager instance to wrap. + """ + self._manager = manager + + def get_state(self, accessor: str) -> Optional[CredentialState]: + """ + Get the credential state for a given accessor. + + Args: + accessor: Credential file path or API key. + + Returns: + CredentialState if found, None otherwise. + + Example: + state = api.get_state("oauth_creds/my_cred.json") + if state: + print(f"Requests: {state.usage.total_requests}") + """ + stable_id = self._manager.registry.get_stable_id( + accessor, self._manager.provider + ) + return self._manager.states.get(stable_id) + + def get_all_states(self) -> Dict[str, CredentialState]: + """ + Get all credential states. + + Returns: + Dict mapping stable_id to CredentialState. + + Example: + for stable_id, state in api.get_all_states().items(): + print(f"{stable_id}: {state.usage.total_requests} requests") + """ + return dict(self._manager.states) + + def get_window_remaining( + self, + accessor: str, + window_name: str, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> Optional[int]: + """ + Get remaining requests in a usage window. + + Args: + accessor: Credential file path or API key. + window_name: Window name (e.g., "5h", "daily"). + model: Optional model to check (uses model-specific window). + quota_group: Optional quota group to check. + + Returns: + Remaining requests (limit - used), or None if: + - Credential not found + - Window has no limit set + + Example: + remaining = api.get_window_remaining("cred.json", "5h", model="gpt-4o") + if remaining is not None and remaining < 10: + print("Warning: low quota remaining") + """ + state = self.get_state(accessor) + if not state: + return None + return self._manager.limits.window_checker.get_remaining( + state, window_name, model=model, quota_group=quota_group + ) + + async def apply_cooldown( + self, + accessor: str, + duration: float, + reason: str = "manual", + model_or_group: Optional[str] = None, + ) -> None: + """ + Apply a cooldown to a credential. + + The credential will not be selected for requests until the cooldown + expires or is cleared. + + Args: + accessor: Credential file path or API key. + duration: Cooldown duration in seconds. + reason: Reason for cooldown (for logging/debugging). + model_or_group: Optional scope (model name or quota group). + If None, applies to credential globally. + + Example: + # Global cooldown + await api.apply_cooldown("cred.json", 1800.0, "maintenance") + + # Model-specific cooldown + await api.apply_cooldown("cred.json", 3600.0, "quota", "gpt-4o") + """ + await self._manager.apply_cooldown( + accessor=accessor, + duration=duration, + reason=reason, + model_or_group=model_or_group, + ) + + async def clear_cooldown( + self, + accessor: str, + model_or_group: Optional[str] = None, + ) -> None: + """ + Clear a cooldown from a credential. + + Args: + accessor: Credential file path or API key. + model_or_group: Optional scope to clear. If None, clears all. + + Example: + # Clear specific cooldown + await api.clear_cooldown("cred.json", "gpt-4o") + + # Clear all cooldowns + await api.clear_cooldown("cred.json") + """ + stable_id = self._manager.registry.get_stable_id( + accessor, self._manager.provider + ) + state = self._manager.states.get(stable_id) + if state: + await self._manager.tracking.clear_cooldown( + state=state, + model_or_group=model_or_group, + ) + + async def mark_exhausted( + self, + accessor: str, + model_or_group: str, + reason: str, + ) -> None: + """ + Mark a credential as exhausted for fair cycle. + + The credential will be skipped during selection until all other + credentials in the same tier are also exhausted, at which point + the fair cycle resets. + + Args: + accessor: Credential file path or API key. + model_or_group: Model name or quota group to mark exhausted. + reason: Reason for exhaustion (for logging/debugging). + + Example: + await api.mark_exhausted("cred.json", "gpt-4o", "quota_exceeded") + """ + stable_id = self._manager.registry.get_stable_id( + accessor, self._manager.provider + ) + state = self._manager.states.get(stable_id) + if state: + await self._manager.tracking.mark_exhausted( + state=state, + model_or_group=model_or_group, + reason=reason, + ) diff --git a/src/rotator_library/usage/integration/hooks.py b/src/rotator_library/usage/integration/hooks.py new file mode 100644 index 00000000..129b3f42 --- /dev/null +++ b/src/rotator_library/usage/integration/hooks.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Provider Hook Dispatcher for Usage Manager. + +This module bridges provider plugins to the usage manager, allowing providers +to customize how requests are counted, cooled down, and tracked. + +============================================================================= +OVERVIEW +============================================================================= + +The HookDispatcher calls provider hooks at key points in the request lifecycle. +Currently, the main hook is `on_request_complete`, which is called after every +request (success or failure) and allows the provider to override: + + - Request count (how many requests to record) + - Cooldown duration (custom cooldown to apply) + - Exhaustion state (mark credential for fair cycle) + +============================================================================= +IMPLEMENTING on_request_complete IN YOUR PROVIDER +============================================================================= + +Add this method to your provider class: + + from rotator_library.core.types import RequestCompleteResult + + def on_request_complete( + self, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional[RequestCompleteResult]: + ''' + Called after each request completes. + + Args: + credential: Credential accessor (file path or API key) + model: Model that was called + success: Whether the request succeeded + response: Response object (if success=True) + error: ClassifiedError object (if success=False) + + Returns: + RequestCompleteResult to override behavior, or None for defaults. + ''' + # Your logic here + return None # Use default behavior + +============================================================================= +RequestCompleteResult FIELDS +============================================================================= + + count_override: Optional[int] + How many requests to count for usage tracking. + - 0 = Don't count this request (e.g., server errors) + - N = Count as N requests (e.g., internal retries) + - None = Use default (1) + + cooldown_override: Optional[float] + Seconds to cool down this credential. + - Applied in addition to any error-based cooldown. + - Use for custom rate limiting logic. + + force_exhausted: bool + Mark credential as exhausted for fair cycle. + - True = Skip this credential until fair cycle resets. + - Useful for quota errors without long cooldowns. + +============================================================================= +USE CASE: COUNTING INTERNAL RETRIES +============================================================================= + +If your provider performs internal retries (e.g., for transient errors, empty +responses, or malformed responses), each retry is an API call that should be +counted. Use the ContextVar pattern for thread-safe counting: + + from contextvars import ContextVar + from rotator_library.core.types import RequestCompleteResult + + # Module-level: each async task gets its own isolated value + _internal_attempt_count: ContextVar[int] = ContextVar( + 'my_provider_attempt_count', default=1 + ) + + class MyProvider: + + async def _make_request_with_retry(self, ...): + # Reset at start of request + _internal_attempt_count.set(1) + + for attempt in range(max_attempts): + try: + result = await self._call_api(...) + return result # Success + except RetryableError: + # Increment before retry + _internal_attempt_count.set(_internal_attempt_count.get() + 1) + continue + + def on_request_complete(self, credential, model, success, response, error): + # Report actual API call count + count = _internal_attempt_count.get() + _internal_attempt_count.set(1) # Reset for safety + + if count > 1: + logging.debug(f"Request used {count} API calls (internal retries)") + + return RequestCompleteResult(count_override=count) + +Why ContextVar? + - Instance variables (self.count) are shared across concurrent requests + - ContextVar gives each async task its own isolated value + - Thread-safe without explicit locking + +============================================================================= +USE CASE: CUSTOM ERROR HANDLING +============================================================================= + +Override counting or cooldown based on error type: + + def on_request_complete(self, credential, model, success, response, error): + if not success and error: + # Don't count server errors against quota + if error.error_type == "server_error": + return RequestCompleteResult(count_override=0) + + # Force exhaustion on quota errors + if error.error_type == "quota_exceeded": + return RequestCompleteResult( + force_exhausted=True, + cooldown_override=3600.0, # 1 hour + ) + + # Custom cooldown for rate limits + if error.error_type == "rate_limit": + retry_after = getattr(error, "retry_after", 60) + return RequestCompleteResult(cooldown_override=retry_after) + + return None # Default behavior + +============================================================================= +""" + +import asyncio +from typing import Any, Dict, Optional + +from ...core.types import RequestCompleteResult + + +class HookDispatcher: + """ + Dispatch optional provider hooks during request lifecycle. + + The HookDispatcher is instantiated by UsageManager with the provider plugins + dict. It lazily instantiates provider instances and calls their hooks. + + Currently supported hooks: + - on_request_complete: Called after each request completes + + Usage: + dispatcher = HookDispatcher(provider_plugins) + result = await dispatcher.dispatch_request_complete( + provider="my_provider", + credential="path/to/cred.json", + model="my-model", + success=True, + response=response_obj, + error=None, + ) + if result and result.count_override is not None: + request_count = result.count_override + """ + + def __init__(self, provider_plugins: Optional[Dict[str, Any]] = None): + """ + Initialize the hook dispatcher. + + Args: + provider_plugins: Dict mapping provider names to plugin classes. + Classes are lazily instantiated on first hook call. + """ + self._plugins = provider_plugins or {} + self._instances: Dict[str, Any] = {} + + def _get_instance(self, provider: str) -> Optional[Any]: + """Get or create a provider plugin instance.""" + if provider not in self._instances: + plugin_class = self._plugins.get(provider) + if not plugin_class: + return None + if isinstance(plugin_class, type): + self._instances[provider] = plugin_class() + else: + self._instances[provider] = plugin_class + return self._instances[provider] + + async def dispatch_request_complete( + self, + provider: str, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional[RequestCompleteResult]: + """ + Dispatch the on_request_complete hook to a provider. + + Called by UsageManager after each request completes (success or failure). + The provider can return a RequestCompleteResult to override default + behavior for request counting, cooldowns, or exhaustion marking. + + Args: + provider: Provider name (e.g., "antigravity", "openai") + credential: Credential accessor (file path or API key) + model: Model that was called (with provider prefix) + success: Whether the request succeeded + response: Response object if success=True, else None + error: ClassifiedError if success=False, else None + + Returns: + RequestCompleteResult from provider, or None if: + - Provider not found in plugins + - Provider doesn't implement on_request_complete + - Provider returns None (use default behavior) + """ + plugin = self._get_instance(provider) + if not plugin or not hasattr(plugin, "on_request_complete"): + return None + + result = plugin.on_request_complete(credential, model, success, response, error) + if asyncio.iscoroutine(result): + result = await result + + return result diff --git a/src/rotator_library/usage/limits/__init__.py b/src/rotator_library/usage/limits/__init__.py new file mode 100644 index 00000000..e759c312 --- /dev/null +++ b/src/rotator_library/usage/limits/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Limit checking and enforcement.""" + +from .engine import LimitEngine +from .base import LimitChecker +from .window_limits import WindowLimitChecker +from .cooldowns import CooldownChecker +from .fair_cycle import FairCycleChecker +from .custom_caps import CustomCapChecker + +__all__ = [ + "LimitEngine", + "LimitChecker", + "WindowLimitChecker", + "CooldownChecker", + "FairCycleChecker", + "CustomCapChecker", +] diff --git a/src/rotator_library/usage/limits/base.py b/src/rotator_library/usage/limits/base.py new file mode 100644 index 00000000..c6c0a4d8 --- /dev/null +++ b/src/rotator_library/usage/limits/base.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Base interface for limit checkers. + +All limit types implement this interface for consistent behavior. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult + + +class LimitChecker(ABC): + """ + Abstract base class for limit checkers. + + Each limit type (window, cooldown, fair cycle, custom cap) + implements this interface. + """ + + @property + @abstractmethod + def name(self) -> str: + """Name of this limit checker.""" + ... + + @abstractmethod + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if a credential passes this limit. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail and reason + """ + ... + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset this limit for a credential. + + Default implementation does nothing - override if needed. + + Args: + state: Credential state to reset + model: Optional model scope + quota_group: Optional quota group scope + """ + pass diff --git a/src/rotator_library/usage/limits/concurrent.py b/src/rotator_library/usage/limits/concurrent.py new file mode 100644 index 00000000..83a510cb --- /dev/null +++ b/src/rotator_library/usage/limits/concurrent.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Concurrent request limit checker. + +Blocks credentials that have reached their max_concurrent limit. +""" + +from typing import Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from .base import LimitChecker + + +class ConcurrentLimitChecker(LimitChecker): + """ + Checks concurrent request limits. + + Blocks credentials that have active_requests >= max_concurrent. + This ensures we don't overload any single credential. + """ + + @property + def name(self) -> str: + return "concurrent" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if credential is at max concurrent. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + # If no limit set, always allow + if state.max_concurrent is None: + return LimitCheckResult.ok() + + # Check if at or above limit + if state.active_requests >= state.max_concurrent: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_CONCURRENT, + reason=f"At max concurrent: {state.active_requests}/{state.max_concurrent}", + blocked_until=None, # No specific time - depends on request completion + ) + + return LimitCheckResult.ok() + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset concurrent count. + + Note: This is rarely needed as active_requests is + managed by acquire/release, not limit checking. + """ + # Typically don't reset active_requests via limit system + pass diff --git a/src/rotator_library/usage/limits/cooldowns.py b/src/rotator_library/usage/limits/cooldowns.py new file mode 100644 index 00000000..bf08a4a4 --- /dev/null +++ b/src/rotator_library/usage/limits/cooldowns.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Cooldown checker. + +Checks if a credential is currently in cooldown. +""" + +import time +from typing import Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from .base import LimitChecker + + +class CooldownChecker(LimitChecker): + """ + Checks cooldown status for credentials. + + Blocks credentials that are currently cooling down from + rate limits, errors, or other causes. + """ + + @property + def name(self) -> str: + return "cooldowns" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if credential is in cooldown. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + now = time.time() + group_key = quota_group or model + + # Check model/group-specific cooldowns + keys_to_check = [] + if group_key: + keys_to_check.append(group_key) + if quota_group and quota_group != model: + keys_to_check.append(model) + + for key in keys_to_check: + cooldown = state.cooldowns.get(key) + if cooldown and cooldown.until > now: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_COOLDOWN, + reason=f"Cooldown for '{key}': {cooldown.reason} (expires in {cooldown.remaining_seconds:.0f}s)", + blocked_until=cooldown.until, + ) + + # Check global cooldown + global_cooldown = state.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_COOLDOWN, + reason=f"Global cooldown: {global_cooldown.reason} (expires in {global_cooldown.remaining_seconds:.0f}s)", + blocked_until=global_cooldown.until, + ) + + return LimitCheckResult.ok() + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Clear cooldown for a credential. + + Args: + state: Credential state + model: Optional model scope + quota_group: Optional quota group scope + """ + if quota_group: + if quota_group in state.cooldowns: + del state.cooldowns[quota_group] + elif model: + if model in state.cooldowns: + del state.cooldowns[model] + else: + # Clear all cooldowns + state.cooldowns.clear() + + def get_cooldown_end( + self, + state: CredentialState, + model_or_group: Optional[str] = None, + ) -> Optional[float]: + """ + Get when cooldown ends for a credential. + + Args: + state: Credential state + model_or_group: Optional scope to check + + Returns: + Timestamp when cooldown ends, or None if not in cooldown + """ + now = time.time() + + # Check specific scope + if model_or_group: + cooldown = state.cooldowns.get(model_or_group) + if cooldown and cooldown.until > now: + return cooldown.until + + # Check global + global_cooldown = state.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + return global_cooldown.until + + return None diff --git a/src/rotator_library/usage/limits/custom_caps.py b/src/rotator_library/usage/limits/custom_caps.py new file mode 100644 index 00000000..cf318f4e --- /dev/null +++ b/src/rotator_library/usage/limits/custom_caps.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Custom cap limit checker. + +Enforces user-defined limits that are stricter than API limits. +""" + +import time +import logging +from typing import Dict, List, Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from ..config import CustomCapConfig, CooldownMode +from ..tracking.windows import WindowManager +from .base import LimitChecker + +lib_logger = logging.getLogger("rotator_library") + + +class CustomCapChecker(LimitChecker): + """ + Checks custom cap limits. + + Custom caps allow users to set limits more restrictive than + what the API allows, for cost control or other reasons. + """ + + def __init__( + self, + caps: List[CustomCapConfig], + window_manager: WindowManager, + ): + """ + Initialize custom cap checker. + + Args: + caps: List of custom cap configurations + window_manager: WindowManager for checking window usage + """ + self._caps = caps + self._windows = window_manager + # Index caps by (tier_key, model_or_group) for fast lookup + self._cap_index: Dict[tuple, CustomCapConfig] = {} + for cap in caps: + self._cap_index[(cap.tier_key, cap.model_or_group)] = cap + + @property + def name(self) -> str: + return "custom_caps" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if custom cap is exceeded. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + if not self._caps: + return LimitCheckResult.ok() + + group_key = quota_group or model + priority = state.priority + + # Find applicable cap + cap = self._find_cap(str(priority), group_key, model) + if cap is None: + return LimitCheckResult.ok() + + primary_def = self._windows.get_primary_definition() + if primary_def is None: + return LimitCheckResult.ok() + + usage = None + if quota_group and cap.model_or_group == group_key: + usage = state.get_usage_for_scope("group", group_key, create=False) + + if usage is None: + scope_key = None + if primary_def.applies_to == "model": + scope_key = model + elif primary_def.applies_to == "group": + scope_key = group_key + usage = state.get_usage_for_scope( + primary_def.applies_to, scope_key, create=False + ) + + if usage is None: + return LimitCheckResult.ok() + + # Get usage from primary window + primary_window = self._windows.get_primary_window(usage.windows) + if primary_window is None: + return LimitCheckResult.ok() + + current_usage = primary_window.request_count + max_requests = self._resolve_max_requests(cap, primary_window.limit) + + if current_usage >= max_requests: + # Calculate cooldown end + cooldown_until = self._calculate_cooldown_until(cap, primary_window) + + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_CUSTOM_CAP, + reason=f"Custom cap for '{group_key}' exceeded ({current_usage}/{max_requests})", + blocked_until=cooldown_until, + ) + + return LimitCheckResult.ok() + + def get_cap_for( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> Optional[CustomCapConfig]: + """ + Get the applicable custom cap for a credential/model. + + Args: + state: Credential state + model: Model name + quota_group: Quota group + + Returns: + CustomCapConfig if one applies, None otherwise + """ + group_key = quota_group or model + priority = state.priority + return self._find_cap(str(priority), group_key, model) + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _find_cap( + self, + priority_key: str, + group_key: str, + model: str, + ) -> Optional[CustomCapConfig]: + """Find the most specific applicable cap.""" + # Try exact matches first + # Priority + group + cap = self._cap_index.get((priority_key, group_key)) + if cap: + return cap + + # Priority + model (if different from group) + if model != group_key: + cap = self._cap_index.get((priority_key, model)) + if cap: + return cap + + # Default tier + group + cap = self._cap_index.get(("default", group_key)) + if cap: + return cap + + # Default tier + model + if model != group_key: + cap = self._cap_index.get(("default", model)) + if cap: + return cap + + return None + + def _resolve_max_requests( + self, + cap: CustomCapConfig, + window_limit: Optional[int], + ) -> int: + """ + Resolve max requests, handling percentage values. + + Custom caps can only be MORE restrictive than API limits, + so the result is clamped to window_limit if available. + """ + if cap.max_requests >= 0: + # Absolute value - clamp to window limit if known + if window_limit is not None: + return min(cap.max_requests, window_limit) + return cap.max_requests + + # Negative value indicates percentage + if window_limit is None: + # No window limit known, use a high default + return 1000 + + percentage = -cap.max_requests + calculated = int(window_limit * percentage / 100) + # Clamp to window limit (already is <= since percentage < 100 typically) + return min(calculated, window_limit) + + def _calculate_cooldown_until( + self, + cap: CustomCapConfig, + window: "WindowStats", + ) -> Optional[float]: + """Calculate when the custom cap cooldown ends.""" + now = time.time() + natural_reset = window.reset_at + + if cap.cooldown_mode == CooldownMode.QUOTA_RESET: + # Wait until window resets + return natural_reset + + elif cap.cooldown_mode == CooldownMode.OFFSET: + # Add offset to current time + calculated = now + cap.cooldown_value + return max(calculated, natural_reset) if natural_reset else calculated + + elif cap.cooldown_mode == CooldownMode.FIXED: + # Fixed duration + calculated = now + cap.cooldown_value + return max(calculated, natural_reset) if natural_reset else calculated + + return None diff --git a/src/rotator_library/usage/limits/engine.py b/src/rotator_library/usage/limits/engine.py new file mode 100644 index 00000000..d5f975c4 --- /dev/null +++ b/src/rotator_library/usage/limits/engine.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Limit engine for orchestrating limit checks. + +Central component that runs all limit checkers and determines +if a credential is available for use. +""" + +import logging +from typing import Dict, List, Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from ..config import ProviderUsageConfig +from ..tracking.windows import WindowManager +from .base import LimitChecker +from .concurrent import ConcurrentLimitChecker +from .window_limits import WindowLimitChecker +from .cooldowns import CooldownChecker +from .fair_cycle import FairCycleChecker +from .custom_caps import CustomCapChecker + +lib_logger = logging.getLogger("rotator_library") + + +class LimitEngine: + """ + Central engine for limit checking. + + Orchestrates all limit checkers and provides a single entry point + for determining credential availability. + """ + + def __init__( + self, + config: ProviderUsageConfig, + window_manager: WindowManager, + ): + """ + Initialize limit engine. + + Args: + config: Provider usage configuration + window_manager: WindowManager for window-based checks + """ + self._config = config + self._window_manager = window_manager + + # Initialize all limit checkers + # Order matters: concurrent first (fast check), then others + self._checkers: List[LimitChecker] = [ + ConcurrentLimitChecker(), + CooldownChecker(), + WindowLimitChecker(window_manager), + CustomCapChecker(config.custom_caps, window_manager), + FairCycleChecker(config.fair_cycle), + ] + + # Quick access to specific checkers + self._concurrent_checker = self._checkers[0] + self._cooldown_checker = self._checkers[1] + self._window_checker = self._checkers[2] + self._custom_cap_checker = self._checkers[3] + self._fair_cycle_checker = self._checkers[4] + + def check_all( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check all limits for a credential. + + Runs all limit checkers in order and returns the first failure, + or success if all pass. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating overall pass/fail + """ + for checker in self._checkers: + result = checker.check(state, model, quota_group) + if not result.allowed: + lib_logger.debug( + f"Credential {state.stable_id} blocked by {checker.name}: {result.reason}" + ) + return result + + return LimitCheckResult.ok() + + def check_specific( + self, + checker_name: str, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check a specific limit type. + + Args: + checker_name: Name of the checker ("cooldowns", "window_limits", etc.) + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult from the specified checker + """ + for checker in self._checkers: + if checker.name == checker_name: + return checker.check(state, model, quota_group) + + # Unknown checker - return ok + return LimitCheckResult.ok() + + def get_available_candidates( + self, + states: List[CredentialState], + model: str, + quota_group: Optional[str] = None, + ) -> List[CredentialState]: + """ + Filter credentials to only those passing all limits. + + Args: + states: List of credential states to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + List of available credential states + """ + available = [] + for state in states: + result = self.check_all(state, model, quota_group) + if result.allowed: + available.append(state) + + return available + + def get_blocking_info( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> Dict[str, LimitCheckResult]: + """ + Get detailed blocking info for each limit type. + + Useful for debugging and status reporting. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + Dict mapping checker name to its result + """ + results = {} + for checker in self._checkers: + results[checker.name] = checker.check(state, model, quota_group) + return results + + def reset_all( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset all limits for a credential. + + Args: + state: Credential state + model: Optional model scope + quota_group: Optional quota group scope + """ + for checker in self._checkers: + checker.reset(state, model, quota_group) + + @property + def concurrent_checker(self) -> ConcurrentLimitChecker: + """Get the concurrent limit checker.""" + return self._concurrent_checker + + @property + def cooldown_checker(self) -> CooldownChecker: + """Get the cooldown checker.""" + return self._cooldown_checker + + @property + def window_checker(self) -> WindowLimitChecker: + """Get the window limit checker.""" + return self._window_checker + + @property + def custom_cap_checker(self) -> CustomCapChecker: + """Get the custom cap checker.""" + return self._custom_cap_checker + + @property + def fair_cycle_checker(self) -> FairCycleChecker: + """Get the fair cycle checker.""" + return self._fair_cycle_checker + + def add_checker(self, checker: LimitChecker) -> None: + """ + Add a custom limit checker. + + Allows extending the limit system with custom logic. + + Args: + checker: LimitChecker implementation to add + """ + self._checkers.append(checker) + + def remove_checker(self, name: str) -> bool: + """ + Remove a limit checker by name. + + Args: + name: Name of the checker to remove + + Returns: + True if removed, False if not found + """ + for i, checker in enumerate(self._checkers): + if checker.name == name: + del self._checkers[i] + return True + return False diff --git a/src/rotator_library/usage/limits/fair_cycle.py b/src/rotator_library/usage/limits/fair_cycle.py new file mode 100644 index 00000000..04fc3cc9 --- /dev/null +++ b/src/rotator_library/usage/limits/fair_cycle.py @@ -0,0 +1,301 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Fair cycle limit checker. + +Ensures credentials are used fairly by blocking exhausted ones +until all credentials in the pool are exhausted. +""" + +import time +import logging +from typing import Dict, List, Optional, Set + +from ..types import ( + CredentialState, + LimitCheckResult, + LimitResult, + FairCycleState, + GlobalFairCycleState, + TrackingMode, + FAIR_CYCLE_GLOBAL_KEY, +) +from ..config import FairCycleConfig +from .base import LimitChecker + +lib_logger = logging.getLogger("rotator_library") + + +class FairCycleChecker(LimitChecker): + """ + Checks fair cycle constraints. + + Blocks credentials that have been "exhausted" (quota used or long cooldown) + until all credentials in the pool have been exhausted, then resets the cycle. + """ + + def __init__(self, config: FairCycleConfig): + """ + Initialize fair cycle checker. + + Args: + config: Fair cycle configuration + """ + self._config = config + # Global cycle state per provider + self._global_state: Dict[str, Dict[str, GlobalFairCycleState]] = {} + + @property + def name(self) -> str: + return "fair_cycle" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if credential is blocked by fair cycle. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + if not self._config.enabled: + return LimitCheckResult.ok() + + group_key = self._resolve_tracking_key(model, quota_group) + fc_state = state.fair_cycle.get(group_key) + + # Not exhausted = allowed + if fc_state is None or not fc_state.exhausted: + return LimitCheckResult.ok() + + # Exhausted - check if cycle should reset + provider = state.provider + global_state = self._get_global_state(provider, group_key) + + # Check if cycle has expired + if self._should_reset_cycle(global_state): + # Don't block - cycle will be reset + return LimitCheckResult.ok() + + # Still blocked by fair cycle + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_FAIR_CYCLE, + reason=f"Fair cycle: exhausted for '{group_key}' - waiting for other credentials", + blocked_until=None, # Depends on other credentials + ) + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset fair cycle state for a credential. + + Args: + state: Credential state + model: Optional model scope + quota_group: Optional quota group scope + """ + group_key = self._resolve_tracking_key(model or "", quota_group) + + if quota_group or model: + if group_key in state.fair_cycle: + fc_state = state.fair_cycle[group_key] + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + else: + # Reset all + for fc_state in state.fair_cycle.values(): + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + + def check_all_exhausted( + self, + provider: str, + group_key: str, + all_states: List[CredentialState], + priorities: Optional[Dict[str, int]] = None, + ) -> bool: + """ + Check if all credentials in the pool are exhausted. + + Args: + provider: Provider name + group_key: Model or quota group + all_states: All credential states for this provider + priorities: Optional priority filter + + Returns: + True if all are exhausted + """ + # Filter by tier if not cross-tier + if priorities and not self._config.cross_tier: + # Group by priority tier + priority_groups: Dict[int, List[CredentialState]] = {} + for state in all_states: + p = priorities.get(state.stable_id, 999) + priority_groups.setdefault(p, []).append(state) + + # Check each priority group separately + for priority, group_states in priority_groups.items(): + if not self._all_exhausted_in_group(group_states, group_key): + return False + return True + else: + return self._all_exhausted_in_group(all_states, group_key) + + def reset_cycle( + self, + provider: str, + group_key: str, + all_states: List[CredentialState], + ) -> None: + """ + Reset the fair cycle for all credentials. + + Args: + provider: Provider name + group_key: Model or quota group + all_states: All credential states to reset + """ + now = time.time() + + for state in all_states: + if group_key in state.fair_cycle: + fc_state = state.fair_cycle[group_key] + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + + # Update global state + global_state = self._get_global_state(provider, group_key) + global_state.cycle_start = now + global_state.all_exhausted_at = None + global_state.cycle_count += 1 + + lib_logger.info( + f"Fair cycle reset for {provider}/{group_key}, cycle #{global_state.cycle_count}" + ) + + def mark_all_exhausted( + self, + provider: str, + group_key: str, + ) -> None: + """ + Record that all credentials are now exhausted. + + Args: + provider: Provider name + group_key: Model or quota group + """ + global_state = self._get_global_state(provider, group_key) + global_state.all_exhausted_at = time.time() + + lib_logger.info(f"All credentials exhausted for {provider}/{group_key}") + + def get_tracking_key(self, model: str, quota_group: Optional[str]) -> str: + """Get the fair cycle tracking key for a request.""" + return self._resolve_tracking_key(model, quota_group) + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _get_global_state( + self, + provider: str, + group_key: str, + ) -> GlobalFairCycleState: + """Get or create global fair cycle state.""" + if provider not in self._global_state: + self._global_state[provider] = {} + + if group_key not in self._global_state[provider]: + self._global_state[provider][group_key] = GlobalFairCycleState( + cycle_start=time.time() + ) + + return self._global_state[provider][group_key] + + def _resolve_tracking_key( + self, + model: str, + quota_group: Optional[str], + ) -> str: + """Resolve tracking key based on fair cycle mode.""" + if self._config.tracking_mode == TrackingMode.CREDENTIAL: + return FAIR_CYCLE_GLOBAL_KEY + return quota_group or model + + def _should_reset_cycle(self, global_state: GlobalFairCycleState) -> bool: + """Check if cycle duration has expired.""" + now = time.time() + return now >= global_state.cycle_start + self._config.duration + + def _all_exhausted_in_group( + self, + states: List[CredentialState], + group_key: str, + ) -> bool: + """Check if all credentials in a group are exhausted.""" + if not states: + return True + + for state in states: + fc_state = state.fair_cycle.get(group_key) + if fc_state is None or not fc_state.exhausted: + return False + + return True + + def get_global_state_dict(self) -> Dict[str, Dict[str, Dict]]: + """ + Get global state for serialization. + + Returns: + Dict suitable for JSON serialization + """ + result = {} + for provider, groups in self._global_state.items(): + result[provider] = {} + for group_key, state in groups.items(): + result[provider][group_key] = { + "cycle_start": state.cycle_start, + "all_exhausted_at": state.all_exhausted_at, + "cycle_count": state.cycle_count, + } + return result + + def load_global_state_dict(self, data: Dict[str, Dict[str, Dict]]) -> None: + """ + Load global state from serialized data. + + Args: + data: Dict from get_global_state_dict() + """ + self._global_state.clear() + for provider, groups in data.items(): + self._global_state[provider] = {} + for group_key, state_data in groups.items(): + self._global_state[provider][group_key] = GlobalFairCycleState( + cycle_start=state_data.get("cycle_start", 0), + all_exhausted_at=state_data.get("all_exhausted_at"), + cycle_count=state_data.get("cycle_count", 0), + ) diff --git a/src/rotator_library/usage/limits/window_limits.py b/src/rotator_library/usage/limits/window_limits.py new file mode 100644 index 00000000..9872cf95 --- /dev/null +++ b/src/rotator_library/usage/limits/window_limits.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Window limit checker. + +Checks if a credential has exceeded its request quota for a window. +""" + +from typing import List, Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult, WindowStats +from ..tracking.windows import WindowManager +from .base import LimitChecker + + +class WindowLimitChecker(LimitChecker): + """ + Checks window-based request limits. + + Blocks credentials that have exhausted their quota in any + tracked window. + """ + + def __init__(self, window_manager: WindowManager): + """ + Initialize window limit checker. + + Args: + window_manager: WindowManager instance for window operations + """ + self._windows = window_manager + + @property + def name(self) -> str: + return "window_limits" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if any window limit is exceeded. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + group_key = quota_group or model + + # Check all configured windows + for definition in self._windows.definitions.values(): + scope_key = None + if definition.applies_to == "model": + scope_key = model + elif definition.applies_to == "group": + scope_key = group_key + + usage = state.get_usage_for_scope( + definition.applies_to, scope_key, create=False + ) + if usage is None: + continue + + window = usage.windows.get(definition.name) + if window is None or window.limit is None: + continue + + active = self._windows.get_active_window(usage.windows, definition.name) + if active is None: + continue + + if active.request_count >= active.limit: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_WINDOW, + reason=( + f"Window '{definition.name}' exhausted " + f"({active.request_count}/{active.limit})" + ), + blocked_until=active.reset_at, + ) + + return LimitCheckResult.ok() + + def get_remaining( + self, + state: CredentialState, + window_name: str, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> Optional[int]: + """ + Get remaining requests in a specific window. + + Args: + state: Credential state + window_name: Name of window to check + + Returns: + Remaining requests, or None if unlimited/unknown + """ + group_key = quota_group or model or "" + definition = self._windows.definitions.get(window_name) + if not definition: + return self._windows.get_window_remaining(state.usage.windows, window_name) + + scope_key = None + if definition.applies_to == "model": + scope_key = model + elif definition.applies_to == "group": + scope_key = group_key + + usage = state.get_usage_for_scope( + definition.applies_to, scope_key, create=False + ) + if not usage: + return None + + return self._windows.get_window_remaining(usage.windows, window_name) + + def get_all_remaining( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> dict[str, Optional[int]]: + """ + Get remaining requests for all windows. + + Args: + state: Credential state + + Returns: + Dict of window_name -> remaining (None if unlimited) + """ + result = {} + for definition in self._windows.definitions.values(): + result[definition.name] = self.get_remaining( + state, + definition.name, + model=model, + quota_group=quota_group, + ) + return result diff --git a/src/rotator_library/usage/manager.py b/src/rotator_library/usage/manager.py new file mode 100644 index 00000000..65513269 --- /dev/null +++ b/src/rotator_library/usage/manager.py @@ -0,0 +1,1916 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +UsageManager facade and CredentialContext. + +This is the main public API for the usage tracking system. +""" + +import asyncio +import logging +import time +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Set, Union + +from ..core.types import CredentialInfo, RequestCompleteResult +from ..error_handler import ClassifiedError, classify_error + +from .types import ( + UsageStats, + WindowStats, + CredentialState, + LimitCheckResult, + RotationMode, + LimitResult, + FAIR_CYCLE_GLOBAL_KEY, + TrackingMode, + ResetMode, +) +from .config import ( + ProviderUsageConfig, + load_provider_usage_config, + get_default_windows, +) +from .identity.registry import CredentialRegistry +from .tracking.engine import TrackingEngine +from .tracking.windows import WindowManager +from .limits.engine import LimitEngine +from .selection.engine import SelectionEngine +from .persistence.storage import UsageStorage +from .integration.hooks import HookDispatcher +from .integration.api import UsageAPI + +lib_logger = logging.getLogger("rotator_library") + + +class CredentialContext: + """ + Context manager for credential lifecycle. + + Handles: + - Automatic release on exit + - Success/failure recording + - Usage tracking + + Usage: + async with usage_manager.acquire_credential(provider, model) as ctx: + response = await make_request(ctx.credential) + ctx.mark_success(response) + """ + + def __init__( + self, + manager: "UsageManager", + credential: str, + stable_id: str, + model: str, + quota_group: Optional[str] = None, + ): + self._manager = manager + self.credential = credential # The accessor (path or key) + self.stable_id = stable_id + self.model = model + self.quota_group = quota_group + self._acquired_at = time.time() + self._result: Optional[Literal["success", "failure"]] = None + self._response: Optional[Any] = None + self._response_headers: Optional[Dict[str, Any]] = None + self._error: Optional[ClassifiedError] = None + self._tokens: Dict[str, int] = {} + self._approx_cost: float = 0.0 + + async def __aenter__(self) -> "CredentialContext": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: + # Always release the credential + await self._manager._release_credential(self.stable_id, self.model) + + success = False + error = self._error + response = self._response + + if self._result == "success": + success = True + elif self._result == "failure": + success = False + elif exc_val is not None: + error = classify_error(exc_val) + success = False + else: + success = True + + await self._manager._handle_request_complete( + stable_id=self.stable_id, + model=self.model, + quota_group=self.quota_group, + success=success, + response=response, + response_headers=self._response_headers, + error=error, + prompt_tokens=self._tokens.get("prompt", 0), + completion_tokens=self._tokens.get("completion", 0), + thinking_tokens=self._tokens.get("thinking", 0), + prompt_tokens_cache_read=self._tokens.get("prompt_cached", 0), + prompt_tokens_cache_write=self._tokens.get("prompt_cache_write", 0), + approx_cost=self._approx_cost, + ) + + return False # Don't suppress exceptions + + def mark_success( + self, + response: Any = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + response_headers: Optional[Dict[str, Any]] = None, + ) -> None: + """Mark request as successful.""" + self._result = "success" + self._response = response + self._response_headers = response_headers + self._tokens = { + "prompt": prompt_tokens, + "completion": completion_tokens, + "thinking": thinking_tokens, + "prompt_cached": prompt_tokens_cache_read, + "prompt_cache_write": prompt_tokens_cache_write, + } + self._approx_cost = approx_cost + + def mark_failure(self, error: ClassifiedError) -> None: + """Mark request as failed.""" + self._result = "failure" + self._error = error + + +class UsageManager: + """ + Main facade for usage tracking and credential selection. + + This class provides the primary interface for: + - Acquiring credentials for requests (with context manager) + - Recording usage and failures + - Selecting the best available credential + - Managing cooldowns and limits + + Example: + manager = UsageManager(provider="gemini", file_path="usage.json") + await manager.initialize(credentials) + + async with manager.acquire_credential(model="gemini-pro") as ctx: + response = await make_request(ctx.credential) + ctx.mark_success(response, prompt_tokens=100, completion_tokens=50) + """ + + def __init__( + self, + provider: str, + file_path: Optional[Union[str, Path]] = None, + provider_plugins: Optional[Dict[str, Any]] = None, + config: Optional[ProviderUsageConfig] = None, + max_concurrent_per_key: Optional[int] = None, + ): + """ + Initialize UsageManager. + + Args: + provider: Provider name (e.g., "gemini", "openai") + file_path: Path to usage.json file + provider_plugins: Dict of provider plugin classes + config: Optional pre-built configuration + max_concurrent_per_key: Max concurrent requests per credential + """ + self.provider = provider + self._provider_plugins = provider_plugins or {} + self._max_concurrent_per_key = max_concurrent_per_key + + # Load configuration + if config: + self._config = config + else: + self._config = load_provider_usage_config(provider, self._provider_plugins) + + # Initialize components + self._registry = CredentialRegistry() + self._window_manager = WindowManager( + window_definitions=self._config.windows or get_default_windows() + ) + self._tracking = TrackingEngine(self._window_manager, self._config) + self._limits = LimitEngine(self._config, self._window_manager) + self._selection = SelectionEngine( + self._config, self._limits, self._window_manager + ) + self._hooks = HookDispatcher(self._provider_plugins) + self._api = UsageAPI(self) + + # Storage + if file_path: + self._storage = UsageStorage(file_path) + else: + self._storage = None + + # State + self._states: Dict[str, CredentialState] = {} + self._initialized = False + self._lock = asyncio.Lock() + self._loaded_from_storage = False + self._loaded_count = 0 + self._quota_exhausted_summary: Dict[str, Dict[str, float]] = {} + self._quota_exhausted_task: Optional[asyncio.Task] = None + self._quota_exhausted_lock = asyncio.Lock() + self._save_task: Optional[asyncio.Task] = None + self._save_lock = asyncio.Lock() + + # Concurrency control: per-credential locks and conditions for waiting + self._key_locks: Dict[str, asyncio.Lock] = {} + self._key_conditions: Dict[str, asyncio.Condition] = {} + + async def initialize( + self, + credentials: List[str], + priorities: Optional[Dict[str, int]] = None, + tiers: Optional[Dict[str, str]] = None, + ) -> None: + """ + Initialize with credentials. + + Args: + credentials: List of credential accessors (paths or keys) + priorities: Optional priority overrides (accessor -> priority) + tiers: Optional tier overrides (accessor -> tier name) + """ + async with self._lock: + if self._initialized: + return + # Load persisted state + if self._storage: + ( + self._states, + fair_cycle_global, + loaded_from_storage, + ) = await self._storage.load() + self._loaded_from_storage = loaded_from_storage + self._loaded_count = len(self._states) + if fair_cycle_global: + self._limits.fair_cycle_checker.load_global_state_dict( + fair_cycle_global + ) + + # Register credentials + for accessor in credentials: + stable_id = self._registry.get_stable_id(accessor, self.provider) + + # Create or update state + if stable_id not in self._states: + self._states[stable_id] = CredentialState( + stable_id=stable_id, + provider=self.provider, + accessor=accessor, + created_at=time.time(), + ) + else: + # Update accessor in case it changed + self._states[stable_id].accessor = accessor + + # Apply overrides + if priorities and accessor in priorities: + self._states[stable_id].priority = priorities[accessor] + if tiers and accessor in tiers: + self._states[stable_id].tier = tiers[accessor] + + # Set max concurrent if configured, applying priority multiplier + if self._max_concurrent_per_key is not None: + base_concurrent = self._max_concurrent_per_key + # Apply priority multiplier from config + priority = self._states[stable_id].priority + multiplier = self._config.get_effective_multiplier(priority) + effective_concurrent = base_concurrent * multiplier + self._states[stable_id].max_concurrent = effective_concurrent + + self._backfill_group_usage() + + self._initialized = True + lib_logger.debug( + f"UsageManager initialized for {self.provider} with {len(credentials)} credentials" + ) + + async def acquire_credential( + self, + model: str, + quota_group: Optional[str] = None, + exclude: Optional[Set[str]] = None, + candidates: Optional[List[str]] = None, + priorities: Optional[Dict[str, int]] = None, + deadline: float = 0.0, + ) -> CredentialContext: + """ + Acquire a credential for a request. + + Returns a context manager that automatically releases + the credential and records success/failure. + + This method will wait for credentials to become available if all are + currently busy (at max_concurrent), up until the deadline. + + Args: + model: Model to use + quota_group: Optional quota group (uses model name if None) + exclude: Set of stable_ids to exclude (by accessor) + candidates: Optional list of credential accessors to consider. + If provided, only these will be considered for selection. + priorities: Optional priority overrides (accessor -> priority). + If provided, overrides the stored priorities. + deadline: Request deadline timestamp + + Returns: + CredentialContext for use with async with + + Raises: + NoAvailableKeysError: If no credentials available within deadline + """ + from ..error_handler import NoAvailableKeysError + + # Convert accessor-based exclude to stable_id-based + exclude_ids = set() + if exclude: + for accessor in exclude: + stable_id = self._registry.get_stable_id(accessor, self.provider) + exclude_ids.add(stable_id) + + # Filter states to only candidates if provided + if candidates is not None: + candidate_ids = set() + for accessor in candidates: + stable_id = self._registry.get_stable_id(accessor, self.provider) + candidate_ids.add(stable_id) + states_to_check = { + sid: state + for sid, state in self._states.items() + if sid in candidate_ids + } + else: + states_to_check = self._states + + # Convert accessor-based priorities to stable_id-based + priority_overrides = None + if priorities: + priority_overrides = {} + for accessor, priority in priorities.items(): + stable_id = self._registry.get_stable_id(accessor, self.provider) + priority_overrides[stable_id] = priority + + # Normalize model name for consistent tracking and selection + normalized_model = self._normalize_model(model) + + # Ensure key conditions exist for all candidates + for stable_id in states_to_check: + if stable_id not in self._key_conditions: + self._key_conditions[stable_id] = asyncio.Condition() + self._key_locks[stable_id] = asyncio.Lock() + + # Main acquisition loop - continues until deadline + while time.time() < deadline: + # Try to select a credential + stable_id = self._selection.select( + provider=self.provider, + model=normalized_model, + states=states_to_check, + quota_group=quota_group, + exclude=exclude_ids, + priorities=priority_overrides, + deadline=deadline, + ) + + if stable_id is not None: + state = self._states[stable_id] + lock = self._key_locks.get(stable_id) + + if lock: + async with lock: + # Double-check availability after acquiring lock + if ( + state.max_concurrent is None + or state.active_requests < state.max_concurrent + ): + state.active_requests += 1 + lib_logger.debug( + f"Acquired credential {self._mask_accessor(state.accessor)} " + f"for {model} (active: {state.active_requests}" + f"{f'/{state.max_concurrent}' if state.max_concurrent else ''})" + ) + return CredentialContext( + manager=self, + credential=state.accessor, + stable_id=stable_id, + model=normalized_model, + quota_group=quota_group, + ) + else: + # No lock configured, just increment + state.active_requests += 1 + return CredentialContext( + manager=self, + credential=state.accessor, + stable_id=stable_id, + model=normalized_model, + quota_group=quota_group, + ) + + # No credential available - need to wait + # Find the best credential to wait for (prefer lowest usage) + best_wait_id = None + best_usage = float("inf") + + for sid, state in states_to_check.items(): + if sid in exclude_ids: + continue + if ( + state.max_concurrent is not None + and state.active_requests >= state.max_concurrent + ): + # This one is busy but might become free + usage = state.usage.total_requests + if usage < best_usage: + best_usage = usage + best_wait_id = sid + + if best_wait_id is None: + # All credentials blocked by cooldown or limits, not just concurrency + # Check if waiting for cooldown makes sense + soonest_cooldown = self._get_soonest_cooldown_end( + states_to_check, normalized_model, quota_group + ) + + if soonest_cooldown is not None: + remaining_budget = deadline - time.time() + wait_needed = soonest_cooldown - time.time() + + if wait_needed > remaining_budget: + # No credential will be available in time + lib_logger.warning( + f"All credentials on cooldown. Soonest in {wait_needed:.1f}s, " + f"budget {remaining_budget:.1f}s. Failing fast." + ) + break + + # Wait for cooldown to expire + lib_logger.info( + f"All credentials on cooldown. Waiting {wait_needed:.1f}s..." + ) + await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) + continue + + # No cooldowns and no busy keys - truly no keys available + break + + # Wait on the best credential's condition + condition = self._key_conditions.get(best_wait_id) + if condition: + lib_logger.debug( + f"All credentials busy. Waiting for {self._mask_accessor(self._states[best_wait_id].accessor)}..." + ) + try: + async with condition: + remaining_budget = deadline - time.time() + if remaining_budget <= 0: + break + # Wait for notification or timeout (max 1 second to re-check) + await asyncio.wait_for( + condition.wait(), + timeout=min(1.0, remaining_budget), + ) + lib_logger.debug("Credential released. Re-evaluating...") + except asyncio.TimeoutError: + # Timeout is normal, just retry the loop + lib_logger.debug("Wait timed out. Re-evaluating...") + else: + # No condition, just sleep briefly and retry + await asyncio.sleep(0.1) + + # Deadline exceeded + raise NoAvailableKeysError( + f"Could not acquire a credential for {self.provider}/{model} " + f"within the time budget." + ) + + def _get_soonest_cooldown_end( + self, + states: Dict[str, CredentialState], + model: str, + quota_group: Optional[str], + ) -> Optional[float]: + """Get the soonest cooldown end time across all credentials.""" + soonest = None + now = time.time() + group_key = quota_group or model + + for state in states.values(): + # Check model-specific cooldown + cooldown = state.get_cooldown(group_key) + if cooldown and cooldown.until > now: + if soonest is None or cooldown.until < soonest: + soonest = cooldown.until + + # Check global cooldown + global_cooldown = state.get_cooldown() + if global_cooldown and global_cooldown.until > now: + if soonest is None or global_cooldown.until < soonest: + soonest = global_cooldown.until + + return soonest + + async def get_best_credential( + self, + model: str, + quota_group: Optional[str] = None, + exclude: Optional[Set[str]] = None, + deadline: float = 0.0, + ) -> Optional[str]: + """ + Get the best available credential without acquiring. + + Useful for checking availability or manual acquisition. + + Args: + model: Model to use + quota_group: Optional quota group + exclude: Set of accessors to exclude + deadline: Request deadline + + Returns: + Credential accessor, or None if none available + """ + # Convert exclude from accessors to stable_ids + exclude_ids = set() + if exclude: + for accessor in exclude: + stable_id = self._registry.get_stable_id(accessor, self.provider) + exclude_ids.add(stable_id) + + # Normalize model name for consistent selection + normalized_model = self._normalize_model(model) + + stable_id = self._selection.select( + provider=self.provider, + model=normalized_model, + states=self._states, + quota_group=quota_group, + exclude=exclude_ids, + deadline=deadline, + ) + + if stable_id is None: + return None + + return self._states[stable_id].accessor + + async def record_usage( + self, + accessor: str, + model: str, + success: bool, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + error: Optional[ClassifiedError] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Record usage for a credential (manual recording). + + Use this for manual tracking outside of context manager. + + Args: + accessor: Credential accessor + model: Model used + success: Whether request succeeded + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + prompt_tokens_cached: Cached prompt tokens (e.g., from Claude) + error: Classified error if failed + quota_group: Quota group + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + + if success: + await self._record_success( + stable_id, + model, + quota_group, + prompt_tokens, + completion_tokens, + thinking_tokens, + prompt_tokens_cache_read, + prompt_tokens_cache_write, + approx_cost, + ) + else: + await self._record_failure( + stable_id, + model, + quota_group, + error, + request_count=1, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + async def _handle_request_complete( + self, + stable_id: str, + model: str, + quota_group: Optional[str], + success: bool, + response: Optional[Any], + response_headers: Optional[Dict[str, Any]], + error: Optional[ClassifiedError], + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + ) -> None: + """Handle provider hooks and record request outcome.""" + state = self._states.get(stable_id) + if not state: + return + + normalized_model = self._normalize_model(model) + group_key = quota_group or normalized_model + + hook_result: Optional[RequestCompleteResult] = None + if self._hooks: + hook_result = await self._hooks.dispatch_request_complete( + provider=self.provider, + credential=state.accessor, + model=normalized_model, + success=success, + response=response, + error=error, + ) + + request_count = 1 + cooldown_override = None + force_exhausted = False + + if hook_result: + if hook_result.count_override is not None: + request_count = max(0, hook_result.count_override) + cooldown_override = hook_result.cooldown_override + force_exhausted = hook_result.force_exhausted + + if not success and error and hook_result is None: + if error.error_type in {"server_error", "api_connection"}: + request_count = 0 + + if request_count == 0: + prompt_tokens = 0 + completion_tokens = 0 + thinking_tokens = 0 + prompt_tokens_cache_read = 0 + prompt_tokens_cache_write = 0 + approx_cost = 0.0 + + if cooldown_override: + await self._tracking.apply_cooldown( + state=state, + reason="provider_hook", + duration=cooldown_override, + model_or_group=group_key, + source="provider_hook", + ) + + if force_exhausted: + await self._tracking.mark_exhausted( + state=state, + model_or_group=self._resolve_fair_cycle_key(group_key), + reason="provider_hook", + ) + + if success: + await self._record_success( + stable_id, + normalized_model, + quota_group, + prompt_tokens, + completion_tokens, + thinking_tokens, + prompt_tokens_cache_read, + prompt_tokens_cache_write, + approx_cost, + response_headers=response_headers, + request_count=request_count, + ) + else: + await self._record_failure( + stable_id, + normalized_model, + quota_group, + error, + request_count=request_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + async def apply_cooldown( + self, + accessor: str, + duration: float, + reason: str = "manual", + model_or_group: Optional[str] = None, + ) -> None: + """ + Apply a cooldown to a credential. + + Args: + accessor: Credential accessor + duration: Cooldown duration in seconds + reason: Reason for cooldown + model_or_group: Scope of cooldown + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + state = self._states.get(stable_id) + if state: + await self._tracking.apply_cooldown( + state=state, + reason=reason, + duration=duration, + model_or_group=model_or_group, + ) + await self._save_if_needed() + + async def get_availability_stats( + self, + model: str, + quota_group: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Get availability statistics for credentials. + + Args: + model: Model to check + quota_group: Quota group + + Returns: + Dict with availability info + """ + return self._selection.get_availability_stats( + provider=self.provider, + model=model, + states=self._states, + quota_group=quota_group, + ) + + async def get_stats_for_endpoint( + self, + model_filter: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Get comprehensive stats suitable for status endpoints. + + Returns credential states, usage windows, cooldowns, and fair cycle state. + + Args: + model_filter: Optional model to filter stats for + + Returns: + Dict with comprehensive statistics + """ + stats = { + "provider": self.provider, + "credential_count": len(self._states), + "rotation_mode": self._config.rotation_mode.value, + "credentials": {}, + } + + stats.update( + { + "active_count": 0, + "exhausted_count": 0, + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_cost": None, + "quota_groups": {}, + } + ) + + for stable_id, state in self._states.items(): + status = "active" + now = time.time() + if state.cooldowns: + for cooldown in state.cooldowns.values(): + if cooldown.until > now: + status = "cooldown" + break + if status == "active" and state.fair_cycle: + for fc_state in state.fair_cycle.values(): + if fc_state.exhausted: + status = "exhausted" + break + + cred_stats = { + "stable_id": stable_id, + "accessor_masked": self._mask_accessor(state.accessor), + "full_path": state.accessor, + "identifier": self._mask_accessor(state.accessor), + "email": state.display_name, + "tier": state.tier, + "priority": state.priority, + "active_requests": state.active_requests, + "status": status, + "usage": { + "total_requests": state.usage.total_requests, + "total_successes": state.usage.total_successes, + "total_failures": state.usage.total_failures, + "total_tokens": state.usage.total_tokens, + "total_prompt_tokens": state.usage.total_prompt_tokens, + "total_completion_tokens": state.usage.total_completion_tokens, + "total_thinking_tokens": state.usage.total_thinking_tokens, + "total_output_tokens": state.usage.total_output_tokens, + "total_prompt_tokens_cache_read": state.usage.total_prompt_tokens_cache_read, + "total_prompt_tokens_cache_write": state.usage.total_prompt_tokens_cache_write, + "total_approx_cost": state.usage.total_approx_cost, + }, + "windows": {}, + "model_usage": {}, + "group_usage": {}, + "cooldowns": {}, + "fair_cycle": {}, + } + + stats["total_requests"] += state.usage.total_requests + stats["tokens"]["output"] += state.usage.total_output_tokens + stats["tokens"]["input_cached"] += ( + state.usage.total_prompt_tokens_cache_read + ) + stats["tokens"]["input_uncached"] += max( + 0, + state.usage.total_prompt_tokens + - state.usage.total_prompt_tokens_cache_read, + ) + if state.usage.total_approx_cost: + stats["approx_cost"] = (stats["approx_cost"] or 0.0) + ( + state.usage.total_approx_cost + ) + + if status == "active": + stats["active_count"] += 1 + elif status == "exhausted": + stats["exhausted_count"] += 1 + + # Add window stats + for window_name, window in state.usage.windows.items(): + cred_stats["windows"][window_name] = { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "total_tokens": window.total_tokens, + "limit": window.limit, + "remaining": window.remaining, + "reset_at": window.reset_at, + "approx_cost": window.approx_cost, + } + + for model_key, usage in state.model_usage.items(): + model_windows = {} + for window_name, window in usage.windows.items(): + model_windows[window_name] = { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "total_tokens": window.total_tokens, + "limit": window.limit, + "remaining": window.remaining, + "reset_at": window.reset_at, + "approx_cost": window.approx_cost, + } + if model_windows: + cred_stats["model_usage"][model_key] = { + "windows": model_windows, + "total_requests": usage.total_requests, + "total_tokens": usage.total_tokens, + "total_prompt_tokens": usage.total_prompt_tokens, + "total_completion_tokens": usage.total_completion_tokens, + "total_thinking_tokens": usage.total_thinking_tokens, + "total_output_tokens": usage.total_output_tokens, + "total_prompt_tokens_cache_read": usage.total_prompt_tokens_cache_read, + "total_prompt_tokens_cache_write": usage.total_prompt_tokens_cache_write, + } + + for group_key, usage in state.group_usage.items(): + group_windows = {} + for window_name, window in usage.windows.items(): + group_windows[window_name] = { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "total_tokens": window.total_tokens, + "limit": window.limit, + "remaining": window.remaining, + "reset_at": window.reset_at, + "approx_cost": window.approx_cost, + } + if group_windows: + cred_stats["group_usage"][group_key] = { + "windows": group_windows, + "total_requests": usage.total_requests, + "total_tokens": usage.total_tokens, + "total_prompt_tokens": usage.total_prompt_tokens, + "total_completion_tokens": usage.total_completion_tokens, + "total_thinking_tokens": usage.total_thinking_tokens, + "total_output_tokens": usage.total_output_tokens, + "total_prompt_tokens_cache_read": usage.total_prompt_tokens_cache_read, + "total_prompt_tokens_cache_write": usage.total_prompt_tokens_cache_write, + } + + group_stats = stats["quota_groups"].setdefault( + group_key, + { + "total_requests_used": 0, + "total_requests_remaining": 0, + "total_requests_max": 0, + "total_remaining_pct": None, + "tiers": {}, + }, + ) + for window in group_windows.values(): + if window.get("limit") is not None: + limit = window["limit"] + used = window["request_count"] + remaining = max(0, limit - used) + group_stats["total_requests_used"] += used + group_stats["total_requests_remaining"] += remaining + group_stats["total_requests_max"] += limit + + tier_key = state.tier or "unknown" + tier_stats = group_stats["tiers"].setdefault( + tier_key, + {"priority": state.priority or 0, "total": 0, "active": 0}, + ) + tier_stats["total"] += 1 + if status == "active": + tier_stats["active"] += 1 + + # Add active cooldowns + for key, cooldown in state.cooldowns.items(): + if cooldown.is_active: + cred_stats["cooldowns"][key] = { + "reason": cooldown.reason, + "remaining_seconds": cooldown.remaining_seconds, + "source": cooldown.source, + } + + # Add fair cycle state + for key, fc_state in state.fair_cycle.items(): + if model_filter and key != model_filter: + continue + cred_stats["fair_cycle"][key] = { + "exhausted": fc_state.exhausted, + "cycle_request_count": fc_state.cycle_request_count, + } + + stats["credentials"][stable_id] = cred_stats + + for group_stats in stats["quota_groups"].values(): + if group_stats["total_requests_max"] > 0: + group_stats["total_remaining_pct"] = round( + group_stats["total_requests_remaining"] + / group_stats["total_requests_max"] + * 100, + 1, + ) + + total_input = ( + stats["tokens"]["input_cached"] + stats["tokens"]["input_uncached"] + ) + stats["tokens"]["input_cache_pct"] = ( + round(stats["tokens"]["input_cached"] / total_input * 100, 1) + if total_input > 0 + else 0 + ) + + return stats + + def _mask_accessor(self, accessor: str) -> str: + """Mask an accessor for safe display.""" + if accessor.endswith(".json"): + # OAuth credential - show filename only + from pathlib import Path + + return Path(accessor).name + elif len(accessor) > 12: + # API key - show first 4 and last 4 chars + return f"{accessor[:4]}...{accessor[-4:]}" + else: + return "***" + + def _get_provider_plugin_instance(self) -> Optional[Any]: + """Get the provider plugin instance for the current provider.""" + if not self._provider_plugins: + return None + + # Provider plugins dict maps provider name -> plugin class or instance + plugin = self._provider_plugins.get(self.provider) + if plugin is None: + return None + + # If it's a class, instantiate it; if already an instance, use directly + if isinstance(plugin, type): + return plugin() + return plugin + + def _normalize_model(self, model: str) -> str: + """ + Normalize model name using provider's mapping. + + Converts internal model names (e.g., claude-sonnet-4-5-thinking) to + public-facing names (e.g., claude-sonnet-4.5) for consistent storage + and tracking. + + Args: + model: Model name (with or without provider prefix) + + Returns: + Normalized model name (provider prefix preserved if present) + """ + plugin_instance = self._get_provider_plugin_instance() + + if plugin_instance and hasattr(plugin_instance, "normalize_model_for_tracking"): + return plugin_instance.normalize_model_for_tracking(model) + + return model + + def _get_model_quota_group(self, model: str) -> Optional[str]: + """ + Get the quota group for a model, if the provider defines one. + + Models in the same quota group share a single quota pool. + For example, all Claude models in Antigravity share the same daily quota. + + Args: + model: Model name (with or without provider prefix) + + Returns: + Group name (e.g., "claude") or None if not grouped + """ + plugin_instance = self._get_provider_plugin_instance() + + if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): + return plugin_instance.get_model_quota_group(model) + + return None + + def get_model_quota_group(self, model: str) -> Optional[str]: + """Public helper to get quota group for a model.""" + normalized_model = self._normalize_model(model) + return self._get_model_quota_group(normalized_model) + + def _get_grouped_models(self, group: str) -> List[str]: + """ + Get all model names in a quota group (with provider prefix), normalized. + + Returns only public-facing model names, deduplicated. Internal variants + (e.g., claude-sonnet-4-5-thinking) are normalized to their public name + (e.g., claude-sonnet-4.5). + + Args: + group: Group name (e.g., "claude") + + Returns: + List of normalized, deduplicated model names with provider prefix + (e.g., ["antigravity/claude-sonnet-4.5", "antigravity/claude-opus-4.5"]) + """ + plugin_instance = self._get_provider_plugin_instance() + + if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"): + models = plugin_instance.get_models_in_quota_group(group) + + # Normalize and deduplicate + if hasattr(plugin_instance, "normalize_model_for_tracking"): + seen: Set[str] = set() + normalized: List[str] = [] + for m in models: + prefixed = f"{self.provider}/{m}" + norm = plugin_instance.normalize_model_for_tracking(prefixed) + if norm not in seen: + seen.add(norm) + normalized.append(norm) + return normalized + + # Fallback: just add provider prefix + return [f"{self.provider}/{m}" for m in models] + + return [] + + async def _sync_quota_group_counts( + self, + state: CredentialState, + model: str, + ) -> None: + """ + Synchronize quota-group and credential windows for model usage. + + For providers with shared quota pools (e.g., Antigravity), we aggregate + per-model windows into group windows for shared quota tracking. We also + aggregate per-model windows into credential-level windows for display. + + Args: + state: Credential state that was just updated + model: The normalized model that was just used + """ + model_window_names = { + window_def.name + for window_def in self._config.windows + if window_def.applies_to == "model" + } + + if model_window_names: + self._sync_credential_windows(state, model_window_names) + + # Aggregate windows for quota groups + group = self._get_model_quota_group(model) + if not group: + return + + grouped_models = self._get_grouped_models(group) + if not grouped_models: + return + + group_usages = [ + state.model_usage[m] for m in grouped_models if m in state.model_usage + ] + if not group_usages: + return + + aggregated = self._aggregate_model_windows(group_usages, model_window_names) + if not aggregated: + return + + group_usage = state.group_usage.setdefault(group, UsageStats()) + group_usage.windows.update(aggregated) + + def _backfill_group_usage(self) -> None: + """Backfill group usage windows and cooldowns from model data.""" + for state in self._states.values(): + self._backfill_model_usage(state) + self._backfill_credential_windows(state) + self._backfill_group_usage_for_state(state) + self._backfill_group_cooldowns(state) + + def _aggregate_model_windows( + self, + usages: List[UsageStats], + model_window_names: Set[str], + ) -> Dict[str, WindowStats]: + buckets: Dict[str, Dict[str, Any]] = {} + + for usage in usages: + for window_name, window in usage.windows.items(): + if window_name not in model_window_names: + continue + bucket = buckets.setdefault( + window_name, + { + "request_count": 0, + "success_count": 0, + "failure_count": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "thinking_tokens": 0, + "output_tokens": 0, + "prompt_tokens_cache_read": 0, + "prompt_tokens_cache_write": 0, + "total_tokens": 0, + "approx_cost": 0.0, + "started_at": [], + "reset_at": [], + "limit": [], + }, + ) + bucket["request_count"] += window.request_count + bucket["success_count"] += window.success_count + bucket["failure_count"] += window.failure_count + bucket["prompt_tokens"] += window.prompt_tokens + bucket["completion_tokens"] += window.completion_tokens + bucket["thinking_tokens"] += window.thinking_tokens + bucket["output_tokens"] += window.output_tokens + bucket["prompt_tokens_cache_read"] += window.prompt_tokens_cache_read + bucket["prompt_tokens_cache_write"] += window.prompt_tokens_cache_write + bucket["total_tokens"] += window.total_tokens + bucket["approx_cost"] += window.approx_cost + if window.started_at is not None: + bucket["started_at"].append(window.started_at) + if window.reset_at is not None: + bucket["reset_at"].append(window.reset_at) + if window.limit is not None: + bucket["limit"].append(window.limit) + + aggregated: Dict[str, WindowStats] = {} + for window_name, bucket in buckets.items(): + success_count = bucket["success_count"] + failure_count = bucket["failure_count"] + request_count = success_count + failure_count + if request_count == 0: + request_count = bucket["request_count"] + if request_count > 0: + success_count = request_count + started_at = min(bucket["started_at"]) if bucket["started_at"] else None + reset_at = max(bucket["reset_at"]) if bucket["reset_at"] else None + limit = max(bucket["limit"]) if bucket["limit"] else None + aggregated[window_name] = WindowStats( + name=window_name, + request_count=request_count, + success_count=success_count, + failure_count=failure_count, + total_tokens=bucket["total_tokens"], + prompt_tokens=bucket["prompt_tokens"], + completion_tokens=bucket["completion_tokens"], + thinking_tokens=bucket["thinking_tokens"], + output_tokens=bucket["output_tokens"], + prompt_tokens_cache_read=bucket["prompt_tokens_cache_read"], + prompt_tokens_cache_write=bucket["prompt_tokens_cache_write"], + approx_cost=bucket["approx_cost"], + started_at=started_at, + reset_at=reset_at, + limit=limit, + ) + + return aggregated + + def _sync_credential_windows( + self, + state: CredentialState, + model_window_names: Set[str], + ) -> None: + if not model_window_names: + return + usages = list(state.model_usage.values()) + if not usages: + return + aggregated = self._aggregate_model_windows(usages, model_window_names) + if not aggregated: + return + state.usage.windows.update(aggregated) + + def _backfill_group_usage_for_state(self, state: CredentialState) -> None: + model_window_names = { + window_def.name + for window_def in self._config.windows + if window_def.applies_to == "model" + } + if not model_window_names: + return + + for model in state.model_usage.keys(): + group = self._get_model_quota_group(model) + if not group: + continue + group_usage = state.group_usage.setdefault(group, UsageStats()) + grouped_models = self._get_grouped_models(group) + group_usages = [ + state.model_usage[m] for m in grouped_models if m in state.model_usage + ] + aggregated = self._aggregate_model_windows(group_usages, model_window_names) + if aggregated: + group_usage.windows.update(aggregated) + + def _backfill_credential_windows(self, state: CredentialState) -> None: + model_window_names = { + window_def.name + for window_def in self._config.windows + if window_def.applies_to == "model" + } + if not model_window_names: + return + usages = list(state.model_usage.values()) + if not usages: + return + aggregated = self._aggregate_model_windows(usages, model_window_names) + if aggregated: + state.usage.windows.update(aggregated) + + def _backfill_model_usage(self, state: CredentialState) -> None: + if not state.usage.model_request_counts: + return + primary_def = self._window_manager.get_primary_definition() + if not primary_def: + return + now = time.time() + for model, count in state.usage.model_request_counts.items(): + usage = state.model_usage.setdefault(model, UsageStats()) + window = usage.windows.get(primary_def.name) + if not window: + base_time = state.last_updated or state.created_at or now + reset_at = None + if primary_def: + if primary_def.mode == ResetMode.ROLLING: + reset_at = base_time + primary_def.duration_seconds + else: + reset_at = self._window_manager._calculate_reset_time( + primary_def, base_time + ) + window = WindowStats( + name=primary_def.name, + request_count=count, + started_at=base_time, + reset_at=reset_at, + limit=None, + ) + usage.windows[primary_def.name] = window + elif window.request_count == 0: + window.request_count = count + + def _backfill_group_cooldowns(self, state: CredentialState) -> None: + for model, cooldown in list(state.cooldowns.items()): + if model in {"_global_"}: + continue + group = self._get_model_quota_group(model) + if not group: + continue + existing = state.cooldowns.get(group) + if not existing or existing.until < cooldown.until: + state.cooldowns[group] = cooldown + + for group in list(state.group_usage.keys()): + cooldown = state.cooldowns.get(group) + if not cooldown: + continue + for grouped_model in self._get_grouped_models(group): + existing = state.cooldowns.get(grouped_model) + if not existing or existing.until < cooldown.until: + state.cooldowns[grouped_model] = cooldown + + def _reconcile_window_counts(self, window: WindowStats, request_count: int) -> None: + local_total = window.success_count + window.failure_count + window.request_count = request_count + if local_total == 0 and request_count > 0: + window.success_count = request_count + window.failure_count = 0 + return + + if request_count < local_total: + failure_count = min(window.failure_count, request_count) + success_count = max(0, request_count - failure_count) + window.success_count = success_count + window.failure_count = failure_count + return + + if request_count > local_total: + window.success_count += request_count - local_total + + async def save(self, force: bool = False) -> bool: + """ + Save usage data to file. + + Args: + force: Force save even if debounce not elapsed + + Returns: + True if saved successfully + """ + if self._storage: + fair_cycle_global = self._limits.fair_cycle_checker.get_global_state_dict() + return await self._storage.save( + self._states, fair_cycle_global, force=force + ) + return False + + async def get_usage_snapshot(self) -> Dict[str, Dict[str, Any]]: + """ + Get a lightweight usage snapshot keyed by accessor. + + Returns: + Dict mapping accessor -> usage metadata. + """ + async with self._lock: + snapshot: Dict[str, Dict[str, Any]] = {} + for state in self._states.values(): + snapshot[state.accessor] = { + "last_used_ts": state.usage.last_used_at or 0, + } + return snapshot + + async def shutdown(self) -> None: + """Shutdown and save any pending data.""" + await self.save(force=True) + + async def reload_from_disk(self) -> None: + """ + Force reload usage data from disk. + + Useful when wanting fresh stats without making external API calls. + This reloads persisted state while preserving current credential registrations. + """ + if not self._storage: + lib_logger.debug( + f"reload_from_disk: No storage configured for {self.provider}" + ) + return + + async with self._lock: + # Load persisted state + loaded_states, fair_cycle_global, _ = await self._storage.load() + + # Merge loaded state with current state + # Keep current accessors but update usage data + for stable_id, loaded_state in loaded_states.items(): + if stable_id in self._states: + # Update usage data from loaded state + current = self._states[stable_id] + current.usage = loaded_state.usage + current.model_usage = loaded_state.model_usage + current.group_usage = loaded_state.group_usage + current.cooldowns = loaded_state.cooldowns + current.fair_cycle = loaded_state.fair_cycle + current.last_updated = loaded_state.last_updated + else: + # New credential from disk, add it + self._states[stable_id] = loaded_state + + # Reload fair cycle global state + if fair_cycle_global: + self._limits.fair_cycle_checker.load_global_state_dict( + fair_cycle_global + ) + + # Backfill group usage for consistency + self._backfill_group_usage() + + lib_logger.info( + f"Reloaded usage data from disk for {self.provider}: " + f"{len(self._states)} credentials" + ) + + async def update_quota_baseline( + self, + accessor: str, + model: str, + quota_max_requests: Optional[int] = None, + quota_reset_ts: Optional[float] = None, + quota_used: Optional[int] = None, + quota_group: Optional[str] = None, + force: bool = False, + ) -> Optional[Dict[str, Any]]: + """ + Update quota baseline from provider API response. + + Called by provider plugins after receiving rate limit headers or + quota information from API responses. + + Args: + accessor: Credential accessor (path or key) + model: Model name + quota_max_requests: Max requests allowed in window + quota_reset_ts: When quota resets (Unix timestamp) + quota_used: Current used count from API + quota_group: Optional quota group (uses model if None) + force: If True, always use API values (for manual refresh). + If False (default), use max(local, api) to prevent stale + API data from overwriting accurate local counts during + background fetches. + See: https://github.com/Mirrowel/LLM-API-Key-Proxy/issues/75 + + Returns: + Cooldown info dict if cooldown was applied, None otherwise + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + state = self._states.get(stable_id) + if not state: + lib_logger.warning( + f"update_quota_baseline: Unknown credential {accessor[:20]}..." + ) + return None + + # Normalize model name for consistent tracking + normalized_model = self._normalize_model(model) + group_key = quota_group or normalized_model + + primary_def = self._window_manager.get_primary_definition() + primary_window = None + + if primary_def: + scope_key = None + if primary_def.applies_to == "model": + scope_key = normalized_model + elif primary_def.applies_to == "group": + scope_key = group_key + + usage = state.get_usage_for_scope(primary_def.applies_to, scope_key) + primary_window = self._window_manager.get_or_create_window( + usage.windows, + primary_def.name, + ) + + if primary_window is None: + from .types import WindowStats + + primary_window = WindowStats(name="api_quota") + state.usage.windows["api_quota"] = primary_window + + # Update baseline values + if quota_max_requests is not None: + primary_window.limit = quota_max_requests + if quota_reset_ts is not None: + primary_window.reset_at = quota_reset_ts + if quota_used is not None: + # Use max() to prevent stale API data from overwriting local count + # during background fetches. API updates in ~20% increments so may + # return stale cached values. Force mode (manual refresh) trusts API. + # See: https://github.com/Mirrowel/LLM-API-Key-Proxy/issues/75 + if force: + synced_count = quota_used + else: + synced_count = max( + primary_window.request_count, + quota_used, + primary_window.success_count + primary_window.failure_count, + ) + self._reconcile_window_counts(primary_window, synced_count) + state.usage.model_request_counts[normalized_model] = synced_count + else: + state.usage.model_request_counts.setdefault(normalized_model, 0) + if primary_window.request_count == 0: + primary_window.request_count = 0 + + if group_key != normalized_model: + group_usage = state.get_usage_for_scope("group", group_key) + window_name = primary_def.name if primary_def else primary_window.name + group_window = self._window_manager.get_or_create_window( + group_usage.windows, + window_name, + ) + if quota_max_requests is not None: + group_window.limit = quota_max_requests + if quota_reset_ts is not None: + group_window.reset_at = quota_reset_ts + if quota_used is not None: + # Same stale-data protection for group windows + if force: + synced_count = quota_used + else: + synced_count = max( + group_window.request_count, + quota_used, + group_window.success_count + group_window.failure_count, + ) + self._reconcile_window_counts(group_window, synced_count) + else: + group_window.request_count = group_window.request_count or 0 + + # Mark state as updated + state.last_updated = time.time() + + if quota_used is not None: + await self._sync_quota_group_counts(state, normalized_model) + + # Check if we need to apply cooldown (quota exhausted) + if primary_window.is_exhausted and quota_reset_ts: + await self._tracking.apply_cooldown( + state=state, + reason="quota_exhausted", + until=quota_reset_ts, + model_or_group=group_key, + source="api_quota", + ) + + if group_key != normalized_model: + await self._tracking.apply_cooldown( + state=state, + reason="quota_exhausted", + until=quota_reset_ts, + model_or_group=normalized_model, + source="api_quota", + ) + + await self._queue_quota_exhausted_log( + accessor=accessor, + group_key=group_key, + quota_reset_ts=quota_reset_ts, + ) + + await self._save_if_needed() + + return { + "cooldown_until": quota_reset_ts, + "reason": "quota_exhausted", + "model": model, + "cooldown_hours": max(0.0, (quota_reset_ts - time.time()) / 3600), + } + + await self._save_if_needed() + + return None + + # ========================================================================= + # PROPERTIES + # ========================================================================= + + @property + def config(self) -> ProviderUsageConfig: + """Get the configuration.""" + return self._config + + @property + def registry(self) -> CredentialRegistry: + """Get the credential registry.""" + return self._registry + + @property + def api(self) -> UsageAPI: + """Get the usage API facade.""" + return self._api + + @property + def initialized(self) -> bool: + """Check if the manager is initialized.""" + return self._initialized + + @property + def tracking(self) -> TrackingEngine: + """Get the tracking engine.""" + return self._tracking + + @property + def limits(self) -> LimitEngine: + """Get the limit engine.""" + return self._limits + + @property + def window_manager(self) -> WindowManager: + """Get the window manager.""" + return self._window_manager + + @property + def selection(self) -> SelectionEngine: + """Get the selection engine.""" + return self._selection + + @property + def states(self) -> Dict[str, CredentialState]: + """Get all credential states.""" + return self._states + + @property + def loaded_from_storage(self) -> bool: + """Whether usage data was loaded from storage.""" + return self._loaded_from_storage + + @property + def loaded_credentials(self) -> int: + """Number of credentials loaded from storage.""" + return self._loaded_count + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _resolve_fair_cycle_key(self, group_key: str) -> str: + """Resolve fair cycle tracking key based on config.""" + if self._config.fair_cycle.tracking_mode == TrackingMode.CREDENTIAL: + return FAIR_CYCLE_GLOBAL_KEY + return group_key + + async def _release_credential(self, stable_id: str, model: str) -> None: + """Release a credential after use and notify waiting tasks.""" + state = self._states.get(stable_id) + if not state: + return + + # Decrement active requests + lock = self._key_locks.get(stable_id) + if lock: + async with lock: + state.active_requests = max(0, state.active_requests - 1) + remaining = state.active_requests + lib_logger.info( + f"Released credential {self._mask_accessor(state.accessor)} " + f"from {model} (remaining concurrent: {remaining}" + f"{f'/{state.max_concurrent}' if state.max_concurrent else ''})" + ) + else: + state.active_requests = max(0, state.active_requests - 1) + lib_logger.info( + f"Released credential {self._mask_accessor(state.accessor)} " + f"from {model} (remaining concurrent: {state.active_requests}" + f"{f'/{state.max_concurrent}' if state.max_concurrent else ''})" + ) + + # Notify all tasks waiting on this credential's condition + condition = self._key_conditions.get(stable_id) + if condition: + async with condition: + condition.notify_all() + + async def _queue_quota_exhausted_log( + self, accessor: str, group_key: str, quota_reset_ts: float + ) -> None: + async with self._quota_exhausted_lock: + masked = self._mask_accessor(accessor) + if masked not in self._quota_exhausted_summary: + self._quota_exhausted_summary[masked] = {} + self._quota_exhausted_summary[masked][group_key] = quota_reset_ts + + if self._quota_exhausted_task is None or self._quota_exhausted_task.done(): + self._quota_exhausted_task = asyncio.create_task( + self._flush_quota_exhausted_log() + ) + + async def _flush_quota_exhausted_log(self) -> None: + await asyncio.sleep(0.2) + async with self._quota_exhausted_lock: + summary = self._quota_exhausted_summary + self._quota_exhausted_summary = {} + + if not summary: + return + + now = time.time() + parts = [] + for accessor, groups in sorted(summary.items()): + group_parts = [] + for group, reset_ts in sorted(groups.items()): + hours = max(0.0, (reset_ts - now) / 3600) if reset_ts else 0.0 + group_parts.append(f"{group} {hours:.1f}h") + parts.append(f"{accessor}[{', '.join(group_parts)}]") + + lib_logger.info(f"Quota exhausted: {', '.join(parts)}") + + async def _save_if_needed(self) -> None: + """Persist state if storage is configured.""" + if not self._storage: + return + fair_cycle_global = self._limits.fair_cycle_checker.get_global_state_dict() + saved = await self._storage.save(self._states, fair_cycle_global) + if not saved: + await self._schedule_save_flush() + + async def _schedule_save_flush(self) -> None: + if self._save_task and not self._save_task.done(): + return + self._save_task = asyncio.create_task(self._flush_save()) + + async def _flush_save(self) -> None: + async with self._save_lock: + await asyncio.sleep(self._storage.save_debounce_seconds) + if not self._storage: + return + fair_cycle_global = self._limits.fair_cycle_checker.get_global_state_dict() + await self._storage.save_if_dirty(self._states, fair_cycle_global) + + async def _record_success( + self, + stable_id: str, + model: str, + quota_group: Optional[str] = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + response_headers: Optional[Dict[str, Any]] = None, + request_count: int = 1, + ) -> None: + """Record a successful request.""" + state = self._states.get(stable_id) + if state: + # Normalize model name for consistent tracking + normalized_model = self._normalize_model(model) + + await self._tracking.record_success( + state=state, + model=normalized_model, + quota_group=quota_group, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + response_headers=response_headers, + request_count=request_count, + ) + + # Sync request_count across quota group (for shared quota pools) + await self._sync_quota_group_counts(state, normalized_model) + + # Apply custom cap cooldown if exceeded + cap_result = self._limits.custom_cap_checker.check( + state, normalized_model, quota_group + ) + if ( + not cap_result.allowed + and cap_result.result == LimitResult.BLOCKED_CUSTOM_CAP + and cap_result.blocked_until + ): + await self._tracking.apply_cooldown( + state=state, + reason="custom_cap", + until=cap_result.blocked_until, + model_or_group=quota_group or normalized_model, + source="custom_cap", + ) + + await self._save_if_needed() + + async def _record_failure( + self, + stable_id: str, + model: str, + quota_group: Optional[str] = None, + error: Optional[ClassifiedError] = None, + request_count: int = 1, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + ) -> None: + """Record a failed request.""" + state = self._states.get(stable_id) + if not state: + return + + # Normalize model name for consistent tracking + normalized_model = self._normalize_model(model) + + # Determine cooldown from error + cooldown_duration = None + quota_reset = None + mark_exhausted = False + + if error: + cooldown_duration = error.retry_after + quota_reset = error.quota_reset_timestamp + + # Mark exhausted for quota errors with long cooldown + if error.error_type == "quota_exceeded": + if ( + cooldown_duration + and cooldown_duration >= self._config.exhaustion_cooldown_threshold + ): + mark_exhausted = True + + await self._tracking.record_failure( + state=state, + model=normalized_model, + error_type=error.error_type if error else "unknown", + quota_group=quota_group, + cooldown_duration=cooldown_duration, + quota_reset_timestamp=quota_reset, + mark_exhausted=mark_exhausted, + request_count=request_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + # Sync request_count across quota group (for shared quota pools) + await self._sync_quota_group_counts(state, normalized_model) + + await self._save_if_needed() diff --git a/src/rotator_library/usage/persistence/__init__.py b/src/rotator_library/usage/persistence/__init__.py new file mode 100644 index 00000000..9bdf7d76 --- /dev/null +++ b/src/rotator_library/usage/persistence/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Usage data persistence.""" + +from .storage import UsageStorage + +__all__ = ["UsageStorage"] diff --git a/src/rotator_library/usage/persistence/storage.py b/src/rotator_library/usage/persistence/storage.py new file mode 100644 index 00000000..c9eede4b --- /dev/null +++ b/src/rotator_library/usage/persistence/storage.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Usage data storage. + +Handles loading and saving usage data to JSON files. +""" + +import asyncio +import json +import logging +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from ..types import ( + UsageStats, + WindowStats, + CredentialState, + CooldownInfo, + FairCycleState, + GlobalFairCycleState, + StorageSchema, +) +from ...utils.resilient_io import ResilientStateWriter, safe_read_json + +lib_logger = logging.getLogger("rotator_library") + + +class UsageStorage: + """ + Handles persistence of usage data to JSON files. + + Features: + - Async file I/O with aiofiles + - Atomic writes (write to temp, then rename) + - Automatic schema migration + - Debounced saves to reduce I/O + """ + + CURRENT_SCHEMA_VERSION = 2 + + def __init__( + self, + file_path: Union[str, Path], + save_debounce_seconds: float = 5.0, + ): + """ + Initialize storage. + + Args: + file_path: Path to the usage.json file + save_debounce_seconds: Minimum time between saves + """ + self.file_path = Path(file_path) + self.save_debounce_seconds = save_debounce_seconds + + self._last_save: float = 0 + self._pending_save: bool = False + self._save_lock = asyncio.Lock() + self._dirty: bool = False + self._writer = ResilientStateWriter(self.file_path, lib_logger) + + async def load( + self, + ) -> tuple[Dict[str, CredentialState], Dict[str, Dict[str, Any]], bool]: + """ + Load usage data from file. + + Returns: + Dict of stable_id -> CredentialState + """ + if not self.file_path.exists(): + return {}, {}, False + + try: + async with self._file_lock(): + data = safe_read_json(self.file_path, lib_logger, parse_json=True) + + if not data: + return {}, {}, True + + # Check schema version + version = data.get("schema_version", 1) + if version < self.CURRENT_SCHEMA_VERSION: + lib_logger.info( + f"Migrating usage data from v{version} to v{self.CURRENT_SCHEMA_VERSION}" + ) + data = self._migrate(data, version) + + # Parse credentials + states = {} + for stable_id, cred_data in data.get("credentials", {}).items(): + state = self._parse_credential_state(stable_id, cred_data) + if state: + states[stable_id] = state + + lib_logger.info(f"Loaded {len(states)} credentials from {self.file_path}") + return states, data.get("fair_cycle_global", {}), True + + except json.JSONDecodeError as e: + lib_logger.error(f"Failed to parse usage file: {e}") + return {}, {}, True + except Exception as e: + lib_logger.error(f"Failed to load usage file: {e}") + return {}, {}, True + + async def save( + self, + states: Dict[str, CredentialState], + fair_cycle_global: Optional[Dict[str, Dict[str, Any]]] = None, + force: bool = False, + ) -> bool: + """ + Save usage data to file. + + Args: + states: Dict of stable_id -> CredentialState + fair_cycle_global: Global fair cycle state + force: Force save even if debounce not elapsed + + Returns: + True if saved, False if skipped or failed + """ + now = time.time() + + # Check debounce + if not force and (now - self._last_save) < self.save_debounce_seconds: + self._dirty = True + return False + + async with self._save_lock: + try: + # Build storage data + data = { + "schema_version": self.CURRENT_SCHEMA_VERSION, + "updated_at": datetime.now(timezone.utc).isoformat(), + "credentials": {}, + "accessor_index": {}, + "fair_cycle_global": fair_cycle_global or {}, + } + + for stable_id, state in states.items(): + data["credentials"][stable_id] = self._serialize_credential_state( + state + ) + data["accessor_index"][state.accessor] = stable_id + + saved = self._writer.write(data) + + if saved: + self._last_save = now + self._dirty = False + lib_logger.debug( + f"Saved {len(states)} credentials to {self.file_path}" + ) + return True + + self._dirty = True + return False + + except Exception as e: + lib_logger.error(f"Failed to save usage file: {e}") + return False + + async def save_if_dirty( + self, + states: Dict[str, CredentialState], + fair_cycle_global: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> bool: + """ + Save if there are pending changes. + + Args: + states: Dict of stable_id -> CredentialState + fair_cycle_global: Global fair cycle state + + Returns: + True if saved, False otherwise + """ + if self._dirty: + return await self.save(states, fair_cycle_global, force=True) + return False + + def mark_dirty(self) -> None: + """Mark data as changed, needing save.""" + self._dirty = True + + @property + def is_dirty(self) -> bool: + """Check if there are unsaved changes.""" + return self._dirty + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _file_lock(self): + """Get a lock for file operations.""" + return self._save_lock + + async def _read_file(self) -> str: + """Deprecated: use safe_read_json instead.""" + data = safe_read_json(self.file_path, lib_logger, parse_json=True) + return json.dumps(data) if data is not None else "" + + async def _write_file(self, content: str) -> None: + """Deprecated: writes handled by ResilientStateWriter.""" + try: + data = json.loads(content) + except json.JSONDecodeError: + data = None + if data is not None: + self._writer.write(data) + + def _migrate(self, data: Dict[str, Any], from_version: int) -> Dict[str, Any]: + """Migrate data from older schema versions.""" + if from_version == 1: + # v1 -> v2: Add accessor_index, restructure credentials + data["schema_version"] = 2 + data.setdefault("accessor_index", {}) + data.setdefault("fair_cycle_global", {}) + + # v1 used file paths as keys, v2 uses stable_ids + # For migration, treat paths as stable_ids + old_credentials = data.get("credentials", data.get("key_states", {})) + new_credentials = {} + + for key, cred_data in old_credentials.items(): + # Use path as temporary stable_id + stable_id = cred_data.get("stable_id", key) + new_credentials[stable_id] = cred_data + new_credentials[stable_id]["accessor"] = key + + data["credentials"] = new_credentials + + return data + + def _parse_usage_stats(self, data: Dict[str, Any]) -> UsageStats: + """Parse usage stats from storage data.""" + windows = {} + for name, wdata in data.get("windows", {}).items(): + windows[name] = WindowStats( + name=name, + request_count=wdata.get("request_count", 0), + success_count=wdata.get("success_count", 0), + failure_count=wdata.get("failure_count", 0), + total_tokens=wdata.get("total_tokens", 0), + prompt_tokens=wdata.get("prompt_tokens", 0), + completion_tokens=wdata.get("completion_tokens", 0), + thinking_tokens=wdata.get("thinking_tokens", 0), + output_tokens=wdata.get("output_tokens", 0), + prompt_tokens_cache_read=wdata.get("prompt_tokens_cache_read", 0), + prompt_tokens_cache_write=wdata.get("prompt_tokens_cache_write", 0), + approx_cost=wdata.get("approx_cost", 0.0), + started_at=wdata.get("started_at"), + reset_at=wdata.get("reset_at"), + limit=wdata.get("limit"), + ) + + return UsageStats( + windows=windows, + total_requests=data.get("total_requests", 0), + total_successes=data.get("total_successes", 0), + total_failures=data.get("total_failures", 0), + total_tokens=data.get("total_tokens", 0), + total_prompt_tokens=data.get("total_prompt_tokens", 0), + total_completion_tokens=data.get("total_completion_tokens", 0), + total_thinking_tokens=data.get("total_thinking_tokens", 0), + total_output_tokens=data.get("total_output_tokens", 0), + total_prompt_tokens_cache_read=data.get( + "total_prompt_tokens_cache_read", 0 + ), + total_prompt_tokens_cache_write=data.get( + "total_prompt_tokens_cache_write", 0 + ), + total_approx_cost=data.get("total_approx_cost", 0.0), + first_used_at=data.get("first_used_at"), + last_used_at=data.get("last_used_at"), + model_request_counts=dict(data.get("model_request_counts", {})), + ) + + def _serialize_usage_stats(self, usage: UsageStats) -> Dict[str, Any]: + """Serialize usage stats for storage.""" + windows = {} + for name, window in usage.windows.items(): + windows[name] = { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "total_tokens": window.total_tokens, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "approx_cost": window.approx_cost, + "started_at": window.started_at, + "reset_at": window.reset_at, + "limit": window.limit, + } + + return { + "windows": windows, + "total_requests": usage.total_requests, + "total_successes": usage.total_successes, + "total_failures": usage.total_failures, + "total_tokens": usage.total_tokens, + "total_prompt_tokens": usage.total_prompt_tokens, + "total_completion_tokens": usage.total_completion_tokens, + "total_thinking_tokens": usage.total_thinking_tokens, + "total_output_tokens": usage.total_output_tokens, + "total_prompt_tokens_cache_read": usage.total_prompt_tokens_cache_read, + "total_prompt_tokens_cache_write": usage.total_prompt_tokens_cache_write, + "total_approx_cost": usage.total_approx_cost, + "first_used_at": usage.first_used_at, + "last_used_at": usage.last_used_at, + "model_request_counts": usage.model_request_counts, + } + + def _parse_credential_state( + self, + stable_id: str, + data: Dict[str, Any], + ) -> Optional[CredentialState]: + """Parse a credential state from storage data.""" + try: + usage = self._parse_usage_stats(data) + + model_usage = { + key: self._parse_usage_stats(usage_data) + for key, usage_data in data.get("model_usage", {}).items() + } + group_usage = { + key: self._parse_usage_stats(usage_data) + for key, usage_data in data.get("group_usage", {}).items() + } + + # Parse cooldowns + cooldowns = {} + for key, cdata in data.get("cooldowns", {}).items(): + cooldowns[key] = CooldownInfo( + reason=cdata.get("reason", "unknown"), + until=cdata.get("until", 0), + started_at=cdata.get("started_at", 0), + source=cdata.get("source", "system"), + model_or_group=cdata.get("model_or_group"), + backoff_count=cdata.get("backoff_count", 0), + ) + + # Parse fair cycle + fair_cycle = {} + for key, fcdata in data.get("fair_cycle", {}).items(): + fair_cycle[key] = FairCycleState( + exhausted=fcdata.get("exhausted", False), + exhausted_at=fcdata.get("exhausted_at"), + exhausted_reason=fcdata.get("exhausted_reason"), + cycle_request_count=fcdata.get("cycle_request_count", 0), + model_or_group=key, + ) + + return CredentialState( + stable_id=stable_id, + provider=data.get("provider", "unknown"), + accessor=data.get("accessor", stable_id), + display_name=data.get("display_name"), + tier=data.get("tier"), + priority=data.get("priority", 999), + usage=usage, + model_usage=model_usage, + group_usage=group_usage, + cooldowns=cooldowns, + fair_cycle=fair_cycle, + active_requests=0, # Always starts at 0 + max_concurrent=data.get("max_concurrent"), + created_at=data.get("created_at"), + last_updated=data.get("last_updated"), + ) + + except Exception as e: + lib_logger.warning(f"Failed to parse credential {stable_id}: {e}") + return None + + def _serialize_credential_state(self, state: CredentialState) -> Dict[str, Any]: + """Serialize a credential state for storage.""" + # Serialize cooldowns (only active ones) + now = time.time() + cooldowns = {} + for key, cd in state.cooldowns.items(): + if cd.until > now: # Only save active cooldowns + cooldowns[key] = { + "reason": cd.reason, + "until": cd.until, + "started_at": cd.started_at, + "source": cd.source, + "model_or_group": cd.model_or_group, + "backoff_count": cd.backoff_count, + } + + # Serialize fair cycle + fair_cycle = {} + for key, fc in state.fair_cycle.items(): + fair_cycle[key] = { + "exhausted": fc.exhausted, + "exhausted_at": fc.exhausted_at, + "exhausted_reason": fc.exhausted_reason, + "cycle_request_count": fc.cycle_request_count, + } + + return { + "provider": state.provider, + "accessor": state.accessor, + "display_name": state.display_name, + "tier": state.tier, + "priority": state.priority, + **self._serialize_usage_stats(state.usage), + "model_usage": { + key: self._serialize_usage_stats(usage) + for key, usage in state.model_usage.items() + }, + "group_usage": { + key: self._serialize_usage_stats(usage) + for key, usage in state.group_usage.items() + }, + "cooldowns": cooldowns, + "fair_cycle": fair_cycle, + "max_concurrent": state.max_concurrent, + "created_at": state.created_at, + "last_updated": state.last_updated, + } diff --git a/src/rotator_library/usage/selection/__init__.py b/src/rotator_library/usage/selection/__init__.py new file mode 100644 index 00000000..79824e51 --- /dev/null +++ b/src/rotator_library/usage/selection/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Credential selection and rotation strategies.""" + +from .engine import SelectionEngine +from .strategies.balanced import BalancedStrategy +from .strategies.sequential import SequentialStrategy + +__all__ = [ + "SelectionEngine", + "BalancedStrategy", + "SequentialStrategy", +] diff --git a/src/rotator_library/usage/selection/engine.py b/src/rotator_library/usage/selection/engine.py new file mode 100644 index 00000000..e2ef9aa5 --- /dev/null +++ b/src/rotator_library/usage/selection/engine.py @@ -0,0 +1,418 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Selection engine for credential selection. + +Central component that orchestrates limit checking, modifiers, +and rotation strategies to select the best credential. +""" + +import time +import logging +from typing import Any, Dict, List, Optional, Set, Union + +from ..types import ( + CredentialState, + SelectionContext, + RotationMode, + LimitCheckResult, +) +from ..config import ProviderUsageConfig +from ..limits.engine import LimitEngine +from ..tracking.windows import WindowManager +from .strategies.balanced import BalancedStrategy +from .strategies.sequential import SequentialStrategy + +lib_logger = logging.getLogger("rotator_library") + + +class SelectionEngine: + """ + Central engine for credential selection. + + Orchestrates: + 1. Limit checking (filter unavailable credentials) + 2. Fair cycle modifiers (filter exhausted credentials) + 3. Rotation strategy (select from available) + """ + + def __init__( + self, + config: ProviderUsageConfig, + limit_engine: LimitEngine, + window_manager: WindowManager, + ): + """ + Initialize selection engine. + + Args: + config: Provider usage configuration + limit_engine: LimitEngine for availability checks + """ + self._config = config + self._limits = limit_engine + self._windows = window_manager + + # Initialize strategies + self._balanced = BalancedStrategy(config.rotation_tolerance) + self._sequential = SequentialStrategy(config.sequential_fallback_multiplier) + + # Current strategy + if config.rotation_mode == RotationMode.SEQUENTIAL: + self._strategy = self._sequential + else: + self._strategy = self._balanced + + def select( + self, + provider: str, + model: str, + states: Dict[str, CredentialState], + quota_group: Optional[str] = None, + exclude: Optional[Set[str]] = None, + priorities: Optional[Dict[str, int]] = None, + deadline: float = 0.0, + ) -> Optional[str]: + """ + Select the best available credential. + + Args: + provider: Provider name + model: Model being requested + states: Dict of stable_id -> CredentialState + quota_group: Quota group for this model + exclude: Set of stable_ids to exclude + priorities: Override priorities (stable_id -> priority) + deadline: Request deadline timestamp + + Returns: + Selected stable_id, or None if none available + """ + exclude = exclude or set() + + # Step 1: Get all candidates (not excluded) + candidates = [sid for sid in states.keys() if sid not in exclude] + + if not candidates: + return None + + # Step 2: Filter by limits + available = [] + for stable_id in candidates: + state = states[stable_id] + result = self._limits.check_all(state, model, quota_group) + if result.allowed: + available.append(stable_id) + + if not available: + # Check if we should reset fair cycle + if self._config.fair_cycle.enabled: + reset_performed = self._try_fair_cycle_reset( + provider, + model, + quota_group, + states, + candidates, + priorities, + ) + if reset_performed: + # Retry selection after reset + return self.select( + provider, + model, + states, + quota_group, + exclude, + priorities, + deadline, + ) + + lib_logger.debug( + f"No available credentials for {provider}/{model} " + f"(all {len(candidates)} blocked by limits)" + ) + return None + + # Step 3: Build selection context + # Get usage counts for weighting + usage_counts = {} + for stable_id in available: + state = states[stable_id] + usage_counts[stable_id] = self._get_usage_count(state, model, quota_group) + + # Build priorities map + if priorities is None: + priorities = {} + for stable_id in available: + priorities[stable_id] = states[stable_id].priority + + context = SelectionContext( + provider=provider, + model=model, + quota_group=quota_group, + candidates=available, + priorities=priorities, + usage_counts=usage_counts, + rotation_mode=self._config.rotation_mode, + rotation_tolerance=self._config.rotation_tolerance, + deadline=deadline or (time.time() + 120), + ) + + # Step 4: Apply rotation strategy + selected = self._strategy.select(context, states) + + if selected: + lib_logger.debug( + f"Selected credential {selected} for {provider}/{model} " + f"(from {len(available)} available)" + ) + + return selected + + def select_with_retry( + self, + provider: str, + model: str, + states: Dict[str, CredentialState], + quota_group: Optional[str] = None, + tried: Optional[Set[str]] = None, + priorities: Optional[Dict[str, int]] = None, + deadline: float = 0.0, + ) -> Optional[str]: + """ + Select a credential for retry, excluding already-tried ones. + + Convenience method for retry loops. + + Args: + provider: Provider name + model: Model being requested + states: Dict of stable_id -> CredentialState + quota_group: Quota group for this model + tried: Set of already-tried stable_ids + priorities: Override priorities + deadline: Request deadline timestamp + + Returns: + Selected stable_id, or None if none available + """ + return self.select( + provider=provider, + model=model, + states=states, + quota_group=quota_group, + exclude=tried, + priorities=priorities, + deadline=deadline, + ) + + def get_availability_stats( + self, + provider: str, + model: str, + states: Dict[str, CredentialState], + quota_group: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Get availability statistics for credentials. + + Useful for status reporting and debugging. + + Args: + provider: Provider name + model: Model being requested + states: Dict of stable_id -> CredentialState + quota_group: Quota group for this model + + Returns: + Dict with availability stats + """ + total = len(states) + available = 0 + blocked_by = { + "cooldowns": 0, + "window_limits": 0, + "custom_caps": 0, + "fair_cycle": 0, + "concurrent": 0, + } + + for stable_id, state in states.items(): + blocking = self._limits.get_blocking_info(state, model, quota_group) + + is_available = True + for checker_name, result in blocking.items(): + if not result.allowed: + is_available = False + if checker_name in blocked_by: + blocked_by[checker_name] += 1 + break + + if is_available: + available += 1 + + return { + "total": total, + "available": available, + "blocked": total - available, + "blocked_by": blocked_by, + "rotation_mode": self._config.rotation_mode.value, + } + + def set_rotation_mode(self, mode: RotationMode) -> None: + """ + Change the rotation mode. + + Args: + mode: New rotation mode + """ + self._config.rotation_mode = mode + if mode == RotationMode.SEQUENTIAL: + self._strategy = self._sequential + else: + self._strategy = self._balanced + + lib_logger.info(f"Rotation mode changed to {mode.value}") + + def mark_exhausted(self, provider: str, model_or_group: str) -> None: + """ + Mark current credential as exhausted (for sequential mode). + + Args: + provider: Provider name + model_or_group: Model or quota group + """ + if isinstance(self._strategy, SequentialStrategy): + self._strategy.mark_exhausted(provider, model_or_group) + + @property + def balanced_strategy(self) -> BalancedStrategy: + """Get the balanced strategy instance.""" + return self._balanced + + @property + def sequential_strategy(self) -> SequentialStrategy: + """Get the sequential strategy instance.""" + return self._sequential + + def _get_usage_count( + self, + state: CredentialState, + model: str, + quota_group: Optional[str], + ) -> int: + """Get the relevant usage count for rotation weighting.""" + primary_def = self._windows.get_primary_definition() + if primary_def: + scope_key = None + if primary_def.applies_to == "model": + scope_key = model + elif primary_def.applies_to == "group": + scope_key = quota_group or model + + usage = state.get_usage_for_scope( + primary_def.applies_to, scope_key, create=False + ) + if usage: + window = self._windows.get_active_window( + usage.windows, primary_def.name + ) + if window: + return window.request_count + + return state.usage.total_requests + + def _try_fair_cycle_reset( + self, + provider: str, + model: str, + quota_group: Optional[str], + states: Dict[str, CredentialState], + candidates: List[str], + priorities: Optional[Dict[str, int]], + ) -> bool: + """ + Try to reset fair cycle if all credentials are exhausted. + + Tier-aware: If cross_tier is disabled, checks each tier separately. + + Args: + provider: Provider name + model: Model being requested + quota_group: Quota group for this model + states: All credential states + candidates: Candidate stable_ids + + Returns: + True if reset was performed, False otherwise + """ + from ..types import LimitResult + + group_key = quota_group or model + fair_cycle_checker = self._limits.fair_cycle_checker + tracking_key = fair_cycle_checker.get_tracking_key(model, quota_group) + + # Check if all candidates are blocked by fair cycle + all_fair_cycle_blocked = True + fair_cycle_blocked_count = 0 + + for stable_id in candidates: + state = states[stable_id] + result = self._limits.check_all(state, model, quota_group) + + if result.allowed: + # Some credential is available - no need to reset + return False + + if result.result == LimitResult.BLOCKED_FAIR_CYCLE: + fair_cycle_blocked_count += 1 + else: + # Blocked by something other than fair cycle + all_fair_cycle_blocked = False + + # If no credentials blocked by fair cycle, can't help + if fair_cycle_blocked_count == 0: + return False + + # Get all candidate states for reset + candidate_states = [states[sid] for sid in candidates] + priority_map = priorities or {sid: states[sid].priority for sid in candidates} + + # Tier-aware reset + if self._config.fair_cycle.cross_tier: + # Cross-tier: reset all at once + if fair_cycle_checker.check_all_exhausted( + provider, tracking_key, candidate_states, priorities=priority_map + ): + lib_logger.info( + f"All credentials fair-cycle exhausted for {provider}/{model} " + f"(cross-tier), resetting cycle" + ) + fair_cycle_checker.reset_cycle(provider, tracking_key, candidate_states) + return True + else: + # Per-tier: group by priority and check each tier + tier_groups: Dict[int, List[CredentialState]] = {} + for state in candidate_states: + priority = state.priority + tier_groups.setdefault(priority, []).append(state) + + reset_any = False + for priority, tier_states in tier_groups.items(): + # Check if all in this tier are exhausted + all_tier_exhausted = all( + state.is_fair_cycle_exhausted(tracking_key) for state in tier_states + ) + + if all_tier_exhausted: + lib_logger.info( + f"All credentials fair-cycle exhausted for {provider}/{model} " + f"in tier {priority}, resetting tier cycle" + ) + fair_cycle_checker.reset_cycle(provider, tracking_key, tier_states) + reset_any = True + + return reset_any + + return False diff --git a/src/rotator_library/usage/selection/modifiers/__init__.py b/src/rotator_library/usage/selection/modifiers/__init__.py new file mode 100644 index 00000000..f06468eb --- /dev/null +++ b/src/rotator_library/usage/selection/modifiers/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Selection modifiers.""" diff --git a/src/rotator_library/usage/selection/strategies/__init__.py b/src/rotator_library/usage/selection/strategies/__init__.py new file mode 100644 index 00000000..68eeba8c --- /dev/null +++ b/src/rotator_library/usage/selection/strategies/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Rotation strategy implementations.""" + +from .balanced import BalancedStrategy +from .sequential import SequentialStrategy + +__all__ = ["BalancedStrategy", "SequentialStrategy"] diff --git a/src/rotator_library/usage/selection/strategies/balanced.py b/src/rotator_library/usage/selection/strategies/balanced.py new file mode 100644 index 00000000..6d070117 --- /dev/null +++ b/src/rotator_library/usage/selection/strategies/balanced.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Balanced rotation strategy. + +Distributes load evenly across credentials using weighted random selection. +""" + +import random +import logging +from typing import Dict, List, Optional + +from ...types import CredentialState, SelectionContext, RotationMode + +lib_logger = logging.getLogger("rotator_library") + + +class BalancedStrategy: + """ + Balanced credential rotation strategy. + + Uses weighted random selection where less-used credentials have + higher probability of being selected. The tolerance parameter + controls how much randomness is introduced. + + Weight formula: weight = (max_usage - credential_usage) + tolerance + 1 + """ + + def __init__(self, tolerance: float = 3.0): + """ + Initialize balanced strategy. + + Args: + tolerance: Controls randomness of selection. + - 0.0: Deterministic, least-used always selected + - 2.0-4.0: Recommended, balanced randomness + - 5.0+: High randomness + """ + self.tolerance = tolerance + + @property + def name(self) -> str: + return "balanced" + + @property + def mode(self) -> RotationMode: + return RotationMode.BALANCED + + def select( + self, + context: SelectionContext, + states: Dict[str, CredentialState], + ) -> Optional[str]: + """ + Select a credential using weighted random selection. + + Args: + context: Selection context with candidates and usage info + states: Dict of stable_id -> CredentialState + + Returns: + Selected stable_id, or None if no candidates + """ + if not context.candidates: + return None + + if len(context.candidates) == 1: + return context.candidates[0] + + # Group by priority for tiered selection + priority_groups = self._group_by_priority( + context.candidates, context.priorities + ) + + # Try each priority tier in order + for priority in sorted(priority_groups.keys()): + candidates = priority_groups[priority] + if not candidates: + continue + + # Calculate weights for this tier + weights = self._calculate_weights(candidates, context.usage_counts) + + # Weighted random selection + selected = self._weighted_random_choice(candidates, weights) + if selected: + return selected + + # Fallback: first candidate + return context.candidates[0] + + def _group_by_priority( + self, + candidates: List[str], + priorities: Dict[str, int], + ) -> Dict[int, List[str]]: + """Group candidates by priority tier.""" + groups: Dict[int, List[str]] = {} + for stable_id in candidates: + priority = priorities.get(stable_id, 999) + groups.setdefault(priority, []).append(stable_id) + return groups + + def _calculate_weights( + self, + candidates: List[str], + usage_counts: Dict[str, int], + ) -> List[float]: + """ + Calculate selection weights for candidates. + + Weight formula: weight = (max_usage - credential_usage) + tolerance + 1 + """ + if not candidates: + return [] + + # Get usage counts + usages = [usage_counts.get(stable_id, 0) for stable_id in candidates] + max_usage = max(usages) if usages else 0 + + # Calculate weights + weights = [] + for usage in usages: + weight = (max_usage - usage) + self.tolerance + 1 + weights.append(max(weight, 0.1)) # Ensure minimum weight + + return weights + + def _weighted_random_choice( + self, + candidates: List[str], + weights: List[float], + ) -> Optional[str]: + """Select a candidate using weighted random choice.""" + if not candidates: + return None + + if len(candidates) == 1: + return candidates[0] + + # Normalize weights + total = sum(weights) + if total <= 0: + return random.choice(candidates) + + # Weighted selection + r = random.uniform(0, total) + cumulative = 0 + for candidate, weight in zip(candidates, weights): + cumulative += weight + if r <= cumulative: + return candidate + + return candidates[-1] diff --git a/src/rotator_library/usage/selection/strategies/sequential.py b/src/rotator_library/usage/selection/strategies/sequential.py new file mode 100644 index 00000000..e00029b4 --- /dev/null +++ b/src/rotator_library/usage/selection/strategies/sequential.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Sequential rotation strategy. + +Uses one credential until exhausted, then moves to the next. +Good for providers that benefit from request caching. +""" + +import logging +from typing import Dict, List, Optional + +from ...types import CredentialState, SelectionContext, RotationMode + +lib_logger = logging.getLogger("rotator_library") + + +class SequentialStrategy: + """ + Sequential credential rotation strategy. + + Sticks to one credential until it's exhausted (rate limited, + quota exceeded, etc.), then moves to the next in priority order. + + This is useful for providers where repeated requests to the same + credential benefit from caching (e.g., context caching in LLMs). + """ + + def __init__(self, fallback_multiplier: int = 1): + """ + Initialize sequential strategy. + + Args: + fallback_multiplier: Default concurrent slots per priority + when not explicitly configured + """ + self.fallback_multiplier = fallback_multiplier + # Track current "sticky" credential per (provider, model_group) + self._current: Dict[tuple, str] = {} + + @property + def name(self) -> str: + return "sequential" + + @property + def mode(self) -> RotationMode: + return RotationMode.SEQUENTIAL + + def select( + self, + context: SelectionContext, + states: Dict[str, CredentialState], + ) -> Optional[str]: + """ + Select a credential using sequential/sticky selection. + + Prefers the currently active credential if it's still available. + Otherwise, selects the first available by priority. + + Args: + context: Selection context with candidates and usage info + states: Dict of stable_id -> CredentialState + + Returns: + Selected stable_id, or None if no candidates + """ + if not context.candidates: + return None + + if len(context.candidates) == 1: + return context.candidates[0] + + key = (context.provider, context.quota_group or context.model) + + # Check if current sticky credential is still available + current = self._current.get(key) + if current and current in context.candidates: + return current + + # Current not available - select new one by priority + selected = self._select_by_priority(context.candidates, context.priorities) + + # Make it sticky + if selected: + self._current[key] = selected + lib_logger.debug(f"Sequential: switched to credential {selected} for {key}") + + return selected + + def mark_exhausted(self, provider: str, model_or_group: str) -> None: + """ + Mark current credential as exhausted, forcing rotation. + + Args: + provider: Provider name + model_or_group: Model or quota group + """ + key = (provider, model_or_group) + if key in self._current: + old = self._current[key] + del self._current[key] + lib_logger.debug(f"Sequential: marked {old} exhausted for {key}") + + def get_current(self, provider: str, model_or_group: str) -> Optional[str]: + """ + Get the currently sticky credential. + + Args: + provider: Provider name + model_or_group: Model or quota group + + Returns: + Current sticky credential stable_id, or None + """ + key = (provider, model_or_group) + return self._current.get(key) + + def _select_by_priority( + self, + candidates: List[str], + priorities: Dict[str, int], + ) -> Optional[str]: + """Select the highest priority (lowest number) candidate.""" + if not candidates: + return None + + # Sort by priority (lower = higher priority) + sorted_candidates = sorted(candidates, key=lambda c: priorities.get(c, 999)) + + return sorted_candidates[0] + + def clear_sticky(self, provider: Optional[str] = None) -> None: + """ + Clear sticky credential state. + + Args: + provider: If specified, only clear for this provider + """ + if provider: + keys_to_remove = [k for k in self._current if k[0] == provider] + for key in keys_to_remove: + del self._current[key] + else: + self._current.clear() diff --git a/src/rotator_library/usage/tracking/__init__.py b/src/rotator_library/usage/tracking/__init__.py new file mode 100644 index 00000000..e459d28f --- /dev/null +++ b/src/rotator_library/usage/tracking/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Usage tracking and window management.""" + +from .engine import TrackingEngine +from .windows import WindowManager + +__all__ = ["TrackingEngine", "WindowManager"] diff --git a/src/rotator_library/usage/tracking/engine.py b/src/rotator_library/usage/tracking/engine.py new file mode 100644 index 00000000..4e19d388 --- /dev/null +++ b/src/rotator_library/usage/tracking/engine.py @@ -0,0 +1,656 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Tracking engine for usage recording. + +Central component for recording requests, successes, and failures. +""" + +import asyncio +import logging +import time +from typing import Any, Dict, List, Optional, Set + +from ..types import ( + UsageStats, + WindowStats, + CredentialState, + CooldownInfo, + FairCycleState, + TrackingMode, + FAIR_CYCLE_GLOBAL_KEY, +) +from ..config import WindowDefinition, ProviderUsageConfig +from .windows import WindowManager + +lib_logger = logging.getLogger("rotator_library") + + +class TrackingEngine: + """ + Central engine for usage tracking. + + Responsibilities: + - Recording request successes and failures + - Managing usage windows + - Updating global statistics + - Managing cooldowns + - Tracking fair cycle state + """ + + def __init__( + self, + window_manager: WindowManager, + config: ProviderUsageConfig, + ): + """ + Initialize tracking engine. + + Args: + window_manager: WindowManager instance for window operations + config: Provider usage configuration + """ + self._windows = window_manager + self._config = config + self._lock = asyncio.Lock() + + async def record_success( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + thinking_tokens: int = 0, + approx_cost: float = 0.0, + request_count: int = 1, + response_headers: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Record a successful request. + + Args: + state: Credential state to update + model: Model that was used + quota_group: Quota group for this model (None = use model name) + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + prompt_tokens_cached: Cached prompt tokens (e.g., from Claude) + response_headers: Optional response headers with rate limit info + """ + async with self._lock: + now = time.time() + group_key = quota_group or model + fair_cycle_key = self._resolve_fair_cycle_key(group_key) + + # Update usage stats + usage = state.usage + usage.total_requests += request_count + usage.total_successes += request_count + output_tokens = completion_tokens + thinking_tokens + usage.total_tokens += ( + prompt_tokens + + completion_tokens + + thinking_tokens + + prompt_tokens_cache_read + + prompt_tokens_cache_write + ) + usage.total_prompt_tokens += prompt_tokens + usage.total_completion_tokens += completion_tokens + usage.total_thinking_tokens += thinking_tokens + usage.total_output_tokens += output_tokens + usage.total_prompt_tokens_cache_read += prompt_tokens_cache_read + usage.total_prompt_tokens_cache_write += prompt_tokens_cache_write + usage.total_approx_cost += approx_cost + usage.last_used_at = now + if usage.first_used_at is None: + usage.first_used_at = now + + self._update_scoped_usage( + state, + scope="model", + key=model, + now=now, + request_count=request_count, + success_count=request_count, + failure_count=0, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + thinking_tokens=thinking_tokens, + approx_cost=approx_cost, + ) + if group_key != model: + self._update_scoped_usage( + state, + scope="group", + key=group_key, + now=now, + request_count=request_count, + success_count=request_count, + failure_count=0, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + thinking_tokens=thinking_tokens, + approx_cost=approx_cost, + ) + + # Update per-model request count (for quota group sync) + usage.model_request_counts[model] = ( + usage.model_request_counts.get(model, 0) + request_count + ) + + # Record in all windows + for window_def in self._config.windows: + scoped_usage = self._get_usage_for_window( + state, window_def, model, group_key + ) + if scoped_usage is None: + continue + window = self._windows.record_request( + scoped_usage.windows, + window_def.name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + thinking_tokens=thinking_tokens, + approx_cost=approx_cost, + request_count=request_count, + success_count=request_count, + ) + if window.limit is not None and window.request_count >= window.limit: + if self._config.fair_cycle.enabled: + self._mark_exhausted(state, fair_cycle_key, "window_limit") + + # Update from response headers if provided + if response_headers: + self._update_from_headers(state, response_headers, model, group_key) + + # Update fair cycle request count + if self._config.fair_cycle.enabled: + fc_state = state.fair_cycle.get(fair_cycle_key) + if not fc_state: + fc_state = FairCycleState(model_or_group=fair_cycle_key) + state.fair_cycle[fair_cycle_key] = fc_state + fc_state.cycle_request_count += request_count + + state.last_updated = now + + async def record_failure( + self, + state: CredentialState, + model: str, + error_type: str, + quota_group: Optional[str] = None, + cooldown_duration: Optional[float] = None, + quota_reset_timestamp: Optional[float] = None, + mark_exhausted: bool = False, + request_count: int = 1, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + ) -> None: + """ + Record a failed request. + + Args: + state: Credential state to update + model: Model that was used + error_type: Type of error (quota_exceeded, rate_limit, etc.) + quota_group: Quota group for this model + cooldown_duration: How long to cool down (if applicable) + quota_reset_timestamp: When quota resets (from API) + mark_exhausted: Whether to mark as exhausted for fair cycle + """ + async with self._lock: + now = time.time() + group_key = quota_group or model + fair_cycle_key = self._resolve_fair_cycle_key(group_key) + + # Update failure stats + state.usage.total_requests += request_count + state.usage.total_failures += request_count + output_tokens = completion_tokens + thinking_tokens + state.usage.total_tokens += ( + prompt_tokens + + completion_tokens + + thinking_tokens + + prompt_tokens_cache_read + + prompt_tokens_cache_write + ) + state.usage.total_prompt_tokens += prompt_tokens + state.usage.total_completion_tokens += completion_tokens + state.usage.total_thinking_tokens += thinking_tokens + state.usage.total_output_tokens += output_tokens + state.usage.total_prompt_tokens_cache_read += prompt_tokens_cache_read + state.usage.total_prompt_tokens_cache_write += prompt_tokens_cache_write + state.usage.total_approx_cost += approx_cost + state.usage.last_used_at = now + + self._update_scoped_usage( + state, + scope="model", + key=model, + now=now, + request_count=request_count, + success_count=0, + failure_count=request_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + thinking_tokens=thinking_tokens, + approx_cost=approx_cost, + ) + if group_key != model: + self._update_scoped_usage( + state, + scope="group", + key=group_key, + now=now, + request_count=request_count, + success_count=0, + failure_count=request_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + thinking_tokens=thinking_tokens, + approx_cost=approx_cost, + ) + + # Update per-model request count (for quota group sync) + state.usage.model_request_counts[model] = ( + state.usage.model_request_counts.get(model, 0) + request_count + ) + + # Record failure in windows (counts against quota) + for window_def in self._config.windows: + scoped_usage = self._get_usage_for_window( + state, window_def, model, group_key + ) + if scoped_usage is None: + continue + window = self._windows.record_request( + scoped_usage.windows, + window_def.name, + request_count=request_count, + failure_count=request_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + thinking_tokens=thinking_tokens, + approx_cost=approx_cost, + ) + if window.limit is not None and window.request_count >= window.limit: + if self._config.fair_cycle.enabled: + self._mark_exhausted(state, fair_cycle_key, "window_limit") + + # Apply cooldown if specified + if cooldown_duration is not None and cooldown_duration > 0: + self._apply_cooldown( + state=state, + reason=error_type, + duration=cooldown_duration, + model_or_group=group_key, + source="error", + ) + + # Use quota reset timestamp if provided + if quota_reset_timestamp is not None: + self._apply_cooldown( + state=state, + reason=error_type, + until=quota_reset_timestamp, + model_or_group=group_key, + source="api_quota", + ) + + # Mark exhausted for fair cycle if requested + if mark_exhausted: + self._mark_exhausted(state, fair_cycle_key, error_type) + + if self._config.fair_cycle.enabled: + fc_state = state.fair_cycle.get(fair_cycle_key) + if not fc_state: + fc_state = FairCycleState(model_or_group=fair_cycle_key) + state.fair_cycle[fair_cycle_key] = fc_state + fc_state.cycle_request_count += request_count + + state.last_updated = now + + async def acquire( + self, + state: CredentialState, + model: str, + ) -> bool: + """ + Acquire a credential for a request (increment active count). + + Args: + state: Credential state + model: Model being used + + Returns: + True if acquired, False if at max concurrent + """ + async with self._lock: + # Check concurrent limit + if state.max_concurrent is not None: + if state.active_requests >= state.max_concurrent: + return False + + state.active_requests += 1 + return True + + async def apply_cooldown( + self, + state: CredentialState, + reason: str, + duration: Optional[float] = None, + until: Optional[float] = None, + model_or_group: Optional[str] = None, + source: str = "system", + ) -> None: + """ + Apply a cooldown to a credential. + + Args: + state: Credential state + reason: Why the cooldown was applied + duration: Cooldown duration in seconds (if not using 'until') + until: Timestamp when cooldown ends (if not using 'duration') + model_or_group: Scope of cooldown (None = credential-wide) + source: Source of cooldown (system, custom_cap, rate_limit, etc.) + """ + async with self._lock: + self._apply_cooldown( + state=state, + reason=reason, + duration=duration, + until=until, + model_or_group=model_or_group, + source=source, + ) + + async def clear_cooldown( + self, + state: CredentialState, + model_or_group: Optional[str] = None, + ) -> None: + """ + Clear a cooldown from a credential. + + Args: + state: Credential state + model_or_group: Scope of cooldown to clear (None = global) + """ + async with self._lock: + key = model_or_group or "_global_" + if key in state.cooldowns: + del state.cooldowns[key] + + async def mark_exhausted( + self, + state: CredentialState, + model_or_group: str, + reason: str, + ) -> None: + """ + Mark a credential as exhausted for fair cycle. + + Args: + state: Credential state + model_or_group: Scope of exhaustion + reason: Why credential was exhausted + """ + async with self._lock: + self._mark_exhausted(state, model_or_group, reason) + + async def reset_fair_cycle( + self, + state: CredentialState, + model_or_group: str, + ) -> None: + """ + Reset fair cycle state for a credential. + + Args: + state: Credential state + model_or_group: Scope to reset + """ + async with self._lock: + if model_or_group in state.fair_cycle: + fc_state = state.fair_cycle[model_or_group] + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + + def get_window_usage( + self, + state: CredentialState, + window_name: str, + ) -> int: + """ + Get request count for a specific window. + + Args: + state: Credential state + window_name: Name of window + + Returns: + Request count (0 if window doesn't exist) + """ + window = self._windows.get_active_window(state.usage.windows, window_name) + return window.request_count if window else 0 + + def get_primary_window_usage(self, state: CredentialState) -> int: + """ + Get request count for the primary window. + + Args: + state: Credential state + + Returns: + Request count (0 if no primary window) + """ + window = self._windows.get_primary_window(state.usage.windows) + return window.request_count if window else 0 + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _apply_cooldown( + self, + state: CredentialState, + reason: str, + duration: Optional[float] = None, + until: Optional[float] = None, + model_or_group: Optional[str] = None, + source: str = "system", + ) -> None: + """Internal cooldown application (no lock).""" + now = time.time() + + if until is not None: + cooldown_until = until + elif duration is not None: + cooldown_until = now + duration + else: + return # No cooldown specified + + key = model_or_group or "_global_" + + # Check for existing cooldown + existing = state.cooldowns.get(key) + backoff_count = 0 + if existing and existing.is_active: + backoff_count = existing.backoff_count + 1 + + state.cooldowns[key] = CooldownInfo( + reason=reason, + until=cooldown_until, + started_at=now, + source=source, + model_or_group=model_or_group, + backoff_count=backoff_count, + ) + + # Check if cooldown qualifies as exhaustion + cooldown_duration = cooldown_until - now + if cooldown_duration >= self._config.exhaustion_cooldown_threshold: + if self._config.fair_cycle.enabled and model_or_group: + fair_cycle_key = self._resolve_fair_cycle_key(model_or_group) + self._mark_exhausted(state, fair_cycle_key, f"cooldown_{reason}") + + def _mark_exhausted( + self, + state: CredentialState, + model_or_group: str, + reason: str, + ) -> None: + """Internal exhaustion marking (no lock).""" + now = time.time() + + if model_or_group not in state.fair_cycle: + state.fair_cycle[model_or_group] = FairCycleState( + model_or_group=model_or_group + ) + + fc_state = state.fair_cycle[model_or_group] + fc_state.exhausted = True + fc_state.exhausted_at = now + fc_state.exhausted_reason = reason + + lib_logger.debug( + f"Credential {state.stable_id} marked exhausted for {model_or_group}: {reason}" + ) + + def _resolve_fair_cycle_key(self, group_key: str) -> str: + """Resolve fair cycle tracking key based on config.""" + if self._config.fair_cycle.tracking_mode == TrackingMode.CREDENTIAL: + return FAIR_CYCLE_GLOBAL_KEY + return group_key + + def _get_usage_for_window( + self, + state: CredentialState, + window_def: WindowDefinition, + model: str, + group_key: str, + ) -> Optional[UsageStats]: + """Get usage stats for a window definition's scope.""" + scope_key = None + if window_def.applies_to == "model": + scope_key = model + elif window_def.applies_to == "group": + scope_key = group_key + return state.get_usage_for_scope(window_def.applies_to, scope_key) + + def _update_scoped_usage( + self, + state: CredentialState, + scope: str, + key: Optional[str], + now: float, + request_count: int, + success_count: int, + failure_count: int, + prompt_tokens: int, + completion_tokens: int, + prompt_tokens_cache_read: int, + prompt_tokens_cache_write: int, + thinking_tokens: int, + approx_cost: float, + ) -> None: + """Update scoped usage stats.""" + usage = state.get_usage_for_scope(scope, key) + if not usage: + return + output_tokens = completion_tokens + thinking_tokens + usage.total_requests += request_count + usage.total_successes += success_count + usage.total_failures += failure_count + usage.total_tokens += ( + prompt_tokens + + completion_tokens + + thinking_tokens + + prompt_tokens_cache_read + + prompt_tokens_cache_write + ) + usage.total_prompt_tokens += prompt_tokens + usage.total_completion_tokens += completion_tokens + usage.total_thinking_tokens += thinking_tokens + usage.total_output_tokens += output_tokens + usage.total_prompt_tokens_cache_read += prompt_tokens_cache_read + usage.total_prompt_tokens_cache_write += prompt_tokens_cache_write + usage.total_approx_cost += approx_cost + usage.last_used_at = now + if usage.first_used_at is None: + usage.first_used_at = now + + def _update_from_headers( + self, + state: CredentialState, + headers: Dict[str, Any], + model: str, + group_key: str, + ) -> None: + """Update state from API response headers.""" + # Common header patterns for rate limiting + # X-RateLimit-Remaining, X-RateLimit-Reset, etc. + remaining = headers.get("x-ratelimit-remaining") + reset = headers.get("x-ratelimit-reset") + limit = headers.get("x-ratelimit-limit") + + # Update primary window if we have limit info + primary_def = self._windows.get_primary_definition() + if primary_def is None: + return + + scope_key = None + if primary_def.applies_to == "model": + scope_key = model + elif primary_def.applies_to == "group": + scope_key = group_key + + usage = state.get_usage_for_scope( + primary_def.applies_to, scope_key, create=False + ) + if usage is None: + return + + if limit is not None: + try: + limit_int = int(limit) + primary = self._windows.get_primary_window(usage.windows) + if primary: + primary.limit = limit_int + except (ValueError, TypeError): + pass + + if reset is not None: + try: + reset_float = float(reset) + # If reset is in the past, it might be a Unix timestamp + # If it's a small number, it might be seconds until reset + if reset_float < 1000000000: # Less than ~2001, probably relative + reset_float = time.time() + reset_float + primary = self._windows.get_primary_window(usage.windows) + if primary: + primary.reset_at = reset_float + except (ValueError, TypeError): + pass diff --git a/src/rotator_library/usage/tracking/windows.py b/src/rotator_library/usage/tracking/windows.py new file mode 100644 index 00000000..dbc64fb6 --- /dev/null +++ b/src/rotator_library/usage/tracking/windows.py @@ -0,0 +1,461 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Window management for usage tracking. + +Handles time-based usage windows with various reset modes. +""" + +import time +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone, time as dt_time +from typing import Any, Dict, List, Optional, Tuple + +from ..types import WindowStats, ResetMode +from ..config import WindowDefinition + +lib_logger = logging.getLogger("rotator_library") + + +class WindowManager: + """ + Manages usage tracking windows for credentials. + + Handles: + - Rolling windows (e.g., last 5 hours) + - Fixed daily windows (reset at specific UTC time) + - Calendar windows (weekly, monthly) + - API-authoritative windows (provider determines reset) + """ + + def __init__( + self, + window_definitions: List[WindowDefinition], + daily_reset_time_utc: str = "03:00", + ): + """ + Initialize window manager. + + Args: + window_definitions: List of window configurations + daily_reset_time_utc: Time for daily reset in HH:MM format + """ + self.definitions = {w.name: w for w in window_definitions} + self.daily_reset_time_utc = self._parse_time(daily_reset_time_utc) + + def get_active_window( + self, + windows: Dict[str, WindowStats], + window_name: str, + ) -> Optional[WindowStats]: + """ + Get an active (non-expired) window by name. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to get + + Returns: + WindowStats if active, None if expired or doesn't exist + """ + window = windows.get(window_name) + if window is None: + return None + + definition = self.definitions.get(window_name) + if definition is None: + return window # Unknown window, return as-is + + # Check if window needs reset + if self._should_reset(window, definition): + return None + + return window + + def get_or_create_window( + self, + windows: Dict[str, WindowStats], + window_name: str, + limit: Optional[int] = None, + ) -> WindowStats: + """ + Get an active window or create a new one. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to get/create + limit: Optional request limit for the window + + Returns: + Active WindowStats (may be newly created) + """ + window = self.get_active_window(windows, window_name) + if window is not None: + return window + + # Create new window + definition = self.definitions.get(window_name) + now = time.time() + + new_window = WindowStats( + name=window_name, + started_at=now, + reset_at=self._calculate_reset_time(definition, now) + if definition + else None, + limit=limit, + ) + + windows[window_name] = new_window + return new_window + + def record_request( + self, + windows: Dict[str, WindowStats], + window_name: str, + prompt_tokens: int = 0, + completion_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + thinking_tokens: int = 0, + approx_cost: float = 0.0, + request_count: int = 1, + success_count: int = 0, + failure_count: int = 0, + limit: Optional[int] = None, + ) -> WindowStats: + """ + Record a request in a window. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to record in + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + prompt_tokens_cached: Cached prompt tokens (e.g., from Claude) + limit: Optional request limit for the window + + Returns: + Updated WindowStats + """ + window = self.get_or_create_window(windows, window_name, limit) + + window.request_count += request_count + window.success_count += success_count + window.failure_count += failure_count + window.prompt_tokens += prompt_tokens + window.prompt_tokens_cache_read += prompt_tokens_cache_read + window.prompt_tokens_cache_write += prompt_tokens_cache_write + window.completion_tokens += completion_tokens + window.thinking_tokens += thinking_tokens + window.output_tokens += completion_tokens + thinking_tokens + window.approx_cost += approx_cost + window.total_tokens += ( + prompt_tokens + + completion_tokens + + thinking_tokens + + prompt_tokens_cache_read + + prompt_tokens_cache_write + ) + + return window + + def get_primary_window( + self, + windows: Dict[str, WindowStats], + ) -> Optional[WindowStats]: + """ + Get the primary window used for rotation decisions. + + Args: + windows: Current windows dict for a credential + + Returns: + Primary WindowStats or None + """ + for name, definition in self.definitions.items(): + if definition.is_primary: + return self.get_active_window(windows, name) + return None + + def get_primary_definition(self) -> Optional[WindowDefinition]: + """Get the primary window definition.""" + for definition in self.definitions.values(): + if definition.is_primary: + return definition + return None + + def get_window_remaining( + self, + windows: Dict[str, WindowStats], + window_name: str, + ) -> Optional[int]: + """ + Get remaining requests in a window. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to check + + Returns: + Remaining requests, or None if unlimited/unknown + """ + window = self.get_active_window(windows, window_name) + if window is None: + return None + return window.remaining + + def check_expired_windows( + self, + windows: Dict[str, WindowStats], + ) -> List[str]: + """ + Check for and remove expired windows. + + Args: + windows: Current windows dict for a credential + + Returns: + List of removed window names + """ + expired = [] + for name in list(windows.keys()): + if self.get_active_window(windows, name) is None: + del windows[name] + expired.append(name) + return expired + + def update_limit( + self, + windows: Dict[str, WindowStats], + window_name: str, + new_limit: int, + ) -> None: + """ + Update the limit for a window (e.g., from API response). + + Args: + windows: Current windows dict for a credential + window_name: Name of window to update + new_limit: New request limit + """ + window = windows.get(window_name) + if window is not None: + window.limit = new_limit + + def update_reset_time( + self, + windows: Dict[str, WindowStats], + window_name: str, + reset_timestamp: float, + ) -> None: + """ + Update the reset time for a window (e.g., from API response). + + Args: + windows: Current windows dict for a credential + window_name: Name of window to update + reset_timestamp: New reset timestamp + """ + window = windows.get(window_name) + if window is not None: + window.reset_at = reset_timestamp + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _should_reset(self, window: WindowStats, definition: WindowDefinition) -> bool: + """ + Check if a window should be reset based on its definition. + """ + now = time.time() + + # If window has an explicit reset time, use it + if window.reset_at is not None: + return now >= window.reset_at + + # If window has no start time, it needs reset + if window.started_at is None: + return True + + # Check based on reset mode + if definition.reset_mode == ResetMode.ROLLING: + if definition.duration_seconds is None: + return False # Infinite window + return now >= window.started_at + definition.duration_seconds + + elif definition.reset_mode == ResetMode.FIXED_DAILY: + return self._past_daily_reset(window.started_at, now) + + elif definition.reset_mode == ResetMode.CALENDAR_WEEKLY: + return self._past_weekly_reset(window.started_at, now) + + elif definition.reset_mode == ResetMode.CALENDAR_MONTHLY: + return self._past_monthly_reset(window.started_at, now) + + elif definition.reset_mode == ResetMode.API_AUTHORITATIVE: + # Only reset if explicit reset_at is set and passed + return False + + return False + + def _calculate_reset_time( + self, + definition: WindowDefinition, + start_time: float, + ) -> Optional[float]: + """ + Calculate when a window should reset based on its definition. + """ + if definition.reset_mode == ResetMode.ROLLING: + if definition.duration_seconds is None: + return None # Infinite window + return start_time + definition.duration_seconds + + elif definition.reset_mode == ResetMode.FIXED_DAILY: + return self._next_daily_reset(start_time) + + elif definition.reset_mode == ResetMode.CALENDAR_WEEKLY: + return self._next_weekly_reset(start_time) + + elif definition.reset_mode == ResetMode.CALENDAR_MONTHLY: + return self._next_monthly_reset(start_time) + + elif definition.reset_mode == ResetMode.API_AUTHORITATIVE: + return None # Will be set by API response + + return None + + def _parse_time(self, time_str: str) -> dt_time: + """Parse HH:MM time string.""" + try: + parts = time_str.split(":") + return dt_time(hour=int(parts[0]), minute=int(parts[1])) + except (ValueError, IndexError): + return dt_time(hour=3, minute=0) # Default 03:00 + + def _past_daily_reset(self, started_at: float, now: float) -> bool: + """Check if we've passed the daily reset time since window started.""" + start_dt = datetime.fromtimestamp(started_at, tz=timezone.utc) + now_dt = datetime.fromtimestamp(now, tz=timezone.utc) + + # Get reset time for the day after start + reset_dt = start_dt.replace( + hour=self.daily_reset_time_utc.hour, + minute=self.daily_reset_time_utc.minute, + second=0, + microsecond=0, + ) + if reset_dt <= start_dt: + # Reset time already passed today, use tomorrow + from datetime import timedelta + + reset_dt += timedelta(days=1) + + return now_dt >= reset_dt + + def _next_daily_reset(self, from_time: float) -> float: + """Calculate next daily reset timestamp.""" + from datetime import timedelta + + from_dt = datetime.fromtimestamp(from_time, tz=timezone.utc) + reset_dt = from_dt.replace( + hour=self.daily_reset_time_utc.hour, + minute=self.daily_reset_time_utc.minute, + second=0, + microsecond=0, + ) + if reset_dt <= from_dt: + reset_dt += timedelta(days=1) + + return reset_dt.timestamp() + + def _past_weekly_reset(self, started_at: float, now: float) -> bool: + """Check if we've passed the weekly reset (Sunday 03:00 UTC).""" + start_dt = datetime.fromtimestamp(started_at, tz=timezone.utc) + now_dt = datetime.fromtimestamp(now, tz=timezone.utc) + + # Get start of next week (Sunday 03:00 UTC) + days_until_sunday = (6 - start_dt.weekday()) % 7 + if days_until_sunday == 0 and start_dt.hour >= 3: + days_until_sunday = 7 + + from datetime import timedelta + + reset_dt = start_dt.replace( + hour=3, minute=0, second=0, microsecond=0 + ) + timedelta(days=days_until_sunday) + + return now_dt >= reset_dt + + def _next_weekly_reset(self, from_time: float) -> float: + """Calculate next weekly reset timestamp.""" + from datetime import timedelta + + from_dt = datetime.fromtimestamp(from_time, tz=timezone.utc) + days_until_sunday = (6 - from_dt.weekday()) % 7 + if days_until_sunday == 0 and from_dt.hour >= 3: + days_until_sunday = 7 + + reset_dt = from_dt.replace( + hour=3, minute=0, second=0, microsecond=0 + ) + timedelta(days=days_until_sunday) + + return reset_dt.timestamp() + + def _past_monthly_reset(self, started_at: float, now: float) -> bool: + """Check if we've passed the monthly reset (1st 03:00 UTC).""" + start_dt = datetime.fromtimestamp(started_at, tz=timezone.utc) + now_dt = datetime.fromtimestamp(now, tz=timezone.utc) + + # Get 1st of next month + if start_dt.month == 12: + reset_dt = start_dt.replace( + year=start_dt.year + 1, + month=1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + else: + reset_dt = start_dt.replace( + month=start_dt.month + 1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + + return now_dt >= reset_dt + + def _next_monthly_reset(self, from_time: float) -> float: + """Calculate next monthly reset timestamp.""" + from_dt = datetime.fromtimestamp(from_time, tz=timezone.utc) + + if from_dt.month == 12: + reset_dt = from_dt.replace( + year=from_dt.year + 1, + month=1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + else: + reset_dt = from_dt.replace( + month=from_dt.month + 1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + + return reset_dt.timestamp() diff --git a/src/rotator_library/usage/types.py b/src/rotator_library/usage/types.py new file mode 100644 index 00000000..0132b5c7 --- /dev/null +++ b/src/rotator_library/usage/types.py @@ -0,0 +1,379 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Type definitions for the usage tracking package. + +This module contains dataclasses and type definitions specific to +usage tracking, limits, and credential selection. +""" + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union + + +# ============================================================================= +# ENUMS +# ============================================================================= + + +FAIR_CYCLE_GLOBAL_KEY = "_credential_" + + +class ResetMode(str, Enum): + """How a usage window resets.""" + + ROLLING = "rolling" # Continuous rolling window + FIXED_DAILY = "fixed_daily" # Reset at specific time each day + CALENDAR_WEEKLY = "calendar_weekly" # Reset at start of week + CALENDAR_MONTHLY = "calendar_monthly" # Reset at start of month + API_AUTHORITATIVE = "api_authoritative" # Provider API determines reset + + +class LimitResult(str, Enum): + """Result of a limit check.""" + + ALLOWED = "allowed" + BLOCKED_WINDOW = "blocked_window" + BLOCKED_COOLDOWN = "blocked_cooldown" + BLOCKED_FAIR_CYCLE = "blocked_fair_cycle" + BLOCKED_CUSTOM_CAP = "blocked_custom_cap" + BLOCKED_CONCURRENT = "blocked_concurrent" + + +class RotationMode(str, Enum): + """How credentials are rotated.""" + + BALANCED = "balanced" # Weighted random selection + SEQUENTIAL = "sequential" # Sticky until exhausted + + +class TrackingMode(str, Enum): + """How fair cycle tracks exhaustion.""" + + MODEL_GROUP = "model_group" # Track per quota group or model + CREDENTIAL = "credential" # Track per credential globally + + +class CooldownMode(str, Enum): + """How custom cap cooldowns are calculated.""" + + QUOTA_RESET = "quota_reset" # Wait until quota window resets + OFFSET = "offset" # Add offset seconds to current time + FIXED = "fixed" # Use fixed duration + + +# ============================================================================= +# WINDOW TYPES +# ============================================================================= + + +@dataclass +class WindowStats: + """ + Statistics for a single usage window. + + Tracks usage within a specific time window (e.g., 5-hour, daily). + """ + + name: str # Window identifier (e.g., "5h", "daily") + request_count: int = 0 + success_count: int = 0 + failure_count: int = 0 + total_tokens: int = 0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + thinking_tokens: int = 0 + output_tokens: int = 0 + prompt_tokens_cache_read: int = 0 + prompt_tokens_cache_write: int = 0 + approx_cost: float = 0.0 + started_at: Optional[float] = None # Timestamp when window started + reset_at: Optional[float] = None # Timestamp when window resets + limit: Optional[int] = None # Max requests allowed (None = unlimited) + + @property + def remaining(self) -> Optional[int]: + """Remaining requests in this window, or None if unlimited.""" + if self.limit is None: + return None + return max(0, self.limit - self.request_count) + + @property + def is_exhausted(self) -> bool: + """True if limit reached.""" + if self.limit is None: + return False + return self.request_count >= self.limit + + +@dataclass +class UsageStats: + """ + Aggregated usage statistics for a credential. + + Contains both per-window and global (all-time) statistics. + """ + + windows: Dict[str, WindowStats] = field(default_factory=dict) + total_requests: int = 0 + total_successes: int = 0 + total_failures: int = 0 + total_tokens: int = 0 + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + total_thinking_tokens: int = 0 + total_output_tokens: int = 0 + total_prompt_tokens_cache_read: int = 0 + total_prompt_tokens_cache_write: int = 0 + total_approx_cost: float = 0.0 + first_used_at: Optional[float] = None + last_used_at: Optional[float] = None + + # Per-model request counts (for quota group synchronization) + # Key: normalized model name, Value: request count + model_request_counts: Dict[str, int] = field(default_factory=dict) + + +# ============================================================================= +# COOLDOWN TYPES +# ============================================================================= + + +@dataclass +class CooldownInfo: + """ + Information about a cooldown period. + + Cooldowns temporarily block a credential from being used. + """ + + reason: str # Why the cooldown was applied + until: float # Timestamp when cooldown ends + started_at: float # Timestamp when cooldown started + source: str = "system" # "system", "custom_cap", "rate_limit", "provider_hook" + model_or_group: Optional[str] = None # Scope of cooldown (None = credential-wide) + backoff_count: int = 0 # Number of consecutive cooldowns + + @property + def remaining_seconds(self) -> float: + """Seconds remaining in cooldown.""" + import time + + return max(0.0, self.until - time.time()) + + @property + def is_active(self) -> bool: + """True if cooldown is still in effect.""" + import time + + return time.time() < self.until + + +# ============================================================================= +# FAIR CYCLE TYPES +# ============================================================================= + + +@dataclass +class FairCycleState: + """ + Fair cycle state for a credential. + + Tracks whether a credential has been exhausted in the current cycle. + """ + + exhausted: bool = False + exhausted_at: Optional[float] = None + exhausted_reason: Optional[str] = None + cycle_request_count: int = 0 # Requests in current cycle + model_or_group: Optional[str] = None # Scope of exhaustion + + +@dataclass +class GlobalFairCycleState: + """ + Global fair cycle state for a provider. + + Tracks the overall cycle across all credentials. + """ + + cycle_start: float = 0.0 # Timestamp when current cycle started + all_exhausted_at: Optional[float] = None # When all credentials exhausted + cycle_count: int = 0 # How many full cycles completed + + +# ============================================================================= +# CREDENTIAL STATE +# ============================================================================= + + +@dataclass +class CredentialState: + """ + Complete state for a single credential. + + This is the primary storage unit for credential data. + """ + + # Identity + stable_id: str # Email (OAuth) or hash (API key) + provider: str + accessor: str # Current file path or API key + display_name: Optional[str] = None + tier: Optional[str] = None + priority: int = 999 # Lower = higher priority + + # Usage stats + usage: UsageStats = field(default_factory=UsageStats) + + # Per-model usage stats (for per-model windows) + model_usage: Dict[str, UsageStats] = field(default_factory=dict) + + # Per-quota-group usage stats (for shared quota windows) + group_usage: Dict[str, UsageStats] = field(default_factory=dict) + + # Cooldowns (keyed by model/group or "_global_") + cooldowns: Dict[str, CooldownInfo] = field(default_factory=dict) + + # Fair cycle state (keyed by model/group) + fair_cycle: Dict[str, FairCycleState] = field(default_factory=dict) + + # Active requests (for concurrent request limiting) + active_requests: int = 0 + max_concurrent: Optional[int] = None + + # Metadata + created_at: Optional[float] = None + last_updated: Optional[float] = None + + def get_cooldown( + self, model_or_group: Optional[str] = None + ) -> Optional[CooldownInfo]: + """Get active cooldown for given scope.""" + import time + + now = time.time() + + # Check specific cooldown + if model_or_group: + cooldown = self.cooldowns.get(model_or_group) + if cooldown and cooldown.until > now: + return cooldown + + # Check global cooldown + global_cooldown = self.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + return global_cooldown + + return None + + def is_fair_cycle_exhausted(self, model_or_group: str) -> bool: + """Check if exhausted for fair cycle purposes.""" + state = self.fair_cycle.get(model_or_group) + return state.exhausted if state else False + + def get_usage_for_scope( + self, + scope: str, + key: Optional[str] = None, + create: bool = True, + ) -> Optional[UsageStats]: + """Get usage stats for a given scope.""" + if scope == "credential": + return self.usage + if scope == "model": + if not key: + return self.usage + if create: + return self.model_usage.setdefault(key, UsageStats()) + return self.model_usage.get(key) + if scope == "group": + if not key: + return self.usage + if create: + return self.group_usage.setdefault(key, UsageStats()) + return self.group_usage.get(key) + return self.usage + + +# ============================================================================= +# SELECTION TYPES +# ============================================================================= + + +@dataclass +class SelectionContext: + """ + Context passed to rotation strategies during credential selection. + + Contains all information needed to make a selection decision. + """ + + provider: str + model: str + quota_group: Optional[str] # Quota group for this model + candidates: List[str] # Stable IDs of available candidates + priorities: Dict[str, int] # stable_id -> priority + usage_counts: Dict[str, int] # stable_id -> request count for relevant window + rotation_mode: RotationMode + rotation_tolerance: float + deadline: float + + +@dataclass +class LimitCheckResult: + """ + Result of checking all limits for a credential. + + Used by LimitEngine to report why a credential was blocked. + """ + + allowed: bool + result: LimitResult = LimitResult.ALLOWED + reason: Optional[str] = None + blocked_until: Optional[float] = None # When the block expires + + @classmethod + def ok(cls) -> "LimitCheckResult": + """Create an allowed result.""" + return cls(allowed=True, result=LimitResult.ALLOWED) + + @classmethod + def blocked( + cls, + result: LimitResult, + reason: str, + blocked_until: Optional[float] = None, + ) -> "LimitCheckResult": + """Create a blocked result.""" + return cls( + allowed=False, + result=result, + reason=reason, + blocked_until=blocked_until, + ) + + +# ============================================================================= +# STORAGE TYPES +# ============================================================================= + + +@dataclass +class StorageSchema: + """ + Root schema for usage.json storage file. + """ + + schema_version: int = 2 + updated_at: Optional[str] = None # ISO format + credentials: Dict[str, Dict[str, Any]] = field(default_factory=dict) + accessor_index: Dict[str, str] = field( + default_factory=dict + ) # accessor -> stable_id + fair_cycle_global: Dict[str, Dict[str, Any]] = field( + default_factory=dict + ) # provider -> GlobalFairCycleState diff --git a/src/rotator_library/usage_manager.py b/src/rotator_library/usage_manager.py index 46e30bbc..5430dc80 100644 --- a/src/rotator_library/usage_manager.py +++ b/src/rotator_library/usage_manager.py @@ -1,3980 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -import json -import os -import time -import logging -import asyncio -import random -from datetime import date, datetime, timezone, time as dt_time -from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union -import aiofiles -import litellm +"""Compatibility shim for legacy imports.""" -from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential -from .providers import PROVIDER_PLUGINS -from .utils.resilient_io import ResilientStateWriter -from .utils.paths import get_data_file -from .config import ( - DEFAULT_FAIR_CYCLE_DURATION, - DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, - DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, - DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE, - COOLDOWN_BACKOFF_TIERS, - COOLDOWN_BACKOFF_MAX, - COOLDOWN_AUTH_ERROR, - COOLDOWN_TRANSIENT_ERROR, - COOLDOWN_RATE_LIMIT_DEFAULT, -) +from .usage import UsageManager, CredentialContext -lib_logger = logging.getLogger("rotator_library") -lib_logger.propagate = False -if not lib_logger.handlers: - lib_logger.addHandler(logging.NullHandler()) - - -class UsageManager: - """ - Manages usage statistics and cooldowns for API keys with asyncio-safe locking, - asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation. - - The credential rotation strategy can be configured via the `rotation_tolerance` parameter: - - - **tolerance = 0.0**: Deterministic least-used selection. The credential with - the lowest usage count is always selected. This provides predictable, perfectly balanced - load distribution but may be vulnerable to fingerprinting. - - - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected - randomly with weights biased toward less-used ones. Credentials within 2 uses of the - maximum can still be selected with reasonable probability. This provides security through - unpredictability while maintaining good load balance. - - - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant - selection probability. Useful for stress testing or maximum unpredictability, but may - result in less balanced load distribution. - - The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1` - - This ensures lower-usage credentials are preferred while tolerance controls how much - randomness is introduced into the selection process. - - Additionally, providers can specify a rotation mode: - - "balanced" (default): Rotate credentials to distribute load evenly - - "sequential": Use one credential until exhausted (preserves caching) - """ - - def __init__( - self, - file_path: Optional[Union[str, Path]] = None, - daily_reset_time_utc: Optional[str] = "03:00", - rotation_tolerance: float = 0.0, - provider_rotation_modes: Optional[Dict[str, str]] = None, - provider_plugins: Optional[Dict[str, Any]] = None, - priority_multipliers: Optional[Dict[str, Dict[int, int]]] = None, - priority_multipliers_by_mode: Optional[ - Dict[str, Dict[str, Dict[int, int]]] - ] = None, - sequential_fallback_multipliers: Optional[Dict[str, int]] = None, - fair_cycle_enabled: Optional[Dict[str, bool]] = None, - fair_cycle_tracking_mode: Optional[Dict[str, str]] = None, - fair_cycle_cross_tier: Optional[Dict[str, bool]] = None, - fair_cycle_duration: Optional[Dict[str, int]] = None, - exhaustion_cooldown_threshold: Optional[Dict[str, int]] = None, - custom_caps: Optional[ - Dict[str, Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]]] - ] = None, - ): - """ - Initialize the UsageManager. - - Args: - file_path: Path to the usage data JSON file. If None, uses get_data_file("key_usage.json"). - Can be absolute Path, relative Path, or string. - daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format) - rotation_tolerance: Tolerance for weighted random credential rotation. - - 0.0: Deterministic, least-used credential always selected - - tolerance = 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max - - 5.0+: High randomness, more unpredictable selection patterns - provider_rotation_modes: Dict mapping provider names to rotation modes. - - "balanced": Rotate credentials to distribute load evenly (default) - - "sequential": Use one credential until exhausted (preserves caching) - provider_plugins: Dict mapping provider names to provider plugin instances. - Used for per-provider usage reset configuration (window durations, field names). - priority_multipliers: Dict mapping provider -> priority -> multiplier. - Universal multipliers that apply regardless of rotation mode. - Example: {"antigravity": {1: 5, 2: 3}} - priority_multipliers_by_mode: Dict mapping provider -> mode -> priority -> multiplier. - Mode-specific overrides. Example: {"antigravity": {"balanced": {3: 1}}} - sequential_fallback_multipliers: Dict mapping provider -> fallback multiplier. - Used in sequential mode when priority not in priority_multipliers. - Example: {"antigravity": 2} - fair_cycle_enabled: Dict mapping provider -> bool to enable fair cycle rotation. - When enabled, credentials must all exhaust before any can be reused. - Default: enabled for sequential mode only. - fair_cycle_tracking_mode: Dict mapping provider -> tracking mode. - - "model_group": Track per quota group or model (default) - - "credential": Track per credential globally - fair_cycle_cross_tier: Dict mapping provider -> bool for cross-tier tracking. - - False: Each tier cycles independently (default) - - True: All credentials must exhaust regardless of tier - fair_cycle_duration: Dict mapping provider -> cycle duration in seconds. - Default: 86400 (24 hours) - exhaustion_cooldown_threshold: Dict mapping provider -> threshold in seconds. - A cooldown must exceed this to qualify as "exhausted". Default: 300 (5 min) - custom_caps: Dict mapping provider -> tier -> model/group -> cap config. - Allows setting custom usage limits per tier, per model or quota group. - See ProviderInterface.default_custom_caps for format details. - """ - # Resolve file_path - use default if not provided - if file_path is None: - self.file_path = str(get_data_file("key_usage.json")) - elif isinstance(file_path, Path): - self.file_path = str(file_path) - else: - # String path - could be relative or absolute - self.file_path = file_path - self.rotation_tolerance = rotation_tolerance - self.provider_rotation_modes = provider_rotation_modes or {} - self.provider_plugins = provider_plugins or PROVIDER_PLUGINS - self.priority_multipliers = priority_multipliers or {} - self.priority_multipliers_by_mode = priority_multipliers_by_mode or {} - self.sequential_fallback_multipliers = sequential_fallback_multipliers or {} - self._provider_instances: Dict[str, Any] = {} # Cache for provider instances - self.key_states: Dict[str, Dict[str, Any]] = {} - - # Fair cycle rotation configuration - self.fair_cycle_enabled = fair_cycle_enabled or {} - self.fair_cycle_tracking_mode = fair_cycle_tracking_mode or {} - self.fair_cycle_cross_tier = fair_cycle_cross_tier or {} - self.fair_cycle_duration = fair_cycle_duration or {} - self.exhaustion_cooldown_threshold = exhaustion_cooldown_threshold or {} - self.custom_caps = custom_caps or {} - # In-memory cycle state: {provider: {tier_key: {tracking_key: {"cycle_started_at": float, "exhausted": Set[str]}}}} - self._cycle_exhausted: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {} - - self._data_lock = asyncio.Lock() - self._usage_data: Optional[Dict] = None - self._initialized = asyncio.Event() - self._init_lock = asyncio.Lock() - - self._timeout_lock = asyncio.Lock() - self._claimed_on_timeout: Set[str] = set() - - # Resilient writer for usage data persistence - self._state_writer = ResilientStateWriter(file_path, lib_logger) - - if daily_reset_time_utc: - hour, minute = map(int, daily_reset_time_utc.split(":")) - self.daily_reset_time_utc = dt_time( - hour=hour, minute=minute, tzinfo=timezone.utc - ) - else: - self.daily_reset_time_utc = None - - def _get_rotation_mode(self, provider: str) -> str: - """ - Get the rotation mode for a provider. - - Args: - provider: Provider name (e.g., "antigravity", "gemini_cli") - - Returns: - "balanced" or "sequential" - """ - return self.provider_rotation_modes.get(provider, "balanced") - - # ========================================================================= - # FAIR CYCLE ROTATION HELPERS - # ========================================================================= - - def _is_fair_cycle_enabled(self, provider: str, rotation_mode: str) -> bool: - """ - Check if fair cycle rotation is enabled for a provider. - - Args: - provider: Provider name - rotation_mode: Current rotation mode ("balanced" or "sequential") - - Returns: - True if fair cycle is enabled - """ - # Check provider-specific setting first - if provider in self.fair_cycle_enabled: - return self.fair_cycle_enabled[provider] - # Default: enabled only for sequential mode - return rotation_mode == "sequential" - - def _get_fair_cycle_tracking_mode(self, provider: str) -> str: - """ - Get fair cycle tracking mode for a provider. - - Returns: - "model_group" or "credential" - """ - return self.fair_cycle_tracking_mode.get(provider, "model_group") - - def _is_fair_cycle_cross_tier(self, provider: str) -> bool: - """ - Check if fair cycle tracks across all tiers (ignoring priority boundaries). - - Returns: - True if cross-tier tracking is enabled - """ - return self.fair_cycle_cross_tier.get(provider, False) - - def _get_fair_cycle_duration(self, provider: str) -> int: - """ - Get fair cycle duration in seconds for a provider. - - Returns: - Duration in seconds (default 86400 = 24 hours) - """ - return self.fair_cycle_duration.get(provider, DEFAULT_FAIR_CYCLE_DURATION) - - def _get_exhaustion_cooldown_threshold(self, provider: str) -> int: - """ - Get exhaustion cooldown threshold in seconds for a provider. - - A cooldown must exceed this duration to qualify as "exhausted" for fair cycle. - - Returns: - Threshold in seconds (default 300 = 5 minutes) - """ - return self.exhaustion_cooldown_threshold.get( - provider, DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD - ) - - # ========================================================================= - # CUSTOM CAPS HELPERS - # ========================================================================= - - def _get_custom_cap_config( - self, - provider: str, - tier_priority: int, - model: str, - ) -> Optional[Dict[str, Any]]: - """ - Get custom cap config for a provider/tier/model combination. - - Resolution order: - 1. tier + model (exact match) - 2. tier + group (model's quota group) - 3. "default" + model - 4. "default" + group - - Args: - provider: Provider name - tier_priority: Credential's priority level - model: Model name (with provider prefix) - - Returns: - Cap config dict or None if no custom cap applies - """ - provider_caps = self.custom_caps.get(provider) - if not provider_caps: - return None - - # Strip provider prefix from model - clean_model = model.split("/")[-1] if "/" in model else model - - # Get quota group for this model - group = self._get_model_quota_group_by_provider(provider, model) - - # Try to find matching tier config - tier_config = None - default_config = None - - for tier_key, models_config in provider_caps.items(): - if tier_key == "default": - default_config = models_config - continue - - # Check if this tier_key matches our priority - if isinstance(tier_key, int) and tier_key == tier_priority: - tier_config = models_config - break - elif isinstance(tier_key, tuple) and tier_priority in tier_key: - tier_config = models_config - break - - # Resolution order for tier config - if tier_config: - # Try model first - if clean_model in tier_config: - return tier_config[clean_model] - # Try group - if group and group in tier_config: - return tier_config[group] - - # Resolution order for default config - if default_config: - # Try model first - if clean_model in default_config: - return default_config[clean_model] - # Try group - if group and group in default_config: - return default_config[group] - - return None - - def _get_model_quota_group_by_provider( - self, provider: str, model: str - ) -> Optional[str]: - """ - Get quota group for a model using provider name instead of credential. - - Args: - provider: Provider name - model: Model name - - Returns: - Group name or None - """ - plugin_instance = self._get_provider_instance(provider) - if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): - return plugin_instance.get_model_quota_group(model) - return None - - def _resolve_custom_cap_max( - self, - provider: str, - model: str, - cap_config: Dict[str, Any], - actual_max: Optional[int], - ) -> Optional[int]: - """ - Resolve custom cap max_requests value, handling percentages and clamping. - - Args: - provider: Provider name - model: Model name (for logging) - cap_config: Custom cap configuration - actual_max: Actual API max requests (may be None if unknown) - - Returns: - Resolved cap value (clamped), or None if can't be calculated - """ - max_requests = cap_config.get("max_requests") - if max_requests is None: - return None - - # Handle percentage - if isinstance(max_requests, str) and max_requests.endswith("%"): - if actual_max is None: - lib_logger.warning( - f"Custom cap '{max_requests}' for {provider}/{model} requires known max_requests. " - f"Skipping until quota baseline is fetched. Use absolute value for immediate enforcement." - ) - return None - try: - percentage = float(max_requests.rstrip("%")) / 100.0 - calculated = int(actual_max * percentage) - except ValueError: - lib_logger.warning( - f"Invalid percentage cap '{max_requests}' for {provider}/{model}" - ) - return None - else: - # Absolute value - try: - calculated = int(max_requests) - except (ValueError, TypeError): - lib_logger.warning( - f"Invalid cap value '{max_requests}' for {provider}/{model}" - ) - return None - - # Clamp to actual max (can only be MORE restrictive) - if actual_max is not None: - return min(calculated, actual_max) - return calculated - - def _calculate_custom_cooldown_until( - self, - cap_config: Dict[str, Any], - window_start_ts: Optional[float], - natural_reset_ts: Optional[float], - ) -> Optional[float]: - """ - Calculate when custom cap cooldown should end, clamped to natural reset. - - Args: - cap_config: Custom cap configuration - window_start_ts: When first request was made (for fixed mode) - natural_reset_ts: Natural quota reset timestamp - - Returns: - Cooldown end timestamp (clamped), or None if can't calculate - """ - mode = cap_config.get("cooldown_mode", DEFAULT_CUSTOM_CAP_COOLDOWN_MODE) - value = cap_config.get("cooldown_value", DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE) - - if mode == "quota_reset": - calculated = natural_reset_ts - elif mode == "offset": - if natural_reset_ts is None: - return None - calculated = natural_reset_ts + value - elif mode == "fixed": - if window_start_ts is None: - return None - calculated = window_start_ts + value - else: - lib_logger.warning(f"Unknown cooldown_mode '{mode}', using quota_reset") - calculated = natural_reset_ts - - if calculated is None: - return None - - # Clamp to natural reset (can only be MORE restrictive = longer cooldown) - if natural_reset_ts is not None: - return max(calculated, natural_reset_ts) - return calculated - - def _check_and_apply_custom_cap( - self, - credential: str, - model: str, - request_count: int, - ) -> bool: - """ - Check if custom cap is exceeded and apply cooldown if so. - - This should be called after incrementing request_count in record_success(). - - Args: - credential: Credential identifier - model: Model name (with provider prefix) - request_count: Current request count for this model - - Returns: - True if cap exceeded and cooldown applied, False otherwise - """ - provider = self._get_provider_from_credential(credential) - if not provider: - return False - - priority = self._get_credential_priority(credential, provider) - cap_config = self._get_custom_cap_config(provider, priority, model) - if not cap_config: - return False - - # Get model data for actual max and timing info - key_data = self._usage_data.get(credential, {}) - model_data = key_data.get("models", {}).get(model, {}) - actual_max = model_data.get("quota_max_requests") - window_start_ts = model_data.get("window_start_ts") - natural_reset_ts = model_data.get("quota_reset_ts") - - # Resolve custom cap max - custom_max = self._resolve_custom_cap_max( - provider, model, cap_config, actual_max - ) - if custom_max is None: - return False - - # Check if exceeded - if request_count < custom_max: - return False - - # Calculate cooldown end time - cooldown_until = self._calculate_custom_cooldown_until( - cap_config, window_start_ts, natural_reset_ts - ) - if cooldown_until is None: - # Can't calculate cooldown, use natural reset if available - if natural_reset_ts: - cooldown_until = natural_reset_ts - else: - lib_logger.warning( - f"Custom cap hit for {mask_credential(credential)}/{model} but can't calculate cooldown. " - f"Skipping cooldown application." - ) - return False - - now_ts = time.time() - - # Apply cooldown - model_cooldowns = key_data.setdefault("model_cooldowns", {}) - model_cooldowns[model] = cooldown_until - - # Store custom cap info in model data for reference - model_data["custom_cap_max"] = custom_max - model_data["custom_cap_hit_at"] = now_ts - model_data["custom_cap_cooldown_until"] = cooldown_until - - hours_until = (cooldown_until - now_ts) / 3600 - lib_logger.info( - f"Custom cap hit: {mask_credential(credential)} reached {request_count}/{custom_max} " - f"for {model}. Cooldown for {hours_until:.1f}h" - ) - - # Sync cooldown across quota group - group = self._get_model_quota_group(credential, model) - if group: - grouped_models = self._get_grouped_models(credential, group) - for grouped_model in grouped_models: - if grouped_model != model: - model_cooldowns[grouped_model] = cooldown_until - - # Check if this should trigger fair cycle exhaustion - cooldown_duration = cooldown_until - now_ts - threshold = self._get_exhaustion_cooldown_threshold(provider) - if cooldown_duration > threshold: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - tier_key = self._get_tier_key(provider, priority) - tracking_key = self._get_tracking_key(credential, model, provider) - self._mark_credential_exhausted( - credential, provider, tier_key, tracking_key - ) - - return True - - def _get_tier_key(self, provider: str, priority: int) -> str: - """ - Get the tier key for cycle tracking based on cross_tier setting. - - Args: - provider: Provider name - priority: Credential priority level - - Returns: - "__all_tiers__" if cross-tier enabled, else str(priority) - """ - if self._is_fair_cycle_cross_tier(provider): - return "__all_tiers__" - return str(priority) - - def _get_tracking_key(self, credential: str, model: str, provider: str) -> str: - """ - Get the key for exhaustion tracking based on tracking mode. - - Args: - credential: Credential identifier - model: Model name (with provider prefix) - provider: Provider name - - Returns: - Tracking key string (quota group name, model name, or "__credential__") - """ - mode = self._get_fair_cycle_tracking_mode(provider) - if mode == "credential": - return "__credential__" - # model_group mode: use quota group if exists, else model - group = self._get_model_quota_group(credential, model) - return group if group else model - - def _get_credential_priority(self, credential: str, provider: str) -> int: - """ - Get the priority level for a credential. - - Args: - credential: Credential identifier - provider: Provider name - - Returns: - Priority level (default 999 if unknown) - """ - plugin_instance = self._get_provider_instance(provider) - if plugin_instance and hasattr(plugin_instance, "get_credential_priority"): - priority = plugin_instance.get_credential_priority(credential) - if priority is not None: - return priority - return 999 - - def _get_cycle_data( - self, provider: str, tier_key: str, tracking_key: str - ) -> Optional[Dict[str, Any]]: - """ - Get cycle data for a provider/tier/tracking key combination. - - Returns: - Cycle data dict or None if not exists - """ - return ( - self._cycle_exhausted.get(provider, {}).get(tier_key, {}).get(tracking_key) - ) - - def _ensure_cycle_structure( - self, provider: str, tier_key: str, tracking_key: str - ) -> Dict[str, Any]: - """ - Ensure the nested cycle structure exists and return the cycle data dict. - """ - if provider not in self._cycle_exhausted: - self._cycle_exhausted[provider] = {} - if tier_key not in self._cycle_exhausted[provider]: - self._cycle_exhausted[provider][tier_key] = {} - if tracking_key not in self._cycle_exhausted[provider][tier_key]: - self._cycle_exhausted[provider][tier_key][tracking_key] = { - "cycle_started_at": None, - "exhausted": set(), - } - return self._cycle_exhausted[provider][tier_key][tracking_key] - - def _mark_credential_exhausted( - self, - credential: str, - provider: str, - tier_key: str, - tracking_key: str, - ) -> None: - """ - Mark a credential as exhausted for fair cycle tracking. - - Starts the cycle timer on first exhaustion. - Skips if credential is already in the exhausted set (prevents duplicate logging). - """ - cycle_data = self._ensure_cycle_structure(provider, tier_key, tracking_key) - - # Skip if already exhausted in this cycle (prevents duplicate logging) - if credential in cycle_data.get("exhausted", set()): - return - - # Start cycle timer on first exhaustion - if cycle_data["cycle_started_at"] is None: - cycle_data["cycle_started_at"] = time.time() - lib_logger.info( - f"Fair cycle started for {provider} tier={tier_key} tracking='{tracking_key}'" - ) - - cycle_data["exhausted"].add(credential) - lib_logger.info( - f"Fair cycle: marked {mask_credential(credential)} exhausted " - f"for {tracking_key} ({len(cycle_data['exhausted'])} total)" - ) - - def _is_credential_exhausted_in_cycle( - self, - credential: str, - provider: str, - tier_key: str, - tracking_key: str, - ) -> bool: - """ - Check if a credential was exhausted in the current cycle. - """ - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data is None: - return False - return credential in cycle_data.get("exhausted", set()) - - def _is_cycle_expired( - self, provider: str, tier_key: str, tracking_key: str - ) -> bool: - """ - Check if the current cycle has exceeded its duration. - """ - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data is None: - return False - cycle_started = cycle_data.get("cycle_started_at") - if cycle_started is None: - return False - duration = self._get_fair_cycle_duration(provider) - return time.time() >= cycle_started + duration - - def _should_reset_cycle( - self, - provider: str, - tier_key: str, - tracking_key: str, - all_credentials_in_tier: List[str], - available_not_on_cooldown: Optional[List[str]] = None, - ) -> bool: - """ - Check if cycle should reset. - - Returns True if: - 1. Cycle duration has expired, OR - 2. No credentials remain available (after cooldown + fair cycle exclusion), OR - 3. All credentials in the tier have been marked exhausted (fallback) - """ - # Check duration first - if self._is_cycle_expired(provider, tier_key, tracking_key): - return True - - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data is None: - return False - - # If available credentials are provided, reset when none remain usable - if available_not_on_cooldown is not None: - has_available = any( - not self._is_credential_exhausted_in_cycle( - cred, provider, tier_key, tracking_key - ) - for cred in available_not_on_cooldown - ) - if not has_available and len(all_credentials_in_tier) > 0: - return True - - exhausted = cycle_data.get("exhausted", set()) - # All must be exhausted (and there must be at least one credential) - return ( - len(exhausted) >= len(all_credentials_in_tier) - and len(all_credentials_in_tier) > 0 - ) - - def _reset_cycle(self, provider: str, tier_key: str, tracking_key: str) -> None: - """ - Reset exhaustion tracking for a completed cycle. - """ - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data: - exhausted_count = len(cycle_data.get("exhausted", set())) - lib_logger.info( - f"Fair cycle complete for {provider} tier={tier_key} " - f"tracking='{tracking_key}' - resetting ({exhausted_count} credentials cycled)" - ) - cycle_data["cycle_started_at"] = None - cycle_data["exhausted"] = set() - - def _get_all_credentials_for_tier_key( - self, - provider: str, - tier_key: str, - available_keys: List[str], - credential_priorities: Optional[Dict[str, int]], - ) -> List[str]: - """ - Get all credentials that belong to a tier key. - - Args: - provider: Provider name - tier_key: Either "__all_tiers__" or str(priority) - available_keys: List of available credential identifiers - credential_priorities: Dict mapping credentials to priorities - - Returns: - List of credentials belonging to this tier key - """ - if tier_key == "__all_tiers__": - # Cross-tier: all credentials for this provider - return list(available_keys) - else: - # Within-tier: only credentials with matching priority - priority = int(tier_key) - if credential_priorities: - return [ - k - for k in available_keys - if credential_priorities.get(k, 999) == priority - ] - return list(available_keys) - - def _count_fair_cycle_excluded( - self, - provider: str, - tier_key: str, - tracking_key: str, - candidates: List[str], - ) -> int: - """ - Count how many candidates are excluded by fair cycle. - - Args: - provider: Provider name - tier_key: Tier key for tracking - tracking_key: Model/group tracking key - candidates: List of candidate credentials (not on cooldown) - - Returns: - Number of candidates excluded by fair cycle - """ - count = 0 - for cred in candidates: - if self._is_credential_exhausted_in_cycle( - cred, provider, tier_key, tracking_key - ): - count += 1 - return count - - def _get_priority_multiplier( - self, provider: str, priority: int, rotation_mode: str - ) -> int: - """ - Get the concurrency multiplier for a provider/priority/mode combination. - - Lookup order: - 1. Mode-specific tier override: priority_multipliers_by_mode[provider][mode][priority] - 2. Universal tier multiplier: priority_multipliers[provider][priority] - 3. Sequential fallback (if mode is sequential): sequential_fallback_multipliers[provider] - 4. Global default: 1 (no multiplier effect) - - Args: - provider: Provider name (e.g., "antigravity") - priority: Priority level (1 = highest priority) - rotation_mode: Current rotation mode ("sequential" or "balanced") - - Returns: - Multiplier value - """ - provider_lower = provider.lower() - - # 1. Check mode-specific override - if provider_lower in self.priority_multipliers_by_mode: - mode_multipliers = self.priority_multipliers_by_mode[provider_lower] - if rotation_mode in mode_multipliers: - if priority in mode_multipliers[rotation_mode]: - return mode_multipliers[rotation_mode][priority] - - # 2. Check universal tier multiplier - if provider_lower in self.priority_multipliers: - if priority in self.priority_multipliers[provider_lower]: - return self.priority_multipliers[provider_lower][priority] - - # 3. Sequential fallback (only for sequential mode) - if rotation_mode == "sequential": - if provider_lower in self.sequential_fallback_multipliers: - return self.sequential_fallback_multipliers[provider_lower] - - # 4. Global default - return 1 - - def _get_provider_from_credential(self, credential: str) -> Optional[str]: - """ - Extract provider name from credential path or identifier. - - Supports multiple credential formats: - - OAuth: "oauth_creds/antigravity_oauth_15.json" -> "antigravity" - - OAuth: "C:\\...\\oauth_creds\\gemini_cli_oauth_1.json" -> "gemini_cli" - - OAuth filename only: "antigravity_oauth_1.json" -> "antigravity" - - API key style: extracted from model names in usage data (e.g., "firmware/model" -> "firmware") - - Args: - credential: The credential identifier (path or key) - - Returns: - Provider name string or None if cannot be determined - """ - import re - - # Pattern: env:// URI format (e.g., "env://antigravity/1" -> "antigravity") - if credential.startswith("env://"): - parts = credential[6:].split("/") # Remove "env://" prefix - if parts and parts[0]: - return parts[0].lower() - # Malformed env:// URI (empty provider name) - lib_logger.warning(f"Malformed env:// credential URI: {credential}") - return None - - # Normalize path separators - normalized = credential.replace("\\", "/") - - # Pattern: path ending with {provider}_oauth_{number}.json - match = re.search(r"/([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) - if match: - return match.group(1).lower() - - # Pattern: oauth_creds/{provider}_... - match = re.search(r"oauth_creds/([a-z_]+)_", normalized, re.IGNORECASE) - if match: - return match.group(1).lower() - - # Pattern: filename only {provider}_oauth_{number}.json (no path) - match = re.match(r"([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) - if match: - return match.group(1).lower() - - # Pattern: API key prefixes for specific providers - # These are raw API keys with recognizable prefixes - api_key_prefixes = { - "sk-nano-": "nanogpt", - "sk-or-": "openrouter", - "sk-ant-": "anthropic", - } - for prefix, provider in api_key_prefixes.items(): - if credential.startswith(prefix): - return provider - - # Fallback: For raw API keys, extract provider from model names in usage data - # This handles providers like firmware, chutes, nanogpt that use credential-level quota - if self._usage_data and credential in self._usage_data: - cred_data = self._usage_data[credential] - - # Check "models" section first (for per_model mode and quota tracking) - models_data = cred_data.get("models", {}) - if models_data: - # Get first model name and extract provider prefix - first_model = next(iter(models_data.keys()), None) - if first_model and "/" in first_model: - provider = first_model.split("/")[0].lower() - return provider - - # Fallback to "daily" section (legacy structure) - daily_data = cred_data.get("daily", {}) - daily_models = daily_data.get("models", {}) - if daily_models: - # Get first model name and extract provider prefix - first_model = next(iter(daily_models.keys()), None) - if first_model and "/" in first_model: - provider = first_model.split("/")[0].lower() - return provider - - return None - - def _get_provider_instance(self, provider: str) -> Optional[Any]: - """ - Get or create a provider plugin instance. - - Args: - provider: The provider name - - Returns: - Provider plugin instance or None - """ - if not provider: - return None - - plugin_class = self.provider_plugins.get(provider) - if not plugin_class: - return None - - # Get or create provider instance from cache - if provider not in self._provider_instances: - # Instantiate the plugin if it's a class, or use it directly if already an instance - if isinstance(plugin_class, type): - self._provider_instances[provider] = plugin_class() - else: - self._provider_instances[provider] = plugin_class - - return self._provider_instances[provider] - - def _get_usage_reset_config(self, credential: str) -> Optional[Dict[str, Any]]: - """ - Get the usage reset configuration for a credential from its provider plugin. - - Args: - credential: The credential identifier - - Returns: - Configuration dict with window_seconds, field_name, etc. - or None to use default daily reset. - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_usage_reset_config"): - return plugin_instance.get_usage_reset_config(credential) - - return None - - def _get_reset_mode(self, credential: str) -> str: - """ - Get the reset mode for a credential: 'credential' or 'per_model'. - - Args: - credential: The credential identifier - - Returns: - "per_model" or "credential" (default) - """ - config = self._get_usage_reset_config(credential) - return config.get("mode", "credential") if config else "credential" - - def _get_model_quota_group(self, credential: str, model: str) -> Optional[str]: - """ - Get the quota group for a model, if the provider defines one. - - Args: - credential: The credential identifier - model: Model name (with or without provider prefix) - - Returns: - Group name (e.g., "claude") or None if not grouped - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): - return plugin_instance.get_model_quota_group(model) - - return None - - def _get_grouped_models(self, credential: str, group: str) -> List[str]: - """ - Get all model names in a quota group (with provider prefix), normalized. - - Returns only public-facing model names, deduplicated. Internal variants - (e.g., claude-sonnet-4-5-thinking) are normalized to their public name - (e.g., claude-sonnet-4.5). - - Args: - credential: The credential identifier - group: Group name (e.g., "claude") - - Returns: - List of normalized, deduplicated model names with provider prefix - (e.g., ["antigravity/claude-sonnet-4.5", "antigravity/claude-opus-4.5"]) - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"): - models = plugin_instance.get_models_in_quota_group(group) - - # Normalize and deduplicate - if hasattr(plugin_instance, "normalize_model_for_tracking"): - seen = set() - normalized = [] - for m in models: - prefixed = f"{provider}/{m}" - norm = plugin_instance.normalize_model_for_tracking(prefixed) - if norm not in seen: - seen.add(norm) - normalized.append(norm) - return normalized - - # Fallback: just add provider prefix - return [f"{provider}/{m}" for m in models] - - return [] - - def _get_model_usage_weight(self, credential: str, model: str) -> int: - """ - Get the usage weight for a model when calculating grouped usage. - - Args: - credential: The credential identifier - model: Model name (with or without provider prefix) - - Returns: - Weight multiplier (default 1 if not configured) - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_model_usage_weight"): - return plugin_instance.get_model_usage_weight(model) - - return 1 - - def _normalize_model(self, credential: str, model: str) -> str: - """ - Normalize model name using provider's mapping. - - Converts internal model names (e.g., claude-sonnet-4-5-thinking) to - public-facing names (e.g., claude-sonnet-4.5) for consistent storage. - - Args: - credential: The credential identifier - model: Model name (with or without provider prefix) - - Returns: - Normalized model name (provider prefix preserved if present) - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "normalize_model_for_tracking"): - return plugin_instance.normalize_model_for_tracking(model) - - return model - - # Providers where request_count should be used for credential selection - # instead of success_count (because failed requests also consume quota) - _REQUEST_COUNT_PROVIDERS = {"antigravity", "gemini_cli", "chutes", "nanogpt"} - - def _get_grouped_usage_count(self, key: str, model: str) -> int: - """ - Get usage count for credential selection, considering quota groups. - - For providers in _REQUEST_COUNT_PROVIDERS (e.g., antigravity), uses - request_count instead of success_count since failed requests also - consume quota. - - If the model belongs to a quota group, the request_count is already - synced across all models in the group (by record_success/record_failure), - so we just read from the requested model directly. - - Args: - key: Credential identifier - model: Model name (with provider prefix, e.g., "antigravity/claude-sonnet-4-5") - - Returns: - Usage count for the model (synced across group if applicable) - """ - # Determine usage field based on provider - # Some providers (antigravity) count failed requests against quota - provider = self._get_provider_from_credential(key) - usage_field = ( - "request_count" - if provider in self._REQUEST_COUNT_PROVIDERS - else "success_count" - ) - - # For providers with synced quota groups (antigravity), request_count - # is already synced across all models in the group, so just read directly. - # For other providers, we still need to sum success_count across group. - if provider in self._REQUEST_COUNT_PROVIDERS: - # request_count is synced - just read the model's value - return self._get_usage_count(key, model, usage_field) - - # For non-synced providers, check if model is in a quota group and sum - group = self._get_model_quota_group(key, model) - - if group: - # Get all models in the group - grouped_models = self._get_grouped_models(key, group) - - # Sum weighted usage across all models in the group - total_weighted_usage = 0 - for grouped_model in grouped_models: - usage = self._get_usage_count(key, grouped_model, usage_field) - weight = self._get_model_usage_weight(key, grouped_model) - total_weighted_usage += usage * weight - return total_weighted_usage - - # Not grouped - return individual model usage (no weight applied) - return self._get_usage_count(key, model, usage_field) - - def _get_quota_display(self, key: str, model: str) -> str: - """ - Get a formatted quota display string for logging. - - For antigravity (providers in _REQUEST_COUNT_PROVIDERS), returns: - "quota: 170/250 [32%]" format - - For other providers, returns: - "usage: 170" format (no max available) - - Args: - key: Credential identifier - model: Model name (with provider prefix) - - Returns: - Formatted string for logging - """ - provider = self._get_provider_from_credential(key) - - if provider not in self._REQUEST_COUNT_PROVIDERS: - # Non-antigravity: just show usage count - usage = self._get_usage_count(key, model, "success_count") - return f"usage: {usage}" - - # Antigravity: show quota display with remaining percentage - if self._usage_data is None: - return "quota: 0/? [100%]" - - # Normalize model name for consistent lookup (data is stored under normalized names) - model = self._normalize_model(key, model) - - key_data = self._usage_data.get(key, {}) - model_data = key_data.get("models", {}).get(model, {}) - - request_count = model_data.get("request_count", 0) - max_requests = model_data.get("quota_max_requests") - - if max_requests: - remaining = max_requests - request_count - remaining_pct = ( - int((remaining / max_requests) * 100) if max_requests > 0 else 0 - ) - return f"quota: {request_count}/{max_requests} [{remaining_pct}%]" - else: - return f"quota: {request_count}" - - def _get_usage_field_name(self, credential: str) -> str: - """ - Get the usage tracking field name for a credential. - - Returns the provider-specific field name if configured, - otherwise falls back to "daily". - - Args: - credential: The credential identifier - - Returns: - Field name string (e.g., "5h_window", "weekly", "daily") - """ - config = self._get_usage_reset_config(credential) - if config and "field_name" in config: - return config["field_name"] - - # Check provider default - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_default_usage_field_name"): - return plugin_instance.get_default_usage_field_name() - - return "daily" - - def _get_usage_count( - self, key: str, model: str, field: str = "success_count" - ) -> int: - """ - Get the current usage count for a model from the appropriate usage structure. - - Supports both: - - New per-model structure: {"models": {"model_name": {"success_count": N, ...}}} - - Legacy structure: {"daily": {"models": {"model_name": {"success_count": N, ...}}}} - - Args: - key: Credential identifier - model: Model name - field: The field to read for usage count (default: "success_count"). - Use "request_count" for providers where failed requests also - consume quota (e.g., antigravity). - - Returns: - Usage count for the model in the current window/period - """ - if self._usage_data is None: - return 0 - - # Normalize model name for consistent lookup (data is stored under normalized names) - model = self._normalize_model(key, model) - - key_data = self._usage_data.get(key, {}) - reset_mode = self._get_reset_mode(key) - - if reset_mode == "per_model": - # New per-model structure: key_data["models"][model][field] - return key_data.get("models", {}).get(model, {}).get(field, 0) - else: - # Legacy structure: key_data["daily"]["models"][model][field] - return ( - key_data.get("daily", {}).get("models", {}).get(model, {}).get(field, 0) - ) - - # ========================================================================= - # TIMESTAMP FORMATTING HELPERS - # ========================================================================= - - def _format_timestamp_local(self, ts: Optional[float]) -> Optional[str]: - """ - Format Unix timestamp as local time string with timezone offset. - - Args: - ts: Unix timestamp or None - - Returns: - Formatted string like "2025-12-07 14:30:17 +0100" or None - """ - if ts is None: - return None - try: - dt = datetime.fromtimestamp(ts).astimezone() # Local timezone - # Use UTC offset for conciseness (works on all platforms) - return dt.strftime("%Y-%m-%d %H:%M:%S %z") - except (OSError, ValueError, OverflowError): - return None - - def _add_readable_timestamps(self, data: Dict) -> Dict: - """ - Add human-readable timestamp fields to usage data before saving. - - Adds 'window_started' and 'quota_resets' fields derived from - Unix timestamps for easier debugging and monitoring. - - Args: - data: The usage data dict to enhance - - Returns: - The same dict with readable timestamp fields added - """ - for key, key_data in data.items(): - # Handle per-model structure - models = key_data.get("models", {}) - for model_name, model_stats in models.items(): - if not isinstance(model_stats, dict): - continue - - # Add readable window start time - window_start = model_stats.get("window_start_ts") - if window_start: - model_stats["window_started"] = self._format_timestamp_local( - window_start - ) - elif "window_started" in model_stats: - del model_stats["window_started"] - - # Add readable reset time - quota_reset = model_stats.get("quota_reset_ts") - if quota_reset: - model_stats["quota_resets"] = self._format_timestamp_local( - quota_reset - ) - elif "quota_resets" in model_stats: - del model_stats["quota_resets"] - - return data - - def _sort_sequential( - self, - candidates: List[Tuple[str, int]], - credential_priorities: Optional[Dict[str, int]] = None, - ) -> List[Tuple[str, int]]: - """ - Sort credentials for sequential mode with position retention. - - Credentials maintain their position based on established usage patterns, - ensuring that actively-used credentials remain primary until exhausted. - - Sorting order (within each sort key, lower value = higher priority): - 1. Priority tier (lower number = higher priority) - 2. Usage count (higher = more established in rotation, maintains position) - 3. Last used timestamp (higher = more recent, tiebreaker for stickiness) - 4. Credential ID (alphabetical, stable ordering) - - Args: - candidates: List of (credential_id, usage_count) tuples - credential_priorities: Optional dict mapping credentials to priority levels - - Returns: - Sorted list of candidates (same format as input) - """ - if not candidates: - return [] - - if len(candidates) == 1: - return candidates - - def sort_key(item: Tuple[str, int]) -> Tuple[int, int, float, str]: - cred, usage_count = item - priority = ( - credential_priorities.get(cred, 999) if credential_priorities else 999 - ) - last_used = ( - self._usage_data.get(cred, {}).get("last_used_ts", 0) - if self._usage_data - else 0 - ) - return ( - priority, # ASC: lower priority number = higher priority - -usage_count, # DESC: higher usage = more established - -last_used, # DESC: more recent = preferred for ties - cred, # ASC: stable alphabetical ordering - ) - - sorted_candidates = sorted(candidates, key=sort_key) - - # Debug logging - show top 3 credentials in ordering - if lib_logger.isEnabledFor(logging.DEBUG): - order_info = [ - f"{mask_credential(c)}(p={credential_priorities.get(c, 999) if credential_priorities else 'N/A'}, u={u})" - for c, u in sorted_candidates[:3] - ] - lib_logger.debug(f"Sequential ordering: {' → '.join(order_info)}") - - return sorted_candidates - - # ========================================================================= - # FAIR CYCLE PERSISTENCE - # ========================================================================= - - def _serialize_cycle_state(self) -> Dict[str, Any]: - """ - Serialize in-memory cycle state for JSON persistence. - - Converts sets to lists for JSON compatibility. - """ - result: Dict[str, Any] = {} - for provider, tier_data in self._cycle_exhausted.items(): - result[provider] = {} - for tier_key, tracking_data in tier_data.items(): - result[provider][tier_key] = {} - for tracking_key, cycle_data in tracking_data.items(): - result[provider][tier_key][tracking_key] = { - "cycle_started_at": cycle_data.get("cycle_started_at"), - "exhausted": list(cycle_data.get("exhausted", set())), - } - return result - - def _deserialize_cycle_state(self, data: Dict[str, Any]) -> None: - """ - Deserialize cycle state from JSON and populate in-memory structure. - - Converts lists back to sets and validates expired cycles. - """ - self._cycle_exhausted = {} - now_ts = time.time() - - for provider, tier_data in data.items(): - if not isinstance(tier_data, dict): - continue - self._cycle_exhausted[provider] = {} - - for tier_key, tracking_data in tier_data.items(): - if not isinstance(tracking_data, dict): - continue - self._cycle_exhausted[provider][tier_key] = {} - - for tracking_key, cycle_data in tracking_data.items(): - if not isinstance(cycle_data, dict): - continue - - cycle_started = cycle_data.get("cycle_started_at") - exhausted_list = cycle_data.get("exhausted", []) - - # Check if cycle has expired - if cycle_started is not None: - duration = self._get_fair_cycle_duration(provider) - if now_ts >= cycle_started + duration: - # Cycle expired - skip (don't restore) - lib_logger.debug( - f"Fair cycle expired for {provider}/{tier_key}/{tracking_key} - not restoring" - ) - continue - - # Restore valid cycle - self._cycle_exhausted[provider][tier_key][tracking_key] = { - "cycle_started_at": cycle_started, - "exhausted": set(exhausted_list) if exhausted_list else set(), - } - - # Log restoration summary - total_cycles = sum( - len(tracking) - for tier in self._cycle_exhausted.values() - for tracking in tier.values() - ) - if total_cycles > 0: - lib_logger.info(f"Restored {total_cycles} active fair cycle(s) from disk") - - async def _lazy_init(self): - """Initializes the usage data by loading it from the file asynchronously.""" - async with self._init_lock: - if not self._initialized.is_set(): - await self._load_usage() - await self._reset_daily_stats_if_needed() - self._initialized.set() - - async def _load_usage(self): - """Loads usage data from the JSON file asynchronously with resilience.""" - async with self._data_lock: - if not os.path.exists(self.file_path): - self._usage_data = {} - return - - try: - async with aiofiles.open(self.file_path, "r") as f: - content = await f.read() - self._usage_data = json.loads(content) if content.strip() else {} - except FileNotFoundError: - # File deleted between exists check and open - self._usage_data = {} - except json.JSONDecodeError as e: - lib_logger.warning( - f"Corrupted usage file {self.file_path}: {e}. Starting fresh." - ) - self._usage_data = {} - except (OSError, PermissionError, IOError) as e: - lib_logger.warning( - f"Cannot read usage file {self.file_path}: {e}. Using empty state." - ) - self._usage_data = {} - - # Restore fair cycle state from persisted data - fair_cycle_data = self._usage_data.get("__fair_cycle__", {}) - if fair_cycle_data: - self._deserialize_cycle_state(fair_cycle_data) - - async def _save_usage(self): - """Saves the current usage data using the resilient state writer.""" - if self._usage_data is None: - return - - async with self._data_lock: - # Add human-readable timestamp fields before saving - self._add_readable_timestamps(self._usage_data) - - # Persist fair cycle state (separate from credential data) - if self._cycle_exhausted: - self._usage_data["__fair_cycle__"] = self._serialize_cycle_state() - elif "__fair_cycle__" in self._usage_data: - # Clean up empty cycle data - del self._usage_data["__fair_cycle__"] - - # Hand off to resilient writer - handles retries and disk failures - self._state_writer.write(self._usage_data) - - async def _get_usage_data_snapshot(self) -> Dict[str, Any]: - """ - Get a shallow copy of the current usage data. - - Returns: - Copy of usage data dict (safe for reading without lock) - """ - await self._lazy_init() - async with self._data_lock: - return dict(self._usage_data) if self._usage_data else {} - - async def get_available_credentials_for_model( - self, credentials: List[str], model: str - ) -> List[str]: - """ - Get credentials that are not on cooldown for a specific model. - - Filters out credentials where: - - key_cooldown_until > now (key-level cooldown) - - model_cooldowns[model] > now (model-specific cooldown, includes quota exhausted) - - Args: - credentials: List of credential identifiers to check - model: Model name to check cooldowns for - - Returns: - List of credentials that are available (not on cooldown) for this model - """ - await self._lazy_init() - now = time.time() - available = [] - - async with self._data_lock: - for key in credentials: - key_data = self._usage_data.get(key, {}) - - # Skip if key-level cooldown is active - if (key_data.get("key_cooldown_until") or 0) > now: - continue - - # Normalize model name for consistent cooldown lookup - # (cooldowns are stored under normalized names by record_failure) - # For providers without normalize_model_for_tracking (non-Antigravity), - # this returns the model unchanged, so cooldown lookups work as before. - normalized_model = self._normalize_model(key, model) - - # Skip if model-specific cooldown is active - if ( - key_data.get("model_cooldowns", {}).get(normalized_model) or 0 - ) > now: - continue - - available.append(key) - - return available - - async def get_credential_availability_stats( - self, - credentials: List[str], - model: str, - credential_priorities: Optional[Dict[str, int]] = None, - ) -> Dict[str, int]: - """ - Get credential availability statistics including cooldown and fair cycle exclusions. - - This is used for logging to show why credentials are excluded. - - Args: - credentials: List of credential identifiers to check - model: Model name to check - credential_priorities: Optional dict mapping credentials to priorities - - Returns: - Dict with: - "total": Total credentials - "on_cooldown": Count on cooldown - "fair_cycle_excluded": Count excluded by fair cycle - "available": Count available for selection - """ - await self._lazy_init() - now = time.time() - - total = len(credentials) - on_cooldown = 0 - not_on_cooldown = [] - - # First pass: check cooldowns - async with self._data_lock: - for key in credentials: - key_data = self._usage_data.get(key, {}) - - # Check if key-level or model-level cooldown is active - normalized_model = self._normalize_model(key, model) - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(normalized_model) or 0 - ) > now: - on_cooldown += 1 - else: - not_on_cooldown.append(key) - - # Second pass: check fair cycle exclusions (only for non-cooldown credentials) - fair_cycle_excluded = 0 - if not_on_cooldown: - provider = self._get_provider_from_credential(not_on_cooldown[0]) - if provider: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - # Check each credential against its own tier's exhausted set - for key in not_on_cooldown: - key_priority = ( - credential_priorities.get(key, 999) - if credential_priorities - else 999 - ) - tier_key = self._get_tier_key(provider, key_priority) - tracking_key = self._get_tracking_key(key, model, provider) - - if self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ): - fair_cycle_excluded += 1 - - available = total - on_cooldown - fair_cycle_excluded - - return { - "total": total, - "on_cooldown": on_cooldown, - "fair_cycle_excluded": fair_cycle_excluded, - "available": available, - } - - async def get_soonest_cooldown_end( - self, - credentials: List[str], - model: str, - ) -> Optional[float]: - """ - Find the soonest time when any credential will come off cooldown. - - This is used for smart waiting logic - if no credentials are available, - we can determine whether to wait (if soonest cooldown < deadline) or - fail fast (if soonest cooldown > deadline). - - Args: - credentials: List of credential identifiers to check - model: Model name to check cooldowns for - - Returns: - Timestamp of soonest cooldown end, or None if no credentials are on cooldown - """ - await self._lazy_init() - now = time.time() - soonest_end = None - - async with self._data_lock: - for key in credentials: - key_data = self._usage_data.get(key, {}) - normalized_model = self._normalize_model(key, model) - - # Check key-level cooldown - key_cooldown = key_data.get("key_cooldown_until") or 0 - if key_cooldown > now: - if soonest_end is None or key_cooldown < soonest_end: - soonest_end = key_cooldown - - # Check model-level cooldown - model_cooldown = ( - key_data.get("model_cooldowns", {}).get(normalized_model) or 0 - ) - if model_cooldown > now: - if soonest_end is None or model_cooldown < soonest_end: - soonest_end = model_cooldown - - return soonest_end - - async def _reset_daily_stats_if_needed(self): - """ - Checks if usage stats need to be reset for any key. - - Supports three reset modes: - 1. per_model: Each model has its own window, resets based on quota_reset_ts or fallback window - 2. credential: One window per credential (legacy with custom window duration) - 3. daily: Legacy daily reset at daily_reset_time_utc - """ - if self._usage_data is None: - return - - now_utc = datetime.now(timezone.utc) - now_ts = time.time() - today_str = now_utc.date().isoformat() - needs_saving = False - - for key, data in self._usage_data.items(): - reset_config = self._get_usage_reset_config(key) - - if reset_config: - reset_mode = reset_config.get("mode", "credential") - - if reset_mode == "per_model": - # Per-model window reset - needs_saving |= await self._check_per_model_resets( - key, data, reset_config, now_ts - ) - else: - # Credential-level window reset (legacy) - needs_saving |= await self._check_window_reset( - key, data, reset_config, now_ts - ) - elif self.daily_reset_time_utc: - # Legacy daily reset - needs_saving |= await self._check_daily_reset( - key, data, now_utc, today_str, now_ts - ) - - if needs_saving: - await self._save_usage() - - async def _check_per_model_resets( - self, - key: str, - data: Dict[str, Any], - reset_config: Dict[str, Any], - now_ts: float, - ) -> bool: - """ - Check and perform per-model resets for a credential. - - Each model resets independently based on: - 1. quota_reset_ts (authoritative, from quota exhausted error) if set - 2. window_start_ts + window_seconds (fallback) otherwise - - Grouped models reset together - all models in a group must be ready. - - Args: - key: Credential identifier - data: Usage data for this credential - reset_config: Provider's reset configuration - now_ts: Current timestamp - - Returns: - True if data was modified and needs saving - """ - window_seconds = reset_config.get("window_seconds", 86400) - models_data = data.get("models", {}) - - if not models_data: - return False - - modified = False - processed_groups = set() - - for model, model_data in list(models_data.items()): - # Check if this model is in a quota group - group = self._get_model_quota_group(key, model) - - if group: - if group in processed_groups: - continue # Already handled this group - - # Check if entire group should reset - if self._should_group_reset( - key, group, models_data, window_seconds, now_ts - ): - # Archive and reset all models in group - grouped_models = self._get_grouped_models(key, group) - archived_count = 0 - - for grouped_model in grouped_models: - if grouped_model in models_data: - gm_data = models_data[grouped_model] - self._archive_model_to_global(data, grouped_model, gm_data) - self._reset_model_data(gm_data) - archived_count += 1 - - if archived_count > 0: - lib_logger.info( - f"Reset model group '{group}' ({archived_count} models) for {mask_credential(key)}" - ) - modified = True - - processed_groups.add(group) - - else: - # Ungrouped model - check individually - if self._should_model_reset(model_data, window_seconds, now_ts): - self._archive_model_to_global(data, model, model_data) - self._reset_model_data(model_data) - lib_logger.info(f"Reset model {model} for {mask_credential(key)}") - modified = True - - # Preserve unexpired cooldowns - if modified: - self._preserve_unexpired_cooldowns(key, data, now_ts) - if "failures" in data: - data["failures"] = {} - - return modified - - def _should_model_reset( - self, model_data: Dict[str, Any], window_seconds: int, now_ts: float - ) -> bool: - """ - Check if a single model should reset. - - Returns True if: - - quota_reset_ts is set AND now >= quota_reset_ts, OR - - quota_reset_ts is NOT set AND now >= window_start_ts + window_seconds - """ - quota_reset = model_data.get("quota_reset_ts") - window_start = model_data.get("window_start_ts") - - if quota_reset: - return now_ts >= quota_reset - elif window_start: - return now_ts >= window_start + window_seconds - return False - - def _should_group_reset( - self, - key: str, - group: str, - models_data: Dict[str, Dict], - window_seconds: int, - now_ts: float, - ) -> bool: - """ - Check if all models in a group should reset. - - All models in the group must be ready to reset. - If any model has an active cooldown/window, the whole group waits. - """ - grouped_models = self._get_grouped_models(key, group) - - # Track if any model in group has data - any_has_data = False - - for grouped_model in grouped_models: - model_data = models_data.get(grouped_model, {}) - - if not model_data or ( - model_data.get("window_start_ts") is None - and model_data.get("success_count", 0) == 0 - ): - continue # No stats for this model yet - - any_has_data = True - - if not self._should_model_reset(model_data, window_seconds, now_ts): - return False # At least one model not ready - - return any_has_data - - def _archive_model_to_global( - self, data: Dict[str, Any], model: str, model_data: Dict[str, Any] - ) -> None: - """Archive a single model's stats to global.""" - global_data = data.setdefault("global", {"models": {}}) - global_model = global_data["models"].setdefault( - model, - { - "success_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - - global_model["success_count"] += model_data.get("success_count", 0) - global_model["prompt_tokens"] += model_data.get("prompt_tokens", 0) - global_model["prompt_tokens_cached"] = global_model.get( - "prompt_tokens_cached", 0 - ) + model_data.get("prompt_tokens_cached", 0) - global_model["completion_tokens"] += model_data.get("completion_tokens", 0) - global_model["approx_cost"] += model_data.get("approx_cost", 0.0) - - def _reset_model_data(self, model_data: Dict[str, Any]) -> None: - """Reset a model's window and stats.""" - model_data["window_start_ts"] = None - model_data["quota_reset_ts"] = None - model_data["success_count"] = 0 - model_data["failure_count"] = 0 - model_data["request_count"] = 0 - model_data["prompt_tokens"] = 0 - model_data["completion_tokens"] = 0 - model_data["approx_cost"] = 0.0 - # Reset quota baseline fields only if they exist (Antigravity-specific) - # These are added by update_quota_baseline(), only called for Antigravity - if "baseline_remaining_fraction" in model_data: - model_data["baseline_remaining_fraction"] = None - model_data["baseline_fetched_at"] = None - model_data["requests_at_baseline"] = None - # Reset quota display but keep max_requests (it doesn't change between periods) - max_req = model_data.get("quota_max_requests") - if max_req: - model_data["quota_display"] = f"0/{max_req}" - - async def _check_window_reset( - self, - key: str, - data: Dict[str, Any], - reset_config: Dict[str, Any], - now_ts: float, - ) -> bool: - """ - Check and perform rolling window reset for a credential. - - Args: - key: Credential identifier - data: Usage data for this credential - reset_config: Provider's reset configuration - now_ts: Current timestamp - - Returns: - True if data was modified and needs saving - """ - window_seconds = reset_config.get("window_seconds", 86400) # Default 24h - field_name = reset_config.get("field_name", "window") - description = reset_config.get("description", "rolling window") - - # Get current window data - window_data = data.get(field_name, {}) - window_start = window_data.get("start_ts") - - # No window started yet - nothing to reset - if window_start is None: - return False - - # Check if window has expired - window_end = window_start + window_seconds - if now_ts < window_end: - # Window still active - return False - - # Window expired - perform reset - hours_elapsed = (now_ts - window_start) / 3600 - lib_logger.info( - f"Resetting {field_name} for {mask_credential(key)} - " - f"{description} expired after {hours_elapsed:.1f}h" - ) - - # Archive to global - self._archive_to_global(data, window_data) - - # Preserve unexpired cooldowns - self._preserve_unexpired_cooldowns(key, data, now_ts) - - # Reset window stats (but don't start new window until first request) - data[field_name] = {"start_ts": None, "models": {}} - - # Reset consecutive failures - if "failures" in data: - data["failures"] = {} - - return True - - async def _check_daily_reset( - self, - key: str, - data: Dict[str, Any], - now_utc: datetime, - today_str: str, - now_ts: float, - ) -> bool: - """ - Check and perform legacy daily reset for a credential. - - Args: - key: Credential identifier - data: Usage data for this credential - now_utc: Current datetime in UTC - today_str: Today's date as ISO string - now_ts: Current timestamp - - Returns: - True if data was modified and needs saving - """ - last_reset_str = data.get("last_daily_reset", "") - - if last_reset_str == today_str: - return False - - last_reset_dt = None - if last_reset_str: - try: - last_reset_dt = datetime.fromisoformat(last_reset_str).replace( - tzinfo=timezone.utc - ) - except ValueError: - pass - - # Determine the reset threshold for today - reset_threshold_today = datetime.combine( - now_utc.date(), self.daily_reset_time_utc - ) - - if not ( - last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc - ): - return False - - lib_logger.debug(f"Performing daily reset for key {mask_credential(key)}") - - # Preserve unexpired cooldowns - self._preserve_unexpired_cooldowns(key, data, now_ts) - - # Reset consecutive failures - if "failures" in data: - data["failures"] = {} - - # Archive daily stats to global - daily_data = data.get("daily", {}) - if daily_data: - self._archive_to_global(data, daily_data) - - # Reset daily stats - data["daily"] = {"date": today_str, "models": {}} - data["last_daily_reset"] = today_str - - return True - - def _archive_to_global( - self, data: Dict[str, Any], source_data: Dict[str, Any] - ) -> None: - """ - Archive usage stats from a source field (daily/window) to global. - - Args: - data: The credential's usage data - source_data: The source field data to archive (has "models" key) - """ - global_data = data.setdefault("global", {"models": {}}) - for model, stats in source_data.get("models", {}).items(): - global_model_stats = global_data["models"].setdefault( - model, - { - "success_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - global_model_stats["success_count"] += stats.get("success_count", 0) - global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0) - global_model_stats["prompt_tokens_cached"] = global_model_stats.get( - "prompt_tokens_cached", 0 - ) + stats.get("prompt_tokens_cached", 0) - global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0) - global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0) - - def _preserve_unexpired_cooldowns( - self, key: str, data: Dict[str, Any], now_ts: float - ) -> None: - """ - Preserve unexpired cooldowns during reset (important for long quota cooldowns). - - Args: - key: Credential identifier (for logging) - data: The credential's usage data - now_ts: Current timestamp - """ - # Preserve unexpired model cooldowns - if "model_cooldowns" in data: - active_cooldowns = { - model: end_time - for model, end_time in data["model_cooldowns"].items() - if end_time > now_ts - } - if active_cooldowns: - max_remaining = max( - end_time - now_ts for end_time in active_cooldowns.values() - ) - hours_remaining = max_remaining / 3600 - lib_logger.info( - f"Preserving {len(active_cooldowns)} active cooldown(s) " - f"for key {mask_credential(key)} during reset " - f"(longest: {hours_remaining:.1f}h remaining)" - ) - data["model_cooldowns"] = active_cooldowns - else: - data["model_cooldowns"] = {} - - # Preserve unexpired key-level cooldown - if data.get("key_cooldown_until"): - if data["key_cooldown_until"] <= now_ts: - data["key_cooldown_until"] = None - else: - hours_remaining = (data["key_cooldown_until"] - now_ts) / 3600 - lib_logger.info( - f"Preserving key-level cooldown for {mask_credential(key)} " - f"during reset ({hours_remaining:.1f}h remaining)" - ) - else: - data["key_cooldown_until"] = None - - def _initialize_key_states(self, keys: List[str]): - """Initializes state tracking for all provided keys if not already present.""" - for key in keys: - if key not in self.key_states: - self.key_states[key] = { - "lock": asyncio.Lock(), - "condition": asyncio.Condition(), - "models_in_use": {}, # Dict[model_name, concurrent_count] - } - - def _select_weighted_random(self, candidates: List[tuple], tolerance: float) -> str: - """ - Selects a credential using weighted random selection based on usage counts. - - Args: - candidates: List of (credential_id, usage_count) tuples - tolerance: Tolerance value for weight calculation - - Returns: - Selected credential ID - - Formula: - weight = (max_usage - credential_usage) + tolerance + 1 - - This formula ensures: - - Lower usage = higher weight = higher selection probability - - Tolerance adds variability: higher tolerance means more randomness - - The +1 ensures all credentials have at least some chance of selection - """ - if not candidates: - raise ValueError("Cannot select from empty candidate list") - - if len(candidates) == 1: - return candidates[0][0] - - # Extract usage counts - usage_counts = [usage for _, usage in candidates] - max_usage = max(usage_counts) - - # Calculate weights using the formula: (max - current) + tolerance + 1 - weights = [] - for credential, usage in candidates: - weight = (max_usage - usage) + tolerance + 1 - weights.append(weight) - - # Log weight distribution for debugging - if lib_logger.isEnabledFor(logging.DEBUG): - total_weight = sum(weights) - weight_info = ", ".join( - f"{mask_credential(cred)}: w={w:.1f} ({w / total_weight * 100:.1f}%)" - for (cred, _), w in zip(candidates, weights) - ) - # lib_logger.debug(f"Weighted selection candidates: {weight_info}") - - # Random selection with weights - selected_credential = random.choices( - [cred for cred, _ in candidates], weights=weights, k=1 - )[0] - - return selected_credential - - async def acquire_key( - self, - available_keys: List[str], - model: str, - deadline: float, - max_concurrent: int = 1, - credential_priorities: Optional[Dict[str, int]] = None, - credential_tier_names: Optional[Dict[str, str]] = None, - all_provider_credentials: Optional[List[str]] = None, - ) -> str: - """ - Acquires the best available key using a tiered, model-aware locking strategy, - respecting a global deadline and credential priorities. - - Priority Logic: - - Groups credentials by priority level (1=highest, 2=lower, etc.) - - Always tries highest priority (lowest number) first - - Within same priority, sorts by usage count (load balancing) - - Only moves to next priority if all higher-priority keys exhausted/busy - - Args: - available_keys: List of credential identifiers to choose from - model: Model name being requested - deadline: Timestamp after which to stop trying - max_concurrent: Maximum concurrent requests allowed per credential - credential_priorities: Optional dict mapping credentials to priority levels (1=highest) - credential_tier_names: Optional dict mapping credentials to tier names (for logging) - all_provider_credentials: Full list of provider credentials (used for cycle reset checks) - - Returns: - Selected credential identifier - - Raises: - NoAvailableKeysError: If no key could be acquired within the deadline - """ - await self._lazy_init() - await self._reset_daily_stats_if_needed() - self._initialize_key_states(available_keys) - - # Normalize model name for consistent cooldown lookup - # (cooldowns are stored under normalized names by record_failure) - # Use first credential for provider detection; all credentials passed here - # are for the same provider (filtered by client.py before calling acquire_key). - # For providers without normalize_model_for_tracking (non-Antigravity), - # this returns the model unchanged, so cooldown lookups work as before. - normalized_model = ( - self._normalize_model(available_keys[0], model) if available_keys else model - ) - - # This loop continues as long as the global deadline has not been met. - while time.time() < deadline: - now = time.time() - - # Group credentials by priority level (if priorities provided) - if credential_priorities: - # Group keys by priority level - priority_groups = {} - async with self._data_lock: - for key in available_keys: - key_data = self._usage_data.get(key, {}) - - # Skip keys on cooldown (use normalized model for lookup) - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(normalized_model) - or 0 - ) > now: - continue - - # Get priority for this key (default to 999 if not specified) - priority = credential_priorities.get(key, 999) - - # Get usage count for load balancing within priority groups - # Uses grouped usage if model is in a quota group - usage_count = self._get_grouped_usage_count(key, model) - - # Group by priority - if priority not in priority_groups: - priority_groups[priority] = [] - priority_groups[priority].append((key, usage_count)) - - # Try priority groups in order (1, 2, 3, ...) - sorted_priorities = sorted(priority_groups.keys()) - - for priority_level in sorted_priorities: - keys_in_priority = priority_groups[priority_level] - - # Determine selection method based on provider's rotation mode - provider = model.split("/")[0] if "/" in model else "" - rotation_mode = self._get_rotation_mode(provider) - - # Fair cycle filtering - if provider and self._is_fair_cycle_enabled( - provider, rotation_mode - ): - tier_key = self._get_tier_key(provider, priority_level) - tracking_key = self._get_tracking_key( - keys_in_priority[0][0] if keys_in_priority else "", - model, - provider, - ) - - # Get all credentials for this tier (for cycle completion check) - all_tier_creds = self._get_all_credentials_for_tier_key( - provider, - tier_key, - all_provider_credentials or available_keys, - credential_priorities, - ) - - # Check if cycle should reset (all exhausted, expired, or none available) - if self._should_reset_cycle( - provider, - tier_key, - tracking_key, - all_tier_creds, - available_not_on_cooldown=[ - key for key, _ in keys_in_priority - ], - ): - self._reset_cycle(provider, tier_key, tracking_key) - - # Filter out exhausted credentials - filtered_keys = [] - for key, usage_count in keys_in_priority: - if not self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ): - filtered_keys.append((key, usage_count)) - - keys_in_priority = filtered_keys - - # Calculate effective concurrency based on priority tier - multiplier = self._get_priority_multiplier( - provider, priority_level, rotation_mode - ) - effective_max_concurrent = max_concurrent * multiplier - - # Within each priority group, use existing tier1/tier2 logic - tier1_keys, tier2_keys = [], [] - for key, usage_count in keys_in_priority: - key_state = self.key_states[key] - - # Tier 1: Completely idle keys (preferred) - if not key_state["models_in_use"]: - tier1_keys.append((key, usage_count)) - # Tier 2: Keys that can accept more concurrent requests - elif ( - key_state["models_in_use"].get(model, 0) - < effective_max_concurrent - ): - tier2_keys.append((key, usage_count)) - - if rotation_mode == "sequential": - # Sequential mode: sort credentials by priority, usage, recency - # Keep all candidates in sorted order (no filtering to single key) - selection_method = "sequential" - if tier1_keys: - tier1_keys = self._sort_sequential( - tier1_keys, credential_priorities - ) - if tier2_keys: - tier2_keys = self._sort_sequential( - tier2_keys, credential_priorities - ) - elif self.rotation_tolerance > 0: - # Balanced mode with weighted randomness - selection_method = "weighted-random" - if tier1_keys: - selected_key = self._select_weighted_random( - tier1_keys, self.rotation_tolerance - ) - tier1_keys = [ - (k, u) for k, u in tier1_keys if k == selected_key - ] - if tier2_keys: - selected_key = self._select_weighted_random( - tier2_keys, self.rotation_tolerance - ) - tier2_keys = [ - (k, u) for k, u in tier2_keys if k == selected_key - ] - else: - # Deterministic: sort by usage within each tier - selection_method = "least-used" - tier1_keys.sort(key=lambda x: x[1]) - tier2_keys.sort(key=lambda x: x[1]) - - # Try to acquire from Tier 1 first - for key, usage in tier1_keys: - state = self.key_states[key] - async with state["lock"]: - if not state["models_in_use"]: - state["models_in_use"][model] = 1 - tier_name = ( - credential_tier_names.get(key, "unknown") - if credential_tier_names - else "unknown" - ) - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, {quota_display})" - ) - return key - - # Then try Tier 2 - for key, usage in tier2_keys: - state = self.key_states[key] - async with state["lock"]: - current_count = state["models_in_use"].get(model, 0) - if current_count < effective_max_concurrent: - state["models_in_use"][model] = current_count + 1 - tier_name = ( - credential_tier_names.get(key, "unknown") - if credential_tier_names - else "unknown" - ) - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" - ) - return key - - # If we get here, all priority groups were exhausted but keys might become available - # Collect all keys across all priorities for waiting - all_potential_keys = [] - for keys_list in priority_groups.values(): - all_potential_keys.extend(keys_list) - - if not all_potential_keys: - # All credentials are on cooldown - check if waiting makes sense - soonest_end = await self.get_soonest_cooldown_end( - available_keys, model - ) - - if soonest_end is None: - # No cooldowns active but no keys available (shouldn't happen) - lib_logger.warning( - "No keys eligible and no cooldowns active. Re-evaluating..." - ) - await asyncio.sleep(1) - continue - - remaining_budget = deadline - time.time() - wait_needed = soonest_end - time.time() - - if wait_needed > remaining_budget: - # Fail fast - no credential will be available in time - lib_logger.warning( - f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " - f"but only {remaining_budget:.1f}s budget remaining. Failing fast." - ) - break # Exit loop, will raise NoAvailableKeysError - - # Wait for the credential to become available - lib_logger.info( - f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." - ) - await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) - continue - - # Wait for the highest priority key with lowest usage - best_priority = min(priority_groups.keys()) - best_priority_keys = priority_groups[best_priority] - best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0] - wait_condition = self.key_states[best_wait_key]["condition"] - - lib_logger.info( - f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..." - ) - - else: - # Original logic when no priorities specified - - # Determine selection method based on provider's rotation mode - provider = model.split("/")[0] if "/" in model else "" - rotation_mode = self._get_rotation_mode(provider) - - # Calculate effective concurrency for default priority (999) - # When no priorities are specified, all credentials get default priority - default_priority = 999 - multiplier = self._get_priority_multiplier( - provider, default_priority, rotation_mode - ) - effective_max_concurrent = max_concurrent * multiplier - - tier1_keys, tier2_keys = [], [] - - # First, filter the list of available keys to exclude any on cooldown. - async with self._data_lock: - for key in available_keys: - key_data = self._usage_data.get(key, {}) - - # Skip keys on cooldown (use normalized model for lookup) - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(normalized_model) - or 0 - ) > now: - continue - - # Prioritize keys based on their current usage to ensure load balancing. - # Uses grouped usage if model is in a quota group - usage_count = self._get_grouped_usage_count(key, model) - key_state = self.key_states[key] - - # Tier 1: Completely idle keys (preferred). - if not key_state["models_in_use"]: - tier1_keys.append((key, usage_count)) - # Tier 2: Keys that can accept more concurrent requests for this model. - elif ( - key_state["models_in_use"].get(model, 0) - < effective_max_concurrent - ): - tier2_keys.append((key, usage_count)) - - # Fair cycle filtering (non-priority case) - if provider and self._is_fair_cycle_enabled(provider, rotation_mode): - tier_key = self._get_tier_key(provider, default_priority) - tracking_key = self._get_tracking_key( - available_keys[0] if available_keys else "", - model, - provider, - ) - - # Get all credentials for this tier (for cycle completion check) - all_tier_creds = self._get_all_credentials_for_tier_key( - provider, - tier_key, - all_provider_credentials or available_keys, - None, - ) - - # Check if cycle should reset (all exhausted, expired, or none available) - if self._should_reset_cycle( - provider, - tier_key, - tracking_key, - all_tier_creds, - available_not_on_cooldown=[ - key for key, _ in (tier1_keys + tier2_keys) - ], - ): - self._reset_cycle(provider, tier_key, tracking_key) - - # Filter out exhausted credentials from both tiers - tier1_keys = [ - (key, usage) - for key, usage in tier1_keys - if not self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ) - ] - tier2_keys = [ - (key, usage) - for key, usage in tier2_keys - if not self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ) - ] - - if rotation_mode == "sequential": - # Sequential mode: sort credentials by priority, usage, recency - # Keep all candidates in sorted order (no filtering to single key) - selection_method = "sequential" - if tier1_keys: - tier1_keys = self._sort_sequential( - tier1_keys, credential_priorities - ) - if tier2_keys: - tier2_keys = self._sort_sequential( - tier2_keys, credential_priorities - ) - elif self.rotation_tolerance > 0: - # Balanced mode with weighted randomness - selection_method = "weighted-random" - if tier1_keys: - selected_key = self._select_weighted_random( - tier1_keys, self.rotation_tolerance - ) - tier1_keys = [ - (k, u) for k, u in tier1_keys if k == selected_key - ] - if tier2_keys: - selected_key = self._select_weighted_random( - tier2_keys, self.rotation_tolerance - ) - tier2_keys = [ - (k, u) for k, u in tier2_keys if k == selected_key - ] - else: - # Deterministic: sort by usage within each tier - selection_method = "least-used" - tier1_keys.sort(key=lambda x: x[1]) - tier2_keys.sort(key=lambda x: x[1]) - - # Attempt to acquire a key from Tier 1 first. - for key, usage in tier1_keys: - state = self.key_states[key] - async with state["lock"]: - if not state["models_in_use"]: - state["models_in_use"][model] = 1 - tier_name = ( - credential_tier_names.get(key) - if credential_tier_names - else None - ) - tier_info = f"tier: {tier_name}, " if tier_name else "" - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"({tier_info}selection: {selection_method}, {quota_display})" - ) - return key - - # If no Tier 1 keys are available, try Tier 2. - for key, usage in tier2_keys: - state = self.key_states[key] - async with state["lock"]: - current_count = state["models_in_use"].get(model, 0) - if current_count < effective_max_concurrent: - state["models_in_use"][model] = current_count + 1 - tier_name = ( - credential_tier_names.get(key) - if credential_tier_names - else None - ) - tier_info = f"tier: {tier_name}, " if tier_name else "" - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"({tier_info}selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" - ) - return key - - # If all eligible keys are locked, wait for a key to be released. - lib_logger.info( - "All eligible keys are currently locked for this model. Waiting..." - ) - - all_potential_keys = tier1_keys + tier2_keys - if not all_potential_keys: - # All credentials are on cooldown - check if waiting makes sense - soonest_end = await self.get_soonest_cooldown_end( - available_keys, model - ) - - if soonest_end is None: - # No cooldowns active but no keys available (shouldn't happen) - lib_logger.warning( - "No keys eligible and no cooldowns active. Re-evaluating..." - ) - await asyncio.sleep(1) - continue - - remaining_budget = deadline - time.time() - wait_needed = soonest_end - time.time() - - if wait_needed > remaining_budget: - # Fail fast - no credential will be available in time - lib_logger.warning( - f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " - f"but only {remaining_budget:.1f}s budget remaining. Failing fast." - ) - break # Exit loop, will raise NoAvailableKeysError - - # Wait for the credential to become available - lib_logger.info( - f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." - ) - await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) - continue - - # Wait on the condition of the key with the lowest current usage. - best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] - wait_condition = self.key_states[best_wait_key]["condition"] - - try: - async with wait_condition: - remaining_budget = deadline - time.time() - if remaining_budget <= 0: - break # Exit if the budget has already been exceeded. - # Wait for a notification, but no longer than the remaining budget or 1 second. - await asyncio.wait_for( - wait_condition.wait(), timeout=min(1, remaining_budget) - ) - lib_logger.info("Notified that a key was released. Re-evaluating...") - except asyncio.TimeoutError: - # This is not an error, just a timeout for the wait. The main loop will re-evaluate. - lib_logger.info("Wait timed out. Re-evaluating for any available key.") - - # If the loop exits, it means the deadline was exceeded. - raise NoAvailableKeysError( - f"Could not acquire a key for model {model} within the global time budget." - ) - - async def release_key(self, key: str, model: str): - """Releases a key's lock for a specific model and notifies waiting tasks.""" - if key not in self.key_states: - return - - state = self.key_states[key] - async with state["lock"]: - if model in state["models_in_use"]: - state["models_in_use"][model] -= 1 - remaining = state["models_in_use"][model] - if remaining <= 0: - del state["models_in_use"][model] # Clean up when count reaches 0 - lib_logger.info( - f"Released credential {mask_credential(key)} from model {model} " - f"(remaining concurrent: {max(0, remaining)})" - ) - else: - lib_logger.warning( - f"Attempted to release credential {mask_credential(key)} for model {model}, but it was not in use." - ) - - # Notify all tasks waiting on this key's condition - async with state["condition"]: - state["condition"].notify_all() - - async def record_success( - self, - key: str, - model: str, - completion_response: Optional[litellm.ModelResponse] = None, - ): - """ - Records a successful API call, resetting failure counters. - It safely handles cases where token usage data is not available. - - Supports two modes based on provider configuration: - - per_model: Each model has its own window_start_ts and stats in key_data["models"] - - credential: Legacy mode with key_data["daily"]["models"] - """ - await self._lazy_init() - - # Normalize model name to public-facing name for consistent tracking - model = self._normalize_model(key, model) - - async with self._data_lock: - now_ts = time.time() - today_utc_str = datetime.now(timezone.utc).date().isoformat() - - reset_config = self._get_usage_reset_config(key) - reset_mode = ( - reset_config.get("mode", "credential") if reset_config else "credential" - ) - - if reset_mode == "per_model": - # New per-model structure - key_data = self._usage_data.setdefault( - key, - { - "models": {}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - # Ensure models dict exists - if "models" not in key_data: - key_data["models"] = {} - - # Get or create per-model data with window tracking - model_data = key_data["models"].setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - - # Start window on first request for this model - if model_data.get("window_start_ts") is None: - model_data["window_start_ts"] = now_ts - - # Set expected quota reset time from provider config - window_seconds = ( - reset_config.get("window_seconds", 0) if reset_config else 0 - ) - if window_seconds > 0: - model_data["quota_reset_ts"] = now_ts + window_seconds - - window_hours = window_seconds / 3600 if window_seconds else 0 - lib_logger.info( - f"Started {window_hours:.1f}h window for model {model} on {mask_credential(key)}" - ) - - # Record stats - model_data["success_count"] += 1 - model_data["request_count"] = model_data.get("request_count", 0) + 1 - - # Sync request_count across quota group (for providers with shared quota pools) - new_request_count = model_data["request_count"] - group = self._get_model_quota_group(key, model) - if group: - grouped_models = self._get_grouped_models(key, group) - for grouped_model in grouped_models: - if grouped_model != model: - other_model_data = key_data["models"].setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - other_model_data["request_count"] = new_request_count - # Sync window timing (shared quota pool = shared window) - window_start = model_data.get("window_start_ts") - if window_start: - other_model_data["window_start_ts"] = window_start - # Also sync quota_max_requests if set - max_req = model_data.get("quota_max_requests") - if max_req: - other_model_data["quota_max_requests"] = max_req - other_model_data["quota_display"] = ( - f"{new_request_count}/{max_req}" - ) - - # Update quota_display if max_requests is set (Antigravity-specific) - max_req = model_data.get("quota_max_requests") - if max_req: - model_data["quota_display"] = ( - f"{model_data['request_count']}/{max_req}" - ) - - # Check custom cap - if self._check_and_apply_custom_cap( - key, model, model_data["request_count"] - ): - # Custom cap exceeded, cooldown applied - # Continue to record tokens/cost but credential will be skipped next time - pass - - usage_data_ref = model_data # For token/cost recording below - - else: - # Legacy credential-level structure - key_data = self._usage_data.setdefault( - key, - { - "daily": {"date": today_utc_str, "models": {}}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - if "last_daily_reset" not in key_data: - key_data["last_daily_reset"] = today_utc_str - - # Get or create model data in daily structure - usage_data_ref = key_data["daily"]["models"].setdefault( - model, - { - "success_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - usage_data_ref["success_count"] += 1 - - # Reset failures for this model - model_failures = key_data.setdefault("failures", {}).setdefault(model, {}) - model_failures["consecutive_failures"] = 0 - - # Clear transient cooldown on success (but NOT quota_reset_ts) - if model in key_data.get("model_cooldowns", {}): - del key_data["model_cooldowns"][model] - - # Record token and cost usage - if ( - completion_response - and hasattr(completion_response, "usage") - and completion_response.usage - ): - usage = completion_response.usage - prompt_total = usage.prompt_tokens - - # Extract cached tokens from prompt_tokens_details if present - cached_tokens = 0 - prompt_details = getattr(usage, "prompt_tokens_details", None) - if prompt_details: - if isinstance(prompt_details, dict): - cached_tokens = prompt_details.get("cached_tokens", 0) or 0 - elif hasattr(prompt_details, "cached_tokens"): - cached_tokens = prompt_details.cached_tokens or 0 - - # Store uncached tokens (prompt_tokens is total, subtract cached) - uncached_tokens = prompt_total - cached_tokens - usage_data_ref["prompt_tokens"] += uncached_tokens - - # Store cached tokens separately - if cached_tokens > 0: - usage_data_ref["prompt_tokens_cached"] = ( - usage_data_ref.get("prompt_tokens_cached", 0) + cached_tokens - ) - - usage_data_ref["completion_tokens"] += getattr( - usage, "completion_tokens", 0 - ) - lib_logger.info( - f"Recorded usage from response object for key {mask_credential(key)}" - ) - try: - provider_name = model.split("/")[0] - provider_instance = self._get_provider_instance(provider_name) - - if provider_instance and getattr( - provider_instance, "skip_cost_calculation", False - ): - lib_logger.debug( - f"Skipping cost calculation for provider '{provider_name}' (custom provider)." - ) - else: - if isinstance(completion_response, litellm.EmbeddingResponse): - model_info = litellm.get_model_info(model) - input_cost = model_info.get("input_cost_per_token") - if input_cost: - cost = ( - completion_response.usage.prompt_tokens * input_cost - ) - else: - cost = None - else: - cost = litellm.completion_cost( - completion_response=completion_response, model=model - ) - - if cost is not None: - usage_data_ref["approx_cost"] += cost - except Exception as e: - lib_logger.warning( - f"Could not calculate cost for model {model}: {e}" - ) - elif isinstance(completion_response, asyncio.Future) or hasattr( - completion_response, "__aiter__" - ): - pass # Stream - usage recorded from chunks - else: - lib_logger.warning( - f"No usage data found in completion response for model {model}. Recording success without token count." - ) - - key_data["last_used_ts"] = now_ts - - await self._save_usage() - - async def record_failure( - self, - key: str, - model: str, - classified_error: ClassifiedError, - increment_consecutive_failures: bool = True, - ): - """Records a failure and applies cooldowns based on error type. - - Distinguishes between: - - quota_exceeded: Long cooldown with exact reset time (from quota_reset_timestamp) - Sets quota_reset_ts on model (and group) - this becomes authoritative stats reset time - - rate_limit: Short transient cooldown (just wait and retry) - Only sets model_cooldowns - does NOT affect stats reset timing - - Args: - key: The API key or credential identifier - model: The model name - classified_error: The classified error object - increment_consecutive_failures: Whether to increment the failure counter. - Set to False for provider-level errors that shouldn't count against the key. - """ - await self._lazy_init() - - # Normalize model name to public-facing name for consistent tracking - model = self._normalize_model(key, model) - - async with self._data_lock: - now_ts = time.time() - today_utc_str = datetime.now(timezone.utc).date().isoformat() - - reset_config = self._get_usage_reset_config(key) - reset_mode = ( - reset_config.get("mode", "credential") if reset_config else "credential" - ) - - # Initialize key data with appropriate structure - if reset_mode == "per_model": - key_data = self._usage_data.setdefault( - key, - { - "models": {}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - else: - key_data = self._usage_data.setdefault( - key, - { - "daily": {"date": today_utc_str, "models": {}}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - # Provider-level errors (transient issues) should not count against the key - provider_level_errors = {"server_error", "api_connection"} - - # Determine if we should increment the failure counter - should_increment = ( - increment_consecutive_failures - and classified_error.error_type not in provider_level_errors - ) - - # Calculate cooldown duration based on error type - cooldown_seconds = None - model_cooldowns = key_data.setdefault("model_cooldowns", {}) - - # Capture existing cooldown BEFORE we modify it - # Used to determine if this is a fresh exhaustion vs re-processing - existing_cooldown_before = model_cooldowns.get(model) - was_already_on_cooldown = ( - existing_cooldown_before is not None - and existing_cooldown_before > now_ts - ) - - if classified_error.error_type == "quota_exceeded": - # Quota exhausted - use authoritative reset timestamp if available - quota_reset_ts = classified_error.quota_reset_timestamp - cooldown_seconds = ( - classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT - ) - - if quota_reset_ts and reset_mode == "per_model": - # Set quota_reset_ts on model - this becomes authoritative stats reset time - models_data = key_data.setdefault("models", {}) - model_data = models_data.setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - model_data["quota_reset_ts"] = quota_reset_ts - # Track failure for quota estimation (request still consumes quota) - model_data["failure_count"] = model_data.get("failure_count", 0) + 1 - model_data["request_count"] = model_data.get("request_count", 0) + 1 - - # Clamp request_count to quota_max_requests when quota is exhausted - # This prevents display overflow (e.g., 151/150) when requests are - # counted locally before API refresh corrects the value - max_req = model_data.get("quota_max_requests") - if max_req is not None and model_data["request_count"] > max_req: - model_data["request_count"] = max_req - # Update quota_display with clamped value - model_data["quota_display"] = f"{max_req}/{max_req}" - new_request_count = model_data["request_count"] - - # Apply to all models in the same quota group - group = self._get_model_quota_group(key, model) - if group: - grouped_models = self._get_grouped_models(key, group) - for grouped_model in grouped_models: - group_model_data = models_data.setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - group_model_data["quota_reset_ts"] = quota_reset_ts - # Sync request_count across quota group - group_model_data["request_count"] = new_request_count - # Also sync quota_max_requests if set - max_req = model_data.get("quota_max_requests") - if max_req: - group_model_data["quota_max_requests"] = max_req - group_model_data["quota_display"] = ( - f"{new_request_count}/{max_req}" - ) - # Also set transient cooldown for selection logic - model_cooldowns[grouped_model] = quota_reset_ts - - reset_dt = datetime.fromtimestamp( - quota_reset_ts, tz=timezone.utc - ) - lib_logger.info( - f"Quota exhausted for group '{group}' ({len(grouped_models)} models) " - f"on {mask_credential(key)}. Resets at {reset_dt.isoformat()}" - ) - else: - reset_dt = datetime.fromtimestamp( - quota_reset_ts, tz=timezone.utc - ) - hours = (quota_reset_ts - now_ts) / 3600 - lib_logger.info( - f"Quota exhausted for model {model} on {mask_credential(key)}. " - f"Resets at {reset_dt.isoformat()} ({hours:.1f}h)" - ) - - # Set transient cooldown for selection logic - model_cooldowns[model] = quota_reset_ts - else: - # No authoritative timestamp or legacy mode - just use retry_after - model_cooldowns[model] = now_ts + cooldown_seconds - hours = cooldown_seconds / 3600 - lib_logger.info( - f"Quota exhausted on {mask_credential(key)} for model {model}. " - f"Cooldown: {cooldown_seconds}s ({hours:.1f}h)" - ) - - # Mark credential as exhausted for fair cycle if cooldown exceeds threshold - # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) - # This prevents re-marking after cycle reset - if not was_already_on_cooldown: - effective_cooldown = ( - (quota_reset_ts - now_ts) - if quota_reset_ts - else (cooldown_seconds or 0) - ) - provider = self._get_provider_from_credential(key) - if provider: - threshold = self._get_exhaustion_cooldown_threshold(provider) - if effective_cooldown > threshold: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - priority = self._get_credential_priority(key, provider) - tier_key = self._get_tier_key(provider, priority) - tracking_key = self._get_tracking_key( - key, model, provider - ) - self._mark_credential_exhausted( - key, provider, tier_key, tracking_key - ) - - elif classified_error.error_type == "rate_limit": - # Transient rate limit - just set short cooldown (does NOT set quota_reset_ts) - cooldown_seconds = ( - classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT - ) - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.info( - f"Rate limit on {mask_credential(key)} for model {model}. " - f"Transient cooldown: {cooldown_seconds}s" - ) - - elif classified_error.error_type == "authentication": - # Apply a 5-minute key-level lockout for auth errors - key_data["key_cooldown_until"] = now_ts + COOLDOWN_AUTH_ERROR - cooldown_seconds = COOLDOWN_AUTH_ERROR - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.warning( - f"Authentication error on key {mask_credential(key)}. Applying 5-minute key-level lockout." - ) - - # If we should increment failures, calculate escalating backoff - if should_increment: - failures_data = key_data.setdefault("failures", {}) - model_failures = failures_data.setdefault( - model, {"consecutive_failures": 0} - ) - model_failures["consecutive_failures"] += 1 - count = model_failures["consecutive_failures"] - - # If cooldown wasn't set by specific error type, use escalating backoff - if cooldown_seconds is None: - cooldown_seconds = COOLDOWN_BACKOFF_TIERS.get( - count, COOLDOWN_BACKOFF_MAX - ) - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.warning( - f"Failure #{count} for key {mask_credential(key)} with model {model}. " - f"Error type: {classified_error.error_type}, cooldown: {cooldown_seconds}s" - ) - else: - # Provider-level errors: apply short cooldown but don't count against key - if cooldown_seconds is None: - cooldown_seconds = COOLDOWN_TRANSIENT_ERROR - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.info( - f"Provider-level error ({classified_error.error_type}) for key {mask_credential(key)} " - f"with model {model}. NOT incrementing failures. Cooldown: {cooldown_seconds}s" - ) - - # Check for key-level lockout condition - await self._check_key_lockout(key, key_data) - - # Track failure count for quota estimation (all failures consume quota) - # This is separate from consecutive_failures which is for backoff logic - if reset_mode == "per_model": - models_data = key_data.setdefault("models", {}) - model_data = models_data.setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - # Only increment if not already incremented in quota_exceeded branch - if classified_error.error_type != "quota_exceeded": - model_data["failure_count"] = model_data.get("failure_count", 0) + 1 - model_data["request_count"] = model_data.get("request_count", 0) + 1 - - # Sync request_count across quota group - new_request_count = model_data["request_count"] - group = self._get_model_quota_group(key, model) - if group: - grouped_models = self._get_grouped_models(key, group) - for grouped_model in grouped_models: - if grouped_model != model: - other_model_data = models_data.setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - other_model_data["request_count"] = new_request_count - # Also sync quota_max_requests if set - max_req = model_data.get("quota_max_requests") - if max_req: - other_model_data["quota_max_requests"] = max_req - other_model_data["quota_display"] = ( - f"{new_request_count}/{max_req}" - ) - - key_data["last_failure"] = { - "timestamp": now_ts, - "model": model, - "error": str(classified_error.original_exception), - } - - await self._save_usage() - - async def update_quota_baseline( - self, - credential: str, - model: str, - remaining_fraction: float, - max_requests: Optional[int] = None, - reset_timestamp: Optional[float] = None, - ) -> Optional[Dict[str, Any]]: - """ - Update quota baseline data for a credential/model after fetching from API. - - This stores the current quota state as a baseline, which is used to - estimate remaining quota based on subsequent request counts. - - When quota is exhausted (remaining_fraction <= 0.0) and a valid reset_timestamp - is provided, this also sets model_cooldowns to prevent wasted requests. - - Args: - credential: Credential identifier (file path or env:// URI) - model: Model name (with or without provider prefix) - remaining_fraction: Current remaining quota as fraction (0.0 to 1.0) - max_requests: Maximum requests allowed per quota period (e.g., 250 for Claude) - reset_timestamp: Unix timestamp when quota resets. Only trusted when - remaining_fraction < 1.0 (quota has been used). API returns garbage - reset times for unused quota (100%). - - Returns: - None if no cooldown was set/updated, otherwise: - { - "group_or_model": str, # quota group name or model name if ungrouped - "hours_until_reset": float, - } - """ - await self._lazy_init() - async with self._data_lock: - now_ts = time.time() - - # Get or create key data structure - key_data = self._usage_data.setdefault( - credential, - { - "models": {}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - # Ensure models dict exists - if "models" not in key_data: - key_data["models"] = {} - - # Get or create per-model data - model_data = key_data["models"].setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - "baseline_remaining_fraction": None, - "baseline_fetched_at": None, - "requests_at_baseline": None, - }, - ) - - # Calculate actual used requests from API's remaining fraction - # The API is authoritative - sync our local count to match reality - if max_requests is not None: - used_requests = int((1.0 - remaining_fraction) * max_requests) - else: - # Estimate max_requests from provider's quota cost - # This matches how get_max_requests_for_model() calculates it - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - if plugin_instance and hasattr( - plugin_instance, "get_max_requests_for_model" - ): - # Get tier from provider's cache - tier = getattr(plugin_instance, "project_tier_cache", {}).get( - credential, "standard-tier" - ) - # Strip provider prefix from model if present - clean_model = model.split("/")[-1] if "/" in model else model - max_requests = plugin_instance.get_max_requests_for_model( - clean_model, tier - ) - used_requests = int((1.0 - remaining_fraction) * max_requests) - else: - # Fallback: keep existing count if we can't calculate - used_requests = model_data.get("request_count", 0) - max_requests = model_data.get("quota_max_requests") - - # Sync local request count to API's authoritative value - # Use max() to prevent API from resetting our count if it returns stale/cached 100% - # The API can only increase our count (if we missed requests), not decrease it - # See: https://github.com/Mirrowel/LLM-API-Key-Proxy/issues/75 - current_count = model_data.get("request_count", 0) - synced_count = max(current_count, used_requests) - model_data["request_count"] = synced_count - model_data["requests_at_baseline"] = synced_count - - # Update baseline fields - model_data["baseline_remaining_fraction"] = remaining_fraction - model_data["baseline_fetched_at"] = now_ts - - # Update max_requests and quota_display - if max_requests is not None: - model_data["quota_max_requests"] = max_requests - model_data["quota_display"] = f"{synced_count}/{max_requests}" - - # Handle reset_timestamp: only trust it when quota has been used (< 100%) - # API returns garbage reset times for unused quota - valid_reset_ts = ( - reset_timestamp is not None - and remaining_fraction < 1.0 - and reset_timestamp > now_ts - ) - - if valid_reset_ts: - model_data["quota_reset_ts"] = reset_timestamp - - # Set cooldowns when quota is exhausted - model_cooldowns = key_data.setdefault("model_cooldowns", {}) - is_exhausted = remaining_fraction <= 0.0 - cooldown_set_info = ( - None # Will be returned if cooldown was newly set/updated - ) - - if is_exhausted and valid_reset_ts: - # Check if there was an existing ACTIVE cooldown before we update - # This distinguishes between fresh exhaustion vs refresh of existing state - existing_cooldown = model_cooldowns.get(model) - was_already_on_cooldown = ( - existing_cooldown is not None and existing_cooldown > now_ts - ) - - # Only update cooldown if not set or differs by more than 5 minutes - should_update = ( - existing_cooldown is None - or abs(existing_cooldown - reset_timestamp) > 300 - ) - if should_update: - model_cooldowns[model] = reset_timestamp - hours_until_reset = (reset_timestamp - now_ts) / 3600 - # Determine group or model name for logging - group = self._get_model_quota_group(credential, model) - cooldown_set_info = { - "group_or_model": group if group else model.split("/")[-1], - "hours_until_reset": hours_until_reset, - } - - # Mark credential as exhausted in fair cycle if cooldown exceeds threshold - # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) - # This prevents re-marking after cycle reset when quota refresh sees existing cooldown - if not was_already_on_cooldown: - cooldown_duration = reset_timestamp - now_ts - provider = self._get_provider_from_credential(credential) - if provider: - threshold = self._get_exhaustion_cooldown_threshold(provider) - if cooldown_duration > threshold: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - priority = self._get_credential_priority( - credential, provider - ) - tier_key = self._get_tier_key(provider, priority) - tracking_key = self._get_tracking_key( - credential, model, provider - ) - self._mark_credential_exhausted( - credential, provider, tier_key, tracking_key - ) - - # Defensive clamp: ensure request_count doesn't exceed max when exhausted - if ( - max_requests is not None - and model_data["request_count"] > max_requests - ): - model_data["request_count"] = max_requests - model_data["quota_display"] = f"{max_requests}/{max_requests}" - - # Sync baseline fields and quota info across quota group - group = self._get_model_quota_group(credential, model) - if group: - grouped_models = self._get_grouped_models(credential, group) - for grouped_model in grouped_models: - if grouped_model != model: - other_model_data = key_data["models"].setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - # Sync request tracking (use synced_count to prevent reset bug) - other_model_data["request_count"] = synced_count - if max_requests is not None: - other_model_data["quota_max_requests"] = max_requests - other_model_data["quota_display"] = ( - f"{synced_count}/{max_requests}" - ) - # Sync baseline fields - other_model_data["baseline_remaining_fraction"] = ( - remaining_fraction - ) - other_model_data["baseline_fetched_at"] = now_ts - other_model_data["requests_at_baseline"] = synced_count - # Sync reset timestamp if valid - if valid_reset_ts: - other_model_data["quota_reset_ts"] = reset_timestamp - # Sync window start time - window_start = model_data.get("window_start_ts") - if window_start: - other_model_data["window_start_ts"] = window_start - # Sync cooldown if exhausted (with ±5 min check) - if is_exhausted and valid_reset_ts: - existing_grouped = model_cooldowns.get(grouped_model) - should_update_grouped = ( - existing_grouped is None - or abs(existing_grouped - reset_timestamp) > 300 - ) - if should_update_grouped: - model_cooldowns[grouped_model] = reset_timestamp - - # Defensive clamp for grouped models when exhausted - if ( - max_requests is not None - and other_model_data["request_count"] > max_requests - ): - other_model_data["request_count"] = max_requests - other_model_data["quota_display"] = ( - f"{max_requests}/{max_requests}" - ) - - lib_logger.debug( - f"Updated quota baseline for {mask_credential(credential)} model={model}: " - f"remaining={remaining_fraction:.2%}, synced_request_count={synced_count}" - ) - - await self._save_usage() - return cooldown_set_info - - async def _check_key_lockout(self, key: str, key_data: Dict): - """ - Checks if a key should be locked out due to multiple model failures. - - NOTE: This check is currently disabled. The original logic counted individual - models in long-term lockout, but this caused issues with quota groups - when - a single quota group (e.g., "claude" with 5 models) was exhausted, it would - count as 5 lockouts and trigger key-level lockout, blocking other quota groups - (like gemini) that were still available. - - The per-model and per-group cooldowns already handle quota exhaustion properly. - """ - # Disabled - see docstring above - pass - - async def get_stats_for_endpoint( - self, - provider_filter: Optional[str] = None, - include_global: bool = True, - ) -> Dict[str, Any]: - """ - Get usage stats formatted for the /v1/quota-stats endpoint. - - Aggregates data from key_usage.json grouped by provider. - Includes both current period stats and global (lifetime) stats. - - Args: - provider_filter: If provided, only return stats for this provider - include_global: If True, include global/lifetime stats alongside current - - Returns: - { - "providers": { - "provider_name": { - "credential_count": int, - "active_count": int, - "on_cooldown_count": int, - "total_requests": int, - "tokens": { - "input_cached": int, - "input_uncached": int, - "input_cache_pct": float, - "output": int - }, - "approx_cost": float | None, - "credentials": [...], - "global": {...} # If include_global is True - } - }, - "summary": {...}, - "global_summary": {...}, # If include_global is True - "timestamp": float - } - """ - await self._lazy_init() - - now_ts = time.time() - providers: Dict[str, Dict[str, Any]] = {} - # Track global stats separately - global_providers: Dict[str, Dict[str, Any]] = {} - - async with self._data_lock: - if not self._usage_data: - return { - "providers": {}, - "summary": { - "total_providers": 0, - "total_credentials": 0, - "active_credentials": 0, - "exhausted_credentials": 0, - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_total_cost": 0.0, - }, - "global_summary": { - "total_providers": 0, - "total_credentials": 0, - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_total_cost": 0.0, - }, - "data_source": "cache", - "timestamp": now_ts, - } - - for credential, cred_data in self._usage_data.items(): - # Extract provider from credential path - provider = self._get_provider_from_credential(credential) - if not provider: - continue - - # Apply filter if specified - if provider_filter and provider != provider_filter: - continue - - # Initialize provider entry - if provider not in providers: - providers[provider] = { - "credential_count": 0, - "active_count": 0, - "on_cooldown_count": 0, - "exhausted_count": 0, - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_cost": 0.0, - "credentials": [], - } - global_providers[provider] = { - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_cost": 0.0, - } - - prov_stats = providers[provider] - prov_stats["credential_count"] += 1 - - # Determine credential status and cooldowns - key_cooldown = cred_data.get("key_cooldown_until", 0) or 0 - model_cooldowns = cred_data.get("model_cooldowns", {}) - - # Build active cooldowns with remaining time - active_cooldowns = {} - for model, cooldown_ts in model_cooldowns.items(): - if cooldown_ts > now_ts: - remaining_seconds = int(cooldown_ts - now_ts) - active_cooldowns[model] = { - "until_ts": cooldown_ts, - "remaining_seconds": remaining_seconds, - } - - key_cooldown_remaining = None - if key_cooldown > now_ts: - key_cooldown_remaining = int(key_cooldown - now_ts) - - has_active_cooldown = key_cooldown > now_ts or len(active_cooldowns) > 0 - - # Check if exhausted (all quota groups exhausted for Antigravity) - is_exhausted = False - models_data = cred_data.get("models", {}) - if models_data: - # Check if any model has remaining quota - all_exhausted = True - for model_stats in models_data.values(): - if isinstance(model_stats, dict): - baseline = model_stats.get("baseline_remaining_fraction") - if baseline is None or baseline > 0: - all_exhausted = False - break - if all_exhausted and len(models_data) > 0: - is_exhausted = True - - if is_exhausted: - prov_stats["exhausted_count"] += 1 - status = "exhausted" - elif has_active_cooldown: - prov_stats["on_cooldown_count"] += 1 - status = "cooldown" - else: - prov_stats["active_count"] += 1 - status = "active" - - # Aggregate token stats (current period) - cred_tokens = { - "input_cached": 0, - "input_uncached": 0, - "output": 0, - } - cred_requests = 0 - cred_cost = 0.0 - - # Aggregate global token stats - cred_global_tokens = { - "input_cached": 0, - "input_uncached": 0, - "output": 0, - } - cred_global_requests = 0 - cred_global_cost = 0.0 - - # Handle per-model structure (current period) - if models_data: - for model_name, model_stats in models_data.items(): - if not isinstance(model_stats, dict): - continue - # Prefer request_count if available and non-zero, else fall back to success+failure - req_count = model_stats.get("request_count", 0) - if req_count > 0: - cred_requests += req_count - else: - cred_requests += model_stats.get("success_count", 0) - cred_requests += model_stats.get("failure_count", 0) - # Token stats - track cached separately - cred_tokens["input_cached"] += model_stats.get( - "prompt_tokens_cached", 0 - ) - cred_tokens["input_uncached"] += model_stats.get( - "prompt_tokens", 0 - ) - cred_tokens["output"] += model_stats.get("completion_tokens", 0) - cred_cost += model_stats.get("approx_cost", 0.0) - - # Handle legacy daily structure - daily_data = cred_data.get("daily", {}) - daily_models = daily_data.get("models", {}) - for model_name, model_stats in daily_models.items(): - if not isinstance(model_stats, dict): - continue - cred_requests += model_stats.get("success_count", 0) - cred_tokens["input_cached"] += model_stats.get( - "prompt_tokens_cached", 0 - ) - cred_tokens["input_uncached"] += model_stats.get("prompt_tokens", 0) - cred_tokens["output"] += model_stats.get("completion_tokens", 0) - cred_cost += model_stats.get("approx_cost", 0.0) - - # Handle global stats - global_data = cred_data.get("global", {}) - global_models = global_data.get("models", {}) - for model_name, model_stats in global_models.items(): - if not isinstance(model_stats, dict): - continue - cred_global_requests += model_stats.get("success_count", 0) - cred_global_tokens["input_cached"] += model_stats.get( - "prompt_tokens_cached", 0 - ) - cred_global_tokens["input_uncached"] += model_stats.get( - "prompt_tokens", 0 - ) - cred_global_tokens["output"] += model_stats.get( - "completion_tokens", 0 - ) - cred_global_cost += model_stats.get("approx_cost", 0.0) - - # Add current period stats to global totals - cred_global_requests += cred_requests - cred_global_tokens["input_cached"] += cred_tokens["input_cached"] - cred_global_tokens["input_uncached"] += cred_tokens["input_uncached"] - cred_global_tokens["output"] += cred_tokens["output"] - cred_global_cost += cred_cost - - # Build credential entry - # Mask credential identifier for display - if credential.startswith("env://"): - identifier = credential - else: - identifier = Path(credential).name - - cred_entry = { - "identifier": identifier, - "full_path": credential, - "status": status, - "last_used_ts": cred_data.get("last_used_ts"), - "requests": cred_requests, - "tokens": cred_tokens, - "approx_cost": cred_cost if cred_cost > 0 else None, - } - - # Add cooldown info - if key_cooldown_remaining is not None: - cred_entry["key_cooldown_remaining"] = key_cooldown_remaining - if active_cooldowns: - cred_entry["model_cooldowns"] = active_cooldowns - - # Add global stats for this credential - if include_global: - # Calculate global cache percentage - global_total_input = ( - cred_global_tokens["input_cached"] - + cred_global_tokens["input_uncached"] - ) - global_cache_pct = ( - round( - cred_global_tokens["input_cached"] - / global_total_input - * 100, - 1, - ) - if global_total_input > 0 - else 0 - ) - - cred_entry["global"] = { - "requests": cred_global_requests, - "tokens": { - "input_cached": cred_global_tokens["input_cached"], - "input_uncached": cred_global_tokens["input_uncached"], - "input_cache_pct": global_cache_pct, - "output": cred_global_tokens["output"], - }, - "approx_cost": cred_global_cost - if cred_global_cost > 0 - else None, - } - - # Add model-specific data for providers with per-model tracking - if models_data: - cred_entry["models"] = {} - for model_name, model_stats in models_data.items(): - if not isinstance(model_stats, dict): - continue - cred_entry["models"][model_name] = { - "requests": model_stats.get("success_count", 0) - + model_stats.get("failure_count", 0), - "request_count": model_stats.get("request_count", 0), - "success_count": model_stats.get("success_count", 0), - "failure_count": model_stats.get("failure_count", 0), - "prompt_tokens": model_stats.get("prompt_tokens", 0), - "prompt_tokens_cached": model_stats.get( - "prompt_tokens_cached", 0 - ), - "completion_tokens": model_stats.get( - "completion_tokens", 0 - ), - "approx_cost": model_stats.get("approx_cost", 0.0), - "window_start_ts": model_stats.get("window_start_ts"), - "quota_reset_ts": model_stats.get("quota_reset_ts"), - # Quota baseline fields (Antigravity-specific) - "baseline_remaining_fraction": model_stats.get( - "baseline_remaining_fraction" - ), - "baseline_fetched_at": model_stats.get( - "baseline_fetched_at" - ), - "quota_max_requests": model_stats.get("quota_max_requests"), - "quota_display": model_stats.get("quota_display"), - } - - prov_stats["credentials"].append(cred_entry) - - # Aggregate to provider totals (current period) - prov_stats["total_requests"] += cred_requests - prov_stats["tokens"]["input_cached"] += cred_tokens["input_cached"] - prov_stats["tokens"]["input_uncached"] += cred_tokens["input_uncached"] - prov_stats["tokens"]["output"] += cred_tokens["output"] - if cred_cost > 0: - prov_stats["approx_cost"] += cred_cost - - # Aggregate to global provider totals - global_providers[provider]["total_requests"] += cred_global_requests - global_providers[provider]["tokens"]["input_cached"] += ( - cred_global_tokens["input_cached"] - ) - global_providers[provider]["tokens"]["input_uncached"] += ( - cred_global_tokens["input_uncached"] - ) - global_providers[provider]["tokens"]["output"] += cred_global_tokens[ - "output" - ] - global_providers[provider]["approx_cost"] += cred_global_cost - - # Calculate cache percentages for each provider - for provider, prov_stats in providers.items(): - total_input = ( - prov_stats["tokens"]["input_cached"] - + prov_stats["tokens"]["input_uncached"] - ) - if total_input > 0: - prov_stats["tokens"]["input_cache_pct"] = round( - prov_stats["tokens"]["input_cached"] / total_input * 100, 1 - ) - # Set cost to None if 0 - if prov_stats["approx_cost"] == 0: - prov_stats["approx_cost"] = None - - # Calculate global cache percentages - if include_global and provider in global_providers: - gp = global_providers[provider] - global_total = ( - gp["tokens"]["input_cached"] + gp["tokens"]["input_uncached"] - ) - if global_total > 0: - gp["tokens"]["input_cache_pct"] = round( - gp["tokens"]["input_cached"] / global_total * 100, 1 - ) - if gp["approx_cost"] == 0: - gp["approx_cost"] = None - prov_stats["global"] = gp - - # Build summary (current period) - total_creds = sum(p["credential_count"] for p in providers.values()) - active_creds = sum(p["active_count"] for p in providers.values()) - exhausted_creds = sum(p["exhausted_count"] for p in providers.values()) - total_requests = sum(p["total_requests"] for p in providers.values()) - total_input_cached = sum( - p["tokens"]["input_cached"] for p in providers.values() - ) - total_input_uncached = sum( - p["tokens"]["input_uncached"] for p in providers.values() - ) - total_output = sum(p["tokens"]["output"] for p in providers.values()) - total_cost = sum(p["approx_cost"] or 0 for p in providers.values()) - - total_input = total_input_cached + total_input_uncached - input_cache_pct = ( - round(total_input_cached / total_input * 100, 1) if total_input > 0 else 0 - ) - - result = { - "providers": providers, - "summary": { - "total_providers": len(providers), - "total_credentials": total_creds, - "active_credentials": active_creds, - "exhausted_credentials": exhausted_creds, - "total_requests": total_requests, - "tokens": { - "input_cached": total_input_cached, - "input_uncached": total_input_uncached, - "input_cache_pct": input_cache_pct, - "output": total_output, - }, - "approx_total_cost": total_cost if total_cost > 0 else None, - }, - "data_source": "cache", - "timestamp": now_ts, - } - - # Build global summary - if include_global: - global_total_requests = sum( - gp["total_requests"] for gp in global_providers.values() - ) - global_total_input_cached = sum( - gp["tokens"]["input_cached"] for gp in global_providers.values() - ) - global_total_input_uncached = sum( - gp["tokens"]["input_uncached"] for gp in global_providers.values() - ) - global_total_output = sum( - gp["tokens"]["output"] for gp in global_providers.values() - ) - global_total_cost = sum( - gp["approx_cost"] or 0 for gp in global_providers.values() - ) - - global_total_input = global_total_input_cached + global_total_input_uncached - global_input_cache_pct = ( - round(global_total_input_cached / global_total_input * 100, 1) - if global_total_input > 0 - else 0 - ) - - result["global_summary"] = { - "total_providers": len(global_providers), - "total_credentials": total_creds, - "total_requests": global_total_requests, - "tokens": { - "input_cached": global_total_input_cached, - "input_uncached": global_total_input_uncached, - "input_cache_pct": global_input_cache_pct, - "output": global_total_output, - }, - "approx_total_cost": global_total_cost - if global_total_cost > 0 - else None, - } - - return result - - async def reload_from_disk(self) -> None: - """ - Force reload usage data from disk. - - Useful when another process may have updated the file. - """ - async with self._init_lock: - self._initialized.clear() - await self._load_usage() - await self._reset_daily_stats_if_needed() - self._initialized.set() +__all__ = ["UsageManager", "CredentialContext"] diff --git a/src/rotator_library/utils/__init__.py b/src/rotator_library/utils/__init__.py index ce8d959d..a51d1db7 100644 --- a/src/rotator_library/utils/__init__.py +++ b/src/rotator_library/utils/__init__.py @@ -17,6 +17,7 @@ ResilientStateWriter, safe_write_json, safe_log_write, + safe_read_json, safe_mkdir, ) from .suppress_litellm_warnings import suppress_litellm_serialization_warnings @@ -34,6 +35,7 @@ "ResilientStateWriter", "safe_write_json", "safe_log_write", + "safe_read_json", "safe_mkdir", "suppress_litellm_serialization_warnings", ] diff --git a/src/rotator_library/utils/resilient_io.py b/src/rotator_library/utils/resilient_io.py index 1125e3b7..91e96f37 100644 --- a/src/rotator_library/utils/resilient_io.py +++ b/src/rotator_library/utils/resilient_io.py @@ -660,6 +660,36 @@ def safe_log_write( return False +def safe_read_json( + path: Union[str, Path], + logger: logging.Logger, + *, + parse_json: bool = True, +) -> Optional[Any]: + """ + Read file contents with error handling. + + Args: + path: File path to read from + logger: Logger for warnings/errors + parse_json: When True, parse JSON; when False, return raw text + + Returns: + Parsed JSON dict, raw text, or None on failure + """ + path = Path(path) + try: + with open(path, "r", encoding="utf-8") as f: + if parse_json: + return json.load(f) + return f.read() + except FileNotFoundError: + return None + except (OSError, PermissionError, IOError, json.JSONDecodeError) as e: + logger.error(f"Failed to read {path}: {e}") + return None + + def safe_mkdir(path: Union[str, Path], logger: logging.Logger) -> bool: """ Create directory with error handling.