diff --git a/examples/01_standalone_sdk/35_subscription_login.py b/examples/01_standalone_sdk/35_subscription_login.py new file mode 100644 index 0000000000..2c6d4525ad --- /dev/null +++ b/examples/01_standalone_sdk/35_subscription_login.py @@ -0,0 +1,61 @@ +"""Example: Using ChatGPT subscription for Codex models. + +This example demonstrates how to use your ChatGPT Plus/Pro subscription +to access OpenAI's Codex models without consuming API credits. + +The subscription_login() method handles: +- OAuth PKCE authentication flow +- Credential caching (~/.openhands/auth/) +- Automatic token refresh + +Supported models: +- gpt-5.2-codex +- gpt-5.2 +- gpt-5.1-codex-max +- gpt-5.1-codex-mini + +Requirements: +- Active ChatGPT Plus or Pro subscription +- Browser access for initial OAuth login +""" + +import os + +from openhands.sdk import LLM, Agent, Conversation, Tool +from openhands.tools.file_editor import FileEditorTool +from openhands.tools.terminal import TerminalTool + + +# First time: Opens browser for OAuth login +# Subsequent calls: Reuses cached credentials (auto-refreshes if expired) +llm = LLM.subscription_login( + vendor="openai", + model="gpt-5.2-codex", # or "gpt-5.2", "gpt-5.1-codex-max", "gpt-5.1-codex-mini" +) + +# Alternative: Force a fresh login (useful if credentials are stale) +# llm = LLM.subscription_login(vendor="openai", model="gpt-5.2-codex", force_login=True) + +# Alternative: Disable auto-opening browser (prints URL to console instead) +# llm = LLM.subscription_login( +# vendor="openai", model="gpt-5.2-codex", open_browser=False +# ) + +# Verify subscription mode is active +print(f"Using subscription mode: {llm.is_subscription}") + +# Use the LLM with an agent as usual +agent = Agent( + llm=llm, + tools=[ + Tool(name=TerminalTool.name), + Tool(name=FileEditorTool.name), + ], +) + +cwd = os.getcwd() +conversation = Conversation(agent=agent, workspace=cwd) + +conversation.send_message("List the files in the current directory.") +conversation.run() +print("Done!") diff --git a/openhands-sdk/openhands/sdk/llm/__init__.py b/openhands-sdk/openhands/sdk/llm/__init__.py index 63d8d437e6..b2b5ad3047 100644 --- a/openhands-sdk/openhands/sdk/llm/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/__init__.py @@ -1,3 +1,9 @@ +from openhands.sdk.llm.auth import ( + OPENAI_CODEX_MODELS, + CredentialStore, + OAuthCredentials, + OpenAISubscriptionAuth, +) from openhands.sdk.llm.llm import LLM from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent from openhands.sdk.llm.llm_response import LLMResponse @@ -22,11 +28,18 @@ __all__ = [ + # Auth + "CredentialStore", + "OAuthCredentials", + "OpenAISubscriptionAuth", + "OPENAI_CODEX_MODELS", + # Core "LLMResponse", "LLM", "LLMRegistry", "RouterLLM", "RegistryEvent", + # Messages "Message", "MessageToolCall", "TextContent", @@ -35,10 +48,13 @@ "RedactedThinkingBlock", "ReasoningItemModel", "content_to_str", + # Streaming "LLMStreamChunk", "TokenCallbackType", + # Metrics "Metrics", "MetricsSnapshot", + # Models "VERIFIED_MODELS", "UNVERIFIED_MODELS_EXCLUDING_BEDROCK", "get_unverified_models", diff --git a/openhands-sdk/openhands/sdk/llm/auth/__init__.py b/openhands-sdk/openhands/sdk/llm/auth/__init__.py new file mode 100644 index 0000000000..c67564c5bc --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/auth/__init__.py @@ -0,0 +1,28 @@ +"""Authentication module for LLM subscription-based access. + +This module provides OAuth-based authentication for LLM providers that support +subscription-based access (e.g., ChatGPT Plus/Pro for OpenAI Codex models). +""" + +from openhands.sdk.llm.auth.credentials import ( + CredentialStore, + OAuthCredentials, +) +from openhands.sdk.llm.auth.openai import ( + OPENAI_CODEX_MODELS, + OpenAISubscriptionAuth, + SupportedVendor, + inject_system_prefix, + transform_for_subscription, +) + + +__all__ = [ + "CredentialStore", + "OAuthCredentials", + "OpenAISubscriptionAuth", + "OPENAI_CODEX_MODELS", + "SupportedVendor", + "inject_system_prefix", + "transform_for_subscription", +] diff --git a/openhands-sdk/openhands/sdk/llm/auth/credentials.py b/openhands-sdk/openhands/sdk/llm/auth/credentials.py new file mode 100644 index 0000000000..262ff8f7d1 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/auth/credentials.py @@ -0,0 +1,157 @@ +"""Credential storage and retrieval for OAuth-based LLM authentication.""" + +from __future__ import annotations + +import json +import os +import time +import warnings +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel, Field + +from openhands.sdk.logger import get_logger + + +logger = get_logger(__name__) + + +def get_credentials_dir() -> Path: + """Get the directory for storing credentials. + + Uses XDG_DATA_HOME if set, otherwise defaults to ~/.local/share/openhands. + """ + return Path.home() / ".openhands" / "auth" + + +class OAuthCredentials(BaseModel): + """OAuth credentials for subscription-based LLM access.""" + + type: Literal["oauth"] = "oauth" + vendor: str = Field(description="The vendor/provider (e.g., 'openai')") + access_token: str = Field(description="The OAuth access token") + refresh_token: str = Field(description="The OAuth refresh token") + expires_at: int = Field( + description="Unix timestamp (ms) when the access token expires" + ) + + def is_expired(self) -> bool: + """Check if the access token is expired.""" + # Add 60 second buffer to avoid edge cases + # Add 60 second buffer to avoid edge cases where token expires during request + return self.expires_at < (int(time.time() * 1000) + 60_000) + + +class CredentialStore: + """Store and retrieve OAuth credentials for LLM providers.""" + + def __init__(self, credentials_dir: Path | None = None): + """Initialize the credential store. + + Args: + credentials_dir: Optional custom directory for storing credentials. + Defaults to ~/.local/share/openhands/auth/ + """ + self._credentials_dir = credentials_dir or get_credentials_dir() + logger.info(f"Using credentials directory: {self._credentials_dir}") + + @property + def credentials_dir(self) -> Path: + """Get the credentials directory, creating it if necessary.""" + self._credentials_dir.mkdir(parents=True, exist_ok=True) + # Set directory permissions to owner-only (rwx------) + if os.name != "nt": + self._credentials_dir.chmod(0o700) + return self._credentials_dir + + def _get_credentials_file(self, vendor: str) -> Path: + """Get the path to the credentials file for a vendor.""" + return self.credentials_dir / f"{vendor}_oauth.json" + + def get(self, vendor: str) -> OAuthCredentials | None: + """Get stored credentials for a vendor. + + Args: + vendor: The vendor/provider name (e.g., 'openai') + + Returns: + OAuthCredentials if found and valid, None otherwise + """ + creds_file = self._get_credentials_file(vendor) + if not creds_file.exists(): + return None + + try: + with open(creds_file) as f: + data = json.load(f) + return OAuthCredentials.model_validate(data) + except (json.JSONDecodeError, ValueError): + # Invalid credentials file, remove it + creds_file.unlink(missing_ok=True) + return None + + def save(self, credentials: OAuthCredentials) -> None: + """Save credentials for a vendor. + + Args: + credentials: The OAuth credentials to save + """ + creds_file = self._get_credentials_file(credentials.vendor) + with open(creds_file, "w") as f: + json.dump(credentials.model_dump(), f, indent=2) + # Set restrictive permissions (owner read/write only) + # Note: On Windows, NTFS ACLs should be used instead + if os.name != "nt": # Not Windows + creds_file.chmod(0o600) + else: + warnings.warn( + "File permissions on Windows should be manually restricted", + stacklevel=2, + ) + + def delete(self, vendor: str) -> bool: + """Delete stored credentials for a vendor. + + Args: + vendor: The vendor/provider name + + Returns: + True if credentials were deleted, False if they didn't exist + """ + creds_file = self._get_credentials_file(vendor) + if creds_file.exists(): + creds_file.unlink() + return True + return False + + def update_tokens( + self, + vendor: str, + access_token: str, + refresh_token: str | None, + expires_in: int, + ) -> OAuthCredentials | None: + """Update tokens for an existing credential. + + Args: + vendor: The vendor/provider name + access_token: New access token + refresh_token: New refresh token (if provided) + expires_in: Token expiry in seconds + + Returns: + Updated credentials, or None if no existing credentials found + """ + existing = self.get(vendor) + if existing is None: + return None + + updated = OAuthCredentials( + vendor=vendor, + access_token=access_token, + refresh_token=refresh_token or existing.refresh_token, + expires_at=int(time.time() * 1000) + (expires_in * 1000), + ) + self.save(updated) + return updated diff --git a/openhands-sdk/openhands/sdk/llm/auth/openai.py b/openhands-sdk/openhands/sdk/llm/auth/openai.py new file mode 100644 index 0000000000..706559d88c --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/auth/openai.py @@ -0,0 +1,762 @@ +"""OpenAI subscription-based authentication via OAuth. + +This module implements OAuth PKCE flow for authenticating with OpenAI's ChatGPT +service, allowing users with ChatGPT Plus/Pro subscriptions to use Codex models +without consuming API credits. + +Uses authlib for OAuth handling and aiohttp for the callback server. +""" + +from __future__ import annotations + +import asyncio +import platform +import sys +import threading +import time +import webbrowser +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal +from urllib.parse import urlencode + +from aiohttp import web +from authlib.common.security import generate_token +from authlib.jose import JsonWebKey, jwt +from authlib.jose.errors import JoseError +from authlib.oauth2.rfc7636 import create_s256_code_challenge +from httpx import AsyncClient, Client + +from openhands.sdk.llm.auth.credentials import ( + CredentialStore, + OAuthCredentials, + get_credentials_dir, +) +from openhands.sdk.logger import get_logger + + +if TYPE_CHECKING: + from openhands.sdk.llm.llm import LLM + +# Supported vendors for subscription-based authentication. +# Add new vendors here as they become supported. +SupportedVendor = Literal["openai"] + +logger = get_logger(__name__) + +# ========================================================================= +# Consent banner constants +# ========================================================================= + +CONSENT_BANNER = """\ +Signing in with ChatGPT uses your ChatGPT account. By continuing, you confirm \ +you are a ChatGPT End User and are subject to OpenAI's Terms of Use. +https://openai.com/policies/terms-of-use/ +""" + +CONSENT_MARKER_FILENAME = ".chatgpt_consent_acknowledged" + + +def _get_consent_marker_path() -> Path: + """Get the path to the consent acknowledgment marker file.""" + return get_credentials_dir() / CONSENT_MARKER_FILENAME + + +def _has_acknowledged_consent() -> bool: + """Check if the user has previously acknowledged the consent disclaimer.""" + return _get_consent_marker_path().exists() + + +def _mark_consent_acknowledged() -> None: + """Mark that the user has acknowledged the consent disclaimer.""" + marker_path = _get_consent_marker_path() + marker_path.parent.mkdir(parents=True, exist_ok=True) + marker_path.touch() + + +def _display_consent_and_confirm() -> bool: + """Display consent banner and get user confirmation. + + Returns: + True if user confirms, False otherwise. + + Raises: + RuntimeError: If running in non-interactive mode without prior consent. + """ + is_first_time = not _has_acknowledged_consent() + + # Always show the consent banner + print("\n" + "=" * 70) + print(CONSENT_BANNER) + print("=" * 70 + "\n") + + # Check if we're in an interactive terminal + if not sys.stdin.isatty(): + if is_first_time: + raise RuntimeError( + "Cannot proceed with ChatGPT sign-in: running in non-interactive mode " + "and consent has not been previously acknowledged. Please run " + "interactively first to acknowledge the terms." + ) + # Non-interactive but consent was previously given - proceed + logger.info("Non-interactive mode: using previously acknowledged consent") + return True + + # Interactive mode: prompt for confirmation + try: + response = input("Do you want to continue? [y/N]: ").strip().lower() + if response in ("y", "yes"): + if is_first_time: + _mark_consent_acknowledged() + return True + return False + except (EOFError, KeyboardInterrupt): + print() # Newline after ^C + return False + + +# OAuth configuration for OpenAI Codex +# This is a public client ID for OpenAI's OAuth flow (safe to commit) +CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +ISSUER = "https://auth.openai.com" +JWKS_URL = f"{ISSUER}/.well-known/jwks.json" +CODEX_API_ENDPOINT = "https://chatgpt.com/backend-api/codex/responses" +DEFAULT_OAUTH_PORT = 1455 +OAUTH_TIMEOUT_SECONDS = 300 # 5 minutes +JWKS_CACHE_TTL_SECONDS = 3600 # 1 hour + +# Models available via ChatGPT subscription (not API) +OPENAI_CODEX_MODELS = frozenset( + { + "gpt-5.1-codex-max", + "gpt-5.1-codex-mini", + "gpt-5.2", + "gpt-5.2-codex", + } +) + + +# Thread-safe JWKS cache +class _JWKSCache: + """Thread-safe cache for OpenAI's JWKS (JSON Web Key Set).""" + + def __init__(self) -> None: + self._keys: dict[str, Any] = {} + self._fetched_at: float = 0 + self._lock = threading.Lock() + + def get_key_set(self) -> Any: + """Get the JWKS, fetching from OpenAI if cache is stale or empty. + + Returns: + KeySet for verifying JWT signatures. + + Raises: + RuntimeError: If JWKS cannot be fetched. + """ + with self._lock: + now = time.time() + if not self._keys or (now - self._fetched_at) > JWKS_CACHE_TTL_SECONDS: + self._fetch_jwks() + return JsonWebKey.import_key_set(self._keys) + + def _fetch_jwks(self) -> None: + """Fetch JWKS from OpenAI's well-known endpoint.""" + try: + with Client(timeout=10) as client: + response = client.get(JWKS_URL) + response.raise_for_status() + self._keys = response.json() + self._fetched_at = time.time() + logger.debug( + f"Fetched JWKS from OpenAI: {len(self._keys.get('keys', []))} keys" + ) + except Exception as e: + raise RuntimeError(f"Failed to fetch OpenAI JWKS: {e}") from e + + def clear(self) -> None: + """Clear the cache (useful for testing).""" + with self._lock: + self._keys = {} + self._fetched_at = 0 + + +_jwks_cache = _JWKSCache() + + +def _generate_pkce() -> tuple[str, str]: + """Generate PKCE verifier and challenge using authlib.""" + verifier = generate_token(43) + challenge = create_s256_code_challenge(verifier) + return verifier, challenge + + +def _extract_chatgpt_account_id(access_token: str) -> str | None: + """Extract chatgpt_account_id from JWT access token with signature verification. + + Verifies the JWT signature using OpenAI's published JWKS before extracting + claims. This prevents attacks where a manipulated token could be injected + through OAuth callback interception. + + Args: + access_token: The JWT access token from OAuth flow + + Returns: + The chatgpt_account_id if found and signature is valid, None otherwise + """ + try: + # Fetch JWKS and verify JWT signature + key_set = _jwks_cache.get_key_set() + claims = jwt.decode(access_token, key_set) + + # Validate standard claims (issuer) + claims.validate() + + # Extract account ID from nested structure + auth_info = claims.get("https://api.openai.com/auth", {}) + account_id = auth_info.get("chatgpt_account_id") + + if account_id: + logger.debug(f"Extracted chatgpt_account_id: {account_id}") + return account_id + else: + logger.warning("chatgpt_account_id not found in JWT payload") + return None + + except JoseError as e: + logger.warning(f"JWT signature verification failed: {e}") + return None + except RuntimeError as e: + # JWKS fetch failed - log but don't crash + logger.warning(f"Could not verify JWT: {e}") + return None + except Exception as e: + logger.warning(f"Failed to decode JWT: {e}") + return None + + +def _build_authorize_url(redirect_uri: str, code_challenge: str, state: str) -> str: + """Build the OAuth authorization URL.""" + params = { + "response_type": "code", + "client_id": CLIENT_ID, + "redirect_uri": redirect_uri, + "scope": "openid profile email offline_access", + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + "state": state, + "originator": "openhands", + } + return f"{ISSUER}/oauth/authorize?{urlencode(params)}" + + +async def _exchange_code_for_tokens( + code: str, redirect_uri: str, code_verifier: str +) -> dict[str, Any]: + """Exchange authorization code for tokens.""" + async with AsyncClient() as client: + response = await client.post( + f"{ISSUER}/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": CLIENT_ID, + "code_verifier": code_verifier, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if not response.is_success: + raise RuntimeError(f"Token exchange failed: {response.status_code}") + return response.json() + + +async def _refresh_access_token(refresh_token: str) -> dict[str, Any]: + """Refresh the access token using a refresh token.""" + async with AsyncClient() as client: + response = await client.post( + f"{ISSUER}/oauth/token", + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": CLIENT_ID, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if not response.is_success: + raise RuntimeError(f"Token refresh failed: {response.status_code}") + return response.json() + + +# HTML templates for OAuth callback +_HTML_SUCCESS = """ + + + OpenHands - Authorization Successful + + + +
+

Authorization Successful

+

You can close this window and return to OpenHands.

+
+ + +""" + +_HTML_ERROR = """ + + + OpenHands - Authorization Failed + + + +
+

Authorization Failed

+

An error occurred during authorization.

+
{error}
+
+ +""" + + +class OpenAISubscriptionAuth: + """Handle OAuth authentication for OpenAI ChatGPT subscription access.""" + + def __init__( + self, + credential_store: CredentialStore | None = None, + oauth_port: int = DEFAULT_OAUTH_PORT, + ): + """Initialize the OpenAI subscription auth handler. + + Args: + credential_store: Optional custom credential store. + oauth_port: Port for the local OAuth callback server. + """ + self._credential_store = credential_store or CredentialStore() + self._oauth_port = oauth_port + + @property + def vendor(self) -> str: + """Get the vendor name.""" + return "openai" + + def get_credentials(self) -> OAuthCredentials | None: + """Get stored credentials if they exist.""" + return self._credential_store.get(self.vendor) + + def has_valid_credentials(self) -> bool: + """Check if valid (non-expired) credentials exist.""" + creds = self.get_credentials() + return creds is not None and not creds.is_expired() + + async def refresh_if_needed(self) -> OAuthCredentials | None: + """Refresh credentials if they are expired. + + Returns: + Updated credentials, or None if no credentials exist. + + Raises: + RuntimeError: If token refresh fails. + """ + creds = self.get_credentials() + if creds is None: + return None + + if not creds.is_expired(): + return creds + + logger.info("Refreshing OpenAI access token") + tokens = await _refresh_access_token(creds.refresh_token) + updated = self._credential_store.update_tokens( + vendor=self.vendor, + access_token=tokens["access_token"], + refresh_token=tokens.get("refresh_token"), + expires_in=tokens.get("expires_in", 3600), + ) + return updated + + async def login(self, open_browser: bool = True) -> OAuthCredentials: + """Perform OAuth login flow. + + This starts a local HTTP server to handle the OAuth callback, + opens the browser for user authentication, and waits for the + callback with the authorization code. + + Args: + open_browser: Whether to automatically open the browser. + + Returns: + The obtained OAuth credentials. + + Raises: + RuntimeError: If the OAuth flow fails or times out. + """ + code_verifier, code_challenge = _generate_pkce() + state = generate_token(32) + redirect_uri = f"http://localhost:{self._oauth_port}/auth/callback" + auth_url = _build_authorize_url(redirect_uri, code_challenge, state) + + # Future to receive callback result + callback_future: asyncio.Future[dict[str, Any]] = asyncio.Future() + + # Create aiohttp app for callback + app = web.Application() + + async def handle_callback(request: web.Request) -> web.Response: + params = request.query + + if "error" in params: + error_msg = params.get("error_description", params["error"]) + if not callback_future.done(): + callback_future.set_exception(RuntimeError(error_msg)) + return web.Response( + text=_HTML_ERROR.format(error=error_msg), + content_type="text/html", + ) + + code = params.get("code") + if not code: + error_msg = "Missing authorization code" + if not callback_future.done(): + callback_future.set_exception(RuntimeError(error_msg)) + return web.Response( + text=_HTML_ERROR.format(error=error_msg), + content_type="text/html", + status=400, + ) + + if params.get("state") != state: + error_msg = "Invalid state - potential CSRF attack" + if not callback_future.done(): + callback_future.set_exception(RuntimeError(error_msg)) + return web.Response( + text=_HTML_ERROR.format(error=error_msg), + content_type="text/html", + status=400, + ) + + try: + tokens = await _exchange_code_for_tokens( + code, redirect_uri, code_verifier + ) + if not callback_future.done(): + callback_future.set_result(tokens) + return web.Response(text=_HTML_SUCCESS, content_type="text/html") + except Exception as e: + if not callback_future.done(): + callback_future.set_exception(e) + return web.Response( + text=_HTML_ERROR.format(error=str(e)), + content_type="text/html", + status=500, + ) + + app.router.add_get("/auth/callback", handle_callback) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", self._oauth_port) + + try: + try: + await site.start() + except OSError as exc: + if "address already in use" in str(exc).lower(): + raise RuntimeError( + "OAuth callback server port " + f"{self._oauth_port} is already in use. " + "Please free the port or set a different one via " + "OPENHANDS_OAUTH_PORT." + ) from exc + raise + + logger.debug(f"OAuth callback server started on port {self._oauth_port}") + + if open_browser: + logger.info("Opening browser for OpenAI authentication...") + webbrowser.open(auth_url) + else: + logger.info( + f"Please open the following URL in your browser:\n{auth_url}" + ) + + try: + tokens = await asyncio.wait_for( + callback_future, timeout=OAUTH_TIMEOUT_SECONDS + ) + except TimeoutError: + raise RuntimeError( + "OAuth callback timeout - authorization took too long" + ) + + expires_at = int(time.time() * 1000) + ( + tokens.get("expires_in", 3600) * 1000 + ) + credentials = OAuthCredentials( + vendor=self.vendor, + access_token=tokens["access_token"], + refresh_token=tokens["refresh_token"], + expires_at=expires_at, + ) + self._credential_store.save(credentials) + logger.info("OpenAI OAuth login successful") + return credentials + + finally: + await runner.cleanup() + + def logout(self) -> bool: + """Remove stored credentials. + + Returns: + True if credentials were removed, False if none existed. + """ + return self._credential_store.delete(self.vendor) + + def create_llm( + self, + model: str = "gpt-5.2-codex", + credentials: OAuthCredentials | None = None, + instructions: str | None = None, + **llm_kwargs: Any, + ) -> LLM: + """Create an LLM instance configured for Codex subscription access. + + Args: + model: The model to use (must be in OPENAI_CODEX_MODELS). + credentials: OAuth credentials to use. If None, uses stored credentials. + instructions: Optional instructions for the Codex model. + **llm_kwargs: Additional arguments to pass to LLM constructor. + + Returns: + An LLM instance configured for Codex access. + + Raises: + ValueError: If the model is not supported or no credentials available. + """ + from openhands.sdk.llm.llm import LLM + + if model not in OPENAI_CODEX_MODELS: + raise ValueError( + f"Model '{model}' is not supported for subscription access. " + f"Supported models: {', '.join(sorted(OPENAI_CODEX_MODELS))}" + ) + + creds = credentials or self.get_credentials() + if creds is None: + raise ValueError( + "No credentials available. Call login() first or provide credentials." + ) + + account_id = _extract_chatgpt_account_id(creds.access_token) + if not account_id: + logger.warning( + "Could not extract chatgpt_account_id from access token. " + "API requests may fail." + ) + + # Build extra_body with Codex-specific params + extra_body: dict[str, Any] = {"store": False} + if instructions: + extra_body["instructions"] = instructions + if "litellm_extra_body" in llm_kwargs: + extra_body.update(llm_kwargs.pop("litellm_extra_body")) + + # Build headers matching OpenAI's official Codex CLI + extra_headers: dict[str, str] = { + "originator": "codex_cli_rs", + "OpenAI-Beta": "responses=experimental", + "User-Agent": f"openhands-sdk ({platform.system()}; {platform.machine()})", + } + if account_id: + extra_headers["chatgpt-account-id"] = account_id + + # Codex API requires streaming and doesn't support temperature/max_output_tokens + llm = LLM( + model=f"openai/{model}", + base_url=CODEX_API_ENDPOINT.rsplit("/", 1)[0], + api_key=creds.access_token, + extra_headers=extra_headers, + litellm_extra_body=extra_body, + temperature=None, + max_output_tokens=None, + stream=True, + **llm_kwargs, + ) + llm._is_subscription = True + # Ensure these stay None even if model info tried to set them + llm.max_output_tokens = None + llm.temperature = None + return llm + + +async def subscription_login_async( + vendor: SupportedVendor = "openai", + model: str = "gpt-5.2-codex", + force_login: bool = False, + open_browser: bool = True, + skip_consent: bool = False, + **llm_kwargs: Any, +) -> LLM: + """Authenticate with a subscription and return an LLM instance. + + This is the main entry point for subscription-based LLM access. + It handles credential caching, token refresh, and login flow. + + Args: + vendor: The vendor/provider (currently only "openai" is supported). + model: The model to use. + force_login: If True, always perform a fresh login. + open_browser: Whether to automatically open the browser for login. + skip_consent: If True, skip the consent prompt (for programmatic use + where consent has been obtained through other means). + **llm_kwargs: Additional arguments to pass to LLM constructor. + + Returns: + An LLM instance configured for subscription access. + + Raises: + ValueError: If the vendor is not supported. + RuntimeError: If authentication fails or user declines consent. + + Example: + >>> import asyncio + >>> from openhands.sdk.llm.auth import subscription_login_async + >>> llm = asyncio.run(subscription_login_async(model="gpt-5.2-codex")) + """ + if vendor != "openai": + raise ValueError( + f"Vendor '{vendor}' is not supported. Only 'openai' is supported." + ) + + auth = OpenAISubscriptionAuth() + + # Check for existing valid credentials + if not force_login: + creds = await auth.refresh_if_needed() + if creds is not None: + logger.info("Using existing OpenAI credentials") + return auth.create_llm(model=model, credentials=creds, **llm_kwargs) + + # Display consent banner and get confirmation before login + if not skip_consent: + if not _display_consent_and_confirm(): + raise RuntimeError("User declined to continue with ChatGPT sign-in") + + # Perform login + creds = await auth.login(open_browser=open_browser) + return auth.create_llm(model=model, credentials=creds, **llm_kwargs) + + +def subscription_login( + vendor: SupportedVendor = "openai", + model: str = "gpt-5.2-codex", + force_login: bool = False, + open_browser: bool = True, + skip_consent: bool = False, + **llm_kwargs: Any, +) -> LLM: + """Synchronous wrapper for subscription_login_async. + + See subscription_login_async for full documentation. + """ + return asyncio.run( + subscription_login_async( + vendor=vendor, + model=model, + force_login=force_login, + open_browser=open_browser, + skip_consent=skip_consent, + **llm_kwargs, + ) + ) + + +# ========================================================================= +# Message transformation utilities for subscription mode +# ========================================================================= + +DEFAULT_SYSTEM_MESSAGE = ( + "You are OpenHands agent, a helpful AI assistant that can interact " + "with a computer to solve tasks." +) + + +def inject_system_prefix( + input_items: list[dict[str, Any]], prefix_content: dict[str, Any] +) -> None: + """Inject system prefix into the first user message, or create one. + + This modifies input_items in place. + + Args: + input_items: List of input items (messages) to modify. + prefix_content: The content dict to prepend + (e.g., {"type": "input_text", "text": "..."}). + """ + for item in input_items: + if item.get("type") == "message" and item.get("role") == "user": + content = item.get("content") + if not isinstance(content, list): + content = [content] if content else [] + item["content"] = [prefix_content] + content + return + + # No user message found, create a synthetic one + input_items.insert(0, {"role": "user", "content": [prefix_content]}) + + +def transform_for_subscription( + system_chunks: list[str], input_items: list[dict[str, Any]] +) -> tuple[str, list[dict[str, Any]]]: + """Transform messages for Codex subscription transport. + + Codex subscription endpoints reject complex/long `instructions`, so we: + 1. Use a minimal default instruction string + 2. Prepend system prompts to the first user message + 3. Normalize message format to match OpenCode's Codex client + + Args: + system_chunks: List of system prompt strings to merge. + input_items: List of input items (messages) to transform. + + Returns: + A tuple of (instructions, normalized_input_items). + """ + # Prepend system prompts to first user message + if system_chunks: + merged = "\n\n---\n\n".join(system_chunks) + prefix_content = { + "type": "input_text", + "text": f"Context (system prompt):\n{merged}\n\n", + } + inject_system_prefix(input_items, prefix_content) + + # Normalize: {"type": "message", ...} -> {"role": ..., "content": ...} + normalized = [ + {"role": item.get("role"), "content": item.get("content") or []} + if item.get("type") == "message" + else item + for item in input_items + ] + return DEFAULT_SYSTEM_MESSAGE, normalized diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 6f0801cb92..61dbbdd641 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -27,8 +27,11 @@ if TYPE_CHECKING: # type hints only, avoid runtime import cycle + from openhands.sdk.llm.auth import SupportedVendor from openhands.sdk.tool.tool import ToolDefinition +from openhands.sdk.llm.auth.openai import transform_for_subscription + with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -50,8 +53,20 @@ Timeout as LiteLLMTimeout, ) from litellm.responses.main import responses as litellm_responses -from litellm.types.llms.openai import ResponsesAPIResponse -from litellm.types.utils import ModelResponse +from litellm.responses.streaming_iterator import SyncResponsesAPIStreamingIterator +from litellm.types.llms.openai import ( + OutputTextDeltaEvent, + ReasoningSummaryTextDeltaEvent, + RefusalDeltaEvent, + ResponseCompletedEvent, + ResponsesAPIResponse, +) +from litellm.types.utils import ( + Delta, + ModelResponse, + ModelResponseStream, + StreamingChoices, +) from litellm.utils import ( create_pretrained_tokenizer, supports_vision, @@ -335,6 +350,7 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): _model_info: Any = PrivateAttr(default=None) _tokenizer: Any = PrivateAttr(default=None) _telemetry: Telemetry | None = PrivateAttr(default=None) + _is_subscription: bool = PrivateAttr(default=False) model_config: ClassVar[ConfigDict] = ConfigDict( extra="ignore", arbitrary_types_allowed=True @@ -499,6 +515,19 @@ def telemetry(self) -> Telemetry: ) return self._telemetry + @property + def is_subscription(self) -> bool: + """Check if this LLM uses subscription-based authentication. + + Returns True when the LLM was created via `LLM.subscription_login()`, + which uses the ChatGPT subscription Codex backend rather than the + standard OpenAI API. + + Returns: + bool: True if using subscription-based transport, False otherwise. + """ + return self._is_subscription + def restore_metrics(self, metrics: Metrics) -> None: # Only used by ConversationStats to seed metrics self._metrics = metrics @@ -662,7 +691,7 @@ def _one_attempt(**retry_kwargs) -> ModelResponse: raise # ========================================================================= - # Responses API (non-stream, v1) + # Responses API (v1) # ========================================================================= def responses( self, @@ -686,16 +715,19 @@ def responses( store: Whether to store the conversation _return_metrics: Whether to return usage metrics add_security_risk_prediction: Add security_risk field to tool schemas - on_token: Optional callback for streaming tokens (not yet supported) + on_token: Optional callback for streaming deltas **kwargs: Additional arguments passed to the API Note: Summary field is always added to tool schemas for transparency and explainability of agent actions. """ - # Streaming not yet supported - if kwargs.get("stream", False) or self.stream or on_token is not None: - raise ValueError("Streaming is not supported for Responses API yet") + user_enable_streaming = bool(kwargs.get("stream", False)) or self.stream + if user_enable_streaming: + if on_token is None and not self.is_subscription: + # We allow on_token to be None for subscription mode + raise ValueError("Streaming requires an on_token callback") + kwargs["stream"] = True # Build instructions + input list using dedicated Responses formatter instructions, input_items = self.format_messages_for_responses(messages) @@ -771,12 +803,67 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: seed=self.seed, **final_kwargs, ) - assert isinstance(ret, ResponsesAPIResponse), ( + if isinstance(ret, ResponsesAPIResponse): + if user_enable_streaming: + logger.warning( + "Responses streaming was requested, but the provider " + "returned a non-streaming response; no on_token deltas " + "will be emitted." + ) + self._telemetry.on_response(ret) + return ret + + # When stream=True, LiteLLM returns a streaming iterator rather than + # a single ResponsesAPIResponse. Drain the iterator and use the + # completed response. + if final_kwargs.get("stream", False): + if not isinstance(ret, SyncResponsesAPIStreamingIterator): + raise AssertionError( + f"Expected Responses stream iterator, got {type(ret)}" + ) + + stream_callback = on_token if user_enable_streaming else None + for event in ret: + if stream_callback is None: + continue + if isinstance( + event, + ( + OutputTextDeltaEvent, + RefusalDeltaEvent, + ReasoningSummaryTextDeltaEvent, + ), + ): + delta = event.delta + if delta: + stream_callback( + ModelResponseStream( + choices=[ + StreamingChoices( + delta=Delta(content=delta) + ) + ] + ) + ) + + completed_event = ret.completed_response + if completed_event is None: + raise LLMNoResponseError( + "Responses stream finished without a completed response" + ) + if not isinstance(completed_event, ResponseCompletedEvent): + raise LLMNoResponseError( + f"Unexpected completed event: {type(completed_event)}" + ) + + completed_resp = completed_event.response + + self._telemetry.on_response(completed_resp) + return completed_resp + + raise AssertionError( f"Expected ResponsesAPIResponse, got {type(ret)}" ) - # telemetry (latency, cost). Token usage mapping we handle after. - self._telemetry.on_response(ret) - return ret try: resp: ResponsesAPIResponse = _one_attempt() @@ -1046,8 +1133,9 @@ def format_messages_for_responses( - Skips prompt caching flags and string serializer concerns - Uses Message.to_responses_value to get either instructions (system) - or input items (others) + or input items (others) - Concatenates system instructions into a single instructions string + - For subscription mode, system prompts are prepended to user content """ msgs = copy.deepcopy(messages) @@ -1057,18 +1145,26 @@ def format_messages_for_responses( # Assign system instructions as a string, collect input items instructions: str | None = None input_items: list[dict[str, Any]] = [] + system_chunks: list[str] = [] + for m in msgs: val = m.to_responses_value(vision_enabled=vision_active) if isinstance(val, str): s = val.strip() - if not s: - continue - instructions = ( - s if instructions is None else f"{instructions}\n\n---\n\n{s}" - ) - else: - if val: - input_items.extend(val) + if s: + if self.is_subscription: + system_chunks.append(s) + else: + instructions = ( + s + if instructions is None + else f"{instructions}\n\n---\n\n{s}" + ) + elif val: + input_items.extend(val) + + if self.is_subscription: + return transform_for_subscription(system_chunks, input_items) return instructions, input_items def get_token_count(self, messages: list[Message]) -> int: @@ -1159,3 +1255,62 @@ def _cast_value(raw: str, t: Any) -> Any: if v is not None: data[field_name] = v return cls(**data) + + @classmethod + def subscription_login( + cls, + vendor: SupportedVendor, + model: str, + force_login: bool = False, + open_browser: bool = True, + **llm_kwargs, + ) -> LLM: + """Authenticate with a subscription service and return an LLM instance. + + This method provides subscription-based access to LLM models that are + available through chat subscriptions (e.g., ChatGPT Plus/Pro) rather + than API credits. It handles credential caching, token refresh, and + the OAuth login flow. + + Currently supported vendors: + - "openai": ChatGPT Plus/Pro subscription for Codex models + + Supported OpenAI models: + - gpt-5.1-codex-max + - gpt-5.1-codex-mini + - gpt-5.2 + - gpt-5.2-codex + + Args: + vendor: The vendor/provider. Currently only "openai" is supported. + model: The model to use. Must be supported by the vendor's + subscription service. + force_login: If True, always perform a fresh login even if valid + credentials exist. + open_browser: Whether to automatically open the browser for the + OAuth login flow. + **llm_kwargs: Additional arguments to pass to the LLM constructor. + + Returns: + An LLM instance configured for subscription-based access. + + Raises: + ValueError: If the vendor or model is not supported. + RuntimeError: If authentication fails. + + Example: + >>> from openhands.sdk import LLM + >>> # First time: opens browser for OAuth login + >>> llm = LLM.subscription_login(vendor="openai", model="gpt-5.2-codex") + >>> # Subsequent calls: reuses cached credentials + >>> llm = LLM.subscription_login(vendor="openai", model="gpt-5.2-codex") + """ + from openhands.sdk.llm.auth.openai import subscription_login + + return subscription_login( + vendor=vendor, + model=model, + force_login=force_login, + open_browser=open_browser, + **llm_kwargs, + ) diff --git a/openhands-sdk/openhands/sdk/llm/options/responses_options.py b/openhands-sdk/openhands/sdk/llm/options/responses_options.py index f2343906c9..562ec66c12 100644 --- a/openhands-sdk/openhands/sdk/llm/options/responses_options.py +++ b/openhands-sdk/openhands/sdk/llm/options/responses_options.py @@ -15,15 +15,16 @@ def select_responses_options( ) -> dict[str, Any]: """Behavior-preserving extraction of _normalize_responses_kwargs.""" # Apply defaults for keys that are not forced by policy - out = apply_defaults_if_absent( - user_kwargs, - { - "max_output_tokens": llm.max_output_tokens, - }, - ) + # Note: max_output_tokens is not supported in subscription mode + defaults = {} + if not llm.is_subscription: + defaults["max_output_tokens"] = llm.max_output_tokens + out = apply_defaults_if_absent(user_kwargs, defaults) # Enforce sampling/tool behavior for Responses path - out["temperature"] = 1.0 + # Note: temperature is not supported in subscription mode + if not llm.is_subscription: + out["temperature"] = 1.0 out["tool_choice"] = "auto" # If user didn't set extra_headers, propagate from llm config diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 9b569ebfcc..877a43b826 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -49,6 +49,7 @@ "examples/01_standalone_sdk/15_browser_use.py", "examples/01_standalone_sdk/16_llm_security_analyzer.py", "examples/01_standalone_sdk/27_observability_laminar.py", + "examples/01_standalone_sdk/35_subscription_login.py", "examples/02_remote_agent_server/04_vscode_with_docker_sandboxed_server.py", } diff --git a/tests/sdk/llm/auth/__init__.py b/tests/sdk/llm/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sdk/llm/auth/test_credentials.py b/tests/sdk/llm/auth/test_credentials.py new file mode 100644 index 0000000000..c9e0f5f149 --- /dev/null +++ b/tests/sdk/llm/auth/test_credentials.py @@ -0,0 +1,194 @@ +"""Tests for credential storage and retrieval.""" + +import time +from pathlib import Path + +from openhands.sdk.llm.auth.credentials import ( + CredentialStore, + OAuthCredentials, + get_credentials_dir, +) + + +def test_oauth_credentials_model(): + """Test OAuthCredentials model creation and validation.""" + expires_at = int(time.time() * 1000) + 3600_000 # 1 hour from now + creds = OAuthCredentials( + vendor="openai", + access_token="test_access_token", + refresh_token="test_refresh_token", + expires_at=expires_at, + ) + assert creds.vendor == "openai" + assert creds.access_token == "test_access_token" + assert creds.refresh_token == "test_refresh_token" + assert creds.expires_at == expires_at + assert creds.type == "oauth" + + +def test_oauth_credentials_is_expired(): + """Test OAuthCredentials expiration check.""" + # Not expired (1 hour from now) + future_creds = OAuthCredentials( + vendor="openai", + access_token="test", + refresh_token="test", + expires_at=int(time.time() * 1000) + 3600_000, + ) + assert not future_creds.is_expired() + + # Expired (1 hour ago) + past_creds = OAuthCredentials( + vendor="openai", + access_token="test", + refresh_token="test", + expires_at=int(time.time() * 1000) - 3600_000, + ) + assert past_creds.is_expired() + + +def test_get_credentials_dir_default(monkeypatch): + """Test default credentials directory.""" + monkeypatch.delenv("XDG_DATA_HOME", raising=False) + creds_dir = get_credentials_dir() + assert creds_dir == Path.home() / ".openhands" / "auth" + + +def test_get_credentials_dir_xdg(monkeypatch, tmp_path): + """Test credentials directory ignores XDG_DATA_HOME (uses ~/.openhands/auth).""" + monkeypatch.setenv("XDG_DATA_HOME", str(tmp_path)) + creds_dir = get_credentials_dir() + # Implementation uses ~/.openhands/auth regardless of XDG_DATA_HOME + assert creds_dir == Path.home() / ".openhands" / "auth" + + +def test_credential_store_save_and_get(tmp_path): + """Test saving and retrieving credentials.""" + store = CredentialStore(credentials_dir=tmp_path) + creds = OAuthCredentials( + vendor="openai", + access_token="test_access", + refresh_token="test_refresh", + expires_at=int(time.time() * 1000) + 3600_000, + ) + + store.save(creds) + + # Verify file was created + creds_file = tmp_path / "openai_oauth.json" + assert creds_file.exists() + + # Verify file permissions (owner read/write only) + assert (creds_file.stat().st_mode & 0o777) == 0o600 + + # Retrieve and verify + retrieved = store.get("openai") + assert retrieved is not None + assert retrieved.vendor == creds.vendor + assert retrieved.access_token == creds.access_token + assert retrieved.refresh_token == creds.refresh_token + assert retrieved.expires_at == creds.expires_at + + +def test_credential_store_get_nonexistent(tmp_path): + """Test getting credentials that don't exist.""" + store = CredentialStore(credentials_dir=tmp_path) + assert store.get("nonexistent") is None + + +def test_credential_store_get_invalid_json(tmp_path): + """Test getting credentials from invalid JSON file.""" + store = CredentialStore(credentials_dir=tmp_path) + tmp_path.mkdir(parents=True, exist_ok=True) + + # Create invalid JSON file + creds_file = tmp_path / "openai_oauth.json" + creds_file.write_text("invalid json") + + # Should return None and delete the invalid file + assert store.get("openai") is None + assert not creds_file.exists() + + +def test_credential_store_delete(tmp_path): + """Test deleting credentials.""" + store = CredentialStore(credentials_dir=tmp_path) + creds = OAuthCredentials( + vendor="openai", + access_token="test", + refresh_token="test", + expires_at=int(time.time() * 1000) + 3600_000, + ) + store.save(creds) + + # Delete and verify + assert store.delete("openai") is True + assert store.get("openai") is None + + # Delete again should return False + assert store.delete("openai") is False + + +def test_credential_store_update_tokens(tmp_path): + """Test updating tokens for existing credentials.""" + store = CredentialStore(credentials_dir=tmp_path) + original = OAuthCredentials( + vendor="openai", + access_token="old_access", + refresh_token="old_refresh", + expires_at=int(time.time() * 1000) + 3600_000, + ) + store.save(original) + + # Update tokens + updated = store.update_tokens( + vendor="openai", + access_token="new_access", + refresh_token="new_refresh", + expires_in=7200, # 2 hours + ) + + assert updated is not None + assert updated.access_token == "new_access" + assert updated.refresh_token == "new_refresh" + + # Verify persisted + retrieved = store.get("openai") + assert retrieved is not None + assert retrieved.access_token == "new_access" + + +def test_credential_store_update_tokens_keeps_refresh_if_not_provided(tmp_path): + """Test that update_tokens keeps old refresh token if new one not provided.""" + store = CredentialStore(credentials_dir=tmp_path) + original = OAuthCredentials( + vendor="openai", + access_token="old_access", + refresh_token="original_refresh", + expires_at=int(time.time() * 1000) + 3600_000, + ) + store.save(original) + + # Update without new refresh token + updated = store.update_tokens( + vendor="openai", + access_token="new_access", + refresh_token=None, + expires_in=3600, + ) + + assert updated is not None + assert updated.access_token == "new_access" + assert updated.refresh_token == "original_refresh" + + +def test_credential_store_update_tokens_nonexistent(tmp_path): + """Test updating tokens for non-existent credentials.""" + store = CredentialStore(credentials_dir=tmp_path) + result = store.update_tokens( + vendor="openai", + access_token="new_access", + refresh_token="new_refresh", + expires_in=3600, + ) + assert result is None diff --git a/tests/sdk/llm/auth/test_openai.py b/tests/sdk/llm/auth/test_openai.py new file mode 100644 index 0000000000..69cee90487 --- /dev/null +++ b/tests/sdk/llm/auth/test_openai.py @@ -0,0 +1,386 @@ +"""Tests for OpenAI subscription authentication. + +Note: Tests for JWT verification and JWKS caching have been removed as they +require real OAuth tokens to be meaningful. See GitHub issue #1806 for tracking +integration test requirements. +""" + +import time +from unittest.mock import AsyncMock, patch + +import pytest + +from openhands.sdk.llm.auth.credentials import CredentialStore, OAuthCredentials +from openhands.sdk.llm.auth.openai import ( + CLIENT_ID, + CONSENT_BANNER, + ISSUER, + OPENAI_CODEX_MODELS, + OpenAISubscriptionAuth, + _build_authorize_url, + _display_consent_and_confirm, + _generate_pkce, + _get_consent_marker_path, + _has_acknowledged_consent, + _mark_consent_acknowledged, +) + + +def test_generate_pkce(): + """Test PKCE code generation using authlib.""" + verifier, challenge = _generate_pkce() + assert verifier is not None + assert challenge is not None + assert len(verifier) > 0 + assert len(challenge) > 0 + # Verifier and challenge should be different + assert verifier != challenge + + +def test_pkce_codes_are_unique(): + """Test that PKCE codes are unique each time.""" + verifier1, challenge1 = _generate_pkce() + verifier2, challenge2 = _generate_pkce() + assert verifier1 != verifier2 + assert challenge1 != challenge2 + + +def test_build_authorize_url(): + """Test building the OAuth authorization URL.""" + code_challenge = "test_challenge" + state = "test_state" + redirect_uri = "http://localhost:1455/auth/callback" + + url = _build_authorize_url(redirect_uri, code_challenge, state) + + assert url.startswith(f"{ISSUER}/oauth/authorize?") + assert f"client_id={CLIENT_ID}" in url + assert "redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback" in url + assert "code_challenge=test_challenge" in url + assert "code_challenge_method=S256" in url + assert "state=test_state" in url + assert "originator=openhands" in url + assert "response_type=code" in url + + +def test_openai_codex_models(): + """Test that OPENAI_CODEX_MODELS contains expected models.""" + assert "gpt-5.2-codex" in OPENAI_CODEX_MODELS + assert "gpt-5.2" in OPENAI_CODEX_MODELS + assert "gpt-5.1-codex-max" in OPENAI_CODEX_MODELS + assert "gpt-5.1-codex-mini" in OPENAI_CODEX_MODELS + + +def test_openai_subscription_auth_vendor(): + """Test OpenAISubscriptionAuth vendor property.""" + auth = OpenAISubscriptionAuth() + assert auth.vendor == "openai" + + +def test_openai_subscription_auth_get_credentials(tmp_path): + """Test getting credentials from store.""" + store = CredentialStore(credentials_dir=tmp_path) + auth = OpenAISubscriptionAuth(credential_store=store) + + # No credentials initially + assert auth.get_credentials() is None + + # Save credentials + creds = OAuthCredentials( + vendor="openai", + access_token="test_access", + refresh_token="test_refresh", + expires_at=int(time.time() * 1000) + 3600_000, + ) + store.save(creds) + + # Now should return credentials + retrieved = auth.get_credentials() + assert retrieved is not None + assert retrieved.access_token == "test_access" + + +def test_openai_subscription_auth_has_valid_credentials(tmp_path): + """Test checking for valid credentials.""" + store = CredentialStore(credentials_dir=tmp_path) + auth = OpenAISubscriptionAuth(credential_store=store) + + # No credentials + assert not auth.has_valid_credentials() + + # Valid credentials + valid_creds = OAuthCredentials( + vendor="openai", + access_token="test", + refresh_token="test", + expires_at=int(time.time() * 1000) + 3600_000, + ) + store.save(valid_creds) + assert auth.has_valid_credentials() + + # Expired credentials + expired_creds = OAuthCredentials( + vendor="openai", + access_token="test", + refresh_token="test", + expires_at=int(time.time() * 1000) - 3600_000, + ) + store.save(expired_creds) + assert not auth.has_valid_credentials() + + +def test_openai_subscription_auth_logout(tmp_path): + """Test logout removes credentials.""" + store = CredentialStore(credentials_dir=tmp_path) + auth = OpenAISubscriptionAuth(credential_store=store) + + # Save credentials + creds = OAuthCredentials( + vendor="openai", + access_token="test", + refresh_token="test", + expires_at=int(time.time() * 1000) + 3600_000, + ) + store.save(creds) + assert auth.has_valid_credentials() + + # Logout + assert auth.logout() is True + assert not auth.has_valid_credentials() + + # Logout again should return False + assert auth.logout() is False + + +def test_openai_subscription_auth_create_llm_invalid_model(tmp_path): + """Test create_llm raises error for invalid model.""" + store = CredentialStore(credentials_dir=tmp_path) + auth = OpenAISubscriptionAuth(credential_store=store) + + # Save valid credentials + creds = OAuthCredentials( + vendor="openai", + access_token="test", + refresh_token="test", + expires_at=int(time.time() * 1000) + 3600_000, + ) + store.save(creds) + + with pytest.raises(ValueError, match="not supported for subscription access"): + auth.create_llm(model="gpt-4") + + +def test_openai_subscription_auth_create_llm_no_credentials(tmp_path): + """Test create_llm raises error when no credentials available.""" + store = CredentialStore(credentials_dir=tmp_path) + auth = OpenAISubscriptionAuth(credential_store=store) + + with pytest.raises(ValueError, match="No credentials available"): + auth.create_llm(model="gpt-5.2-codex") + + +def test_openai_subscription_auth_create_llm_success(tmp_path): + """Test create_llm creates LLM with correct configuration.""" + store = CredentialStore(credentials_dir=tmp_path) + auth = OpenAISubscriptionAuth(credential_store=store) + + # Save valid credentials + creds = OAuthCredentials( + vendor="openai", + access_token="test_access_token", + refresh_token="test_refresh", + expires_at=int(time.time() * 1000) + 3600_000, + ) + store.save(creds) + + llm = auth.create_llm(model="gpt-5.2-codex") + + assert llm.model == "openai/gpt-5.2-codex" + assert llm.api_key is not None + assert llm.extra_headers is not None + # Uses codex_cli_rs to match official Codex CLI for compatibility + assert llm.extra_headers.get("originator") == "codex_cli_rs" + + +@pytest.mark.asyncio +async def test_openai_subscription_auth_refresh_if_needed_no_creds(tmp_path): + """Test refresh_if_needed returns None when no credentials.""" + store = CredentialStore(credentials_dir=tmp_path) + auth = OpenAISubscriptionAuth(credential_store=store) + + result = await auth.refresh_if_needed() + assert result is None + + +@pytest.mark.asyncio +async def test_openai_subscription_auth_refresh_if_needed_valid_creds(tmp_path): + """Test refresh_if_needed returns existing creds when not expired.""" + store = CredentialStore(credentials_dir=tmp_path) + auth = OpenAISubscriptionAuth(credential_store=store) + + # Save valid credentials + creds = OAuthCredentials( + vendor="openai", + access_token="test_access", + refresh_token="test_refresh", + expires_at=int(time.time() * 1000) + 3600_000, + ) + store.save(creds) + + result = await auth.refresh_if_needed() + assert result is not None + assert result.access_token == "test_access" + + +@pytest.mark.asyncio +async def test_openai_subscription_auth_refresh_if_needed_expired_creds(tmp_path): + """Test refresh_if_needed refreshes expired credentials.""" + store = CredentialStore(credentials_dir=tmp_path) + auth = OpenAISubscriptionAuth(credential_store=store) + + # Save expired credentials + creds = OAuthCredentials( + vendor="openai", + access_token="old_access", + refresh_token="test_refresh", + expires_at=int(time.time() * 1000) - 3600_000, + ) + store.save(creds) + + # Mock the refresh function + with patch( + "openhands.sdk.llm.auth.openai._refresh_access_token", + new_callable=AsyncMock, + ) as mock_refresh: + mock_refresh.return_value = { + "access_token": "new_access", + "refresh_token": "new_refresh", + "expires_in": 3600, + } + + result = await auth.refresh_if_needed() + + assert result is not None + assert result.access_token == "new_access" + mock_refresh.assert_called_once_with("test_refresh") + + +# ========================================================================= +# Tests for consent banner system +# ========================================================================= + + +class TestConsentBannerSystem: + """Tests for the consent banner and acknowledgment system.""" + + def test_consent_banner_content(self): + """Test that consent banner contains required text.""" + assert "ChatGPT" in CONSENT_BANNER + assert "Terms of Use" in CONSENT_BANNER + assert "openai.com/policies/terms-of-use" in CONSENT_BANNER + + def test_consent_marker_path(self, tmp_path): + """Test that consent marker path is in credentials directory.""" + with patch( + "openhands.sdk.llm.auth.openai.get_credentials_dir", return_value=tmp_path + ): + marker_path = _get_consent_marker_path() + assert marker_path.parent == tmp_path + assert ".chatgpt_consent_acknowledged" in str(marker_path) + + def test_has_acknowledged_consent_false_initially(self, tmp_path): + """Test that consent is not acknowledged initially.""" + with patch( + "openhands.sdk.llm.auth.openai.get_credentials_dir", return_value=tmp_path + ): + assert not _has_acknowledged_consent() + + def test_mark_consent_acknowledged(self, tmp_path): + """Test marking consent as acknowledged.""" + with patch( + "openhands.sdk.llm.auth.openai.get_credentials_dir", return_value=tmp_path + ): + assert not _has_acknowledged_consent() + _mark_consent_acknowledged() + assert _has_acknowledged_consent() + + def test_display_consent_user_accepts(self, tmp_path, capsys): + """Test consent display when user accepts.""" + with ( + patch( + "openhands.sdk.llm.auth.openai.get_credentials_dir", + return_value=tmp_path, + ), + patch("sys.stdin.isatty", return_value=True), + patch("builtins.input", return_value="y"), + ): + result = _display_consent_and_confirm() + assert result is True + + # Check banner was printed + captured = capsys.readouterr() + assert "ChatGPT" in captured.out + assert "Terms of Use" in captured.out + + def test_display_consent_user_declines(self, tmp_path, capsys): + """Test consent display when user declines.""" + with ( + patch( + "openhands.sdk.llm.auth.openai.get_credentials_dir", + return_value=tmp_path, + ), + patch("sys.stdin.isatty", return_value=True), + patch("builtins.input", return_value="n"), + ): + result = _display_consent_and_confirm() + assert result is False + + def test_display_consent_non_interactive_first_time_raises(self, tmp_path): + """Test that non-interactive mode raises error on first time.""" + with ( + patch( + "openhands.sdk.llm.auth.openai.get_credentials_dir", + return_value=tmp_path, + ), + patch("sys.stdin.isatty", return_value=False), + ): + with pytest.raises(RuntimeError, match="non-interactive mode"): + _display_consent_and_confirm() + + def test_display_consent_non_interactive_after_acknowledgment(self, tmp_path): + """Test that non-interactive mode works after prior acknowledgment.""" + with patch( + "openhands.sdk.llm.auth.openai.get_credentials_dir", return_value=tmp_path + ): + # Mark consent as acknowledged + _mark_consent_acknowledged() + + with patch("sys.stdin.isatty", return_value=False): + result = _display_consent_and_confirm() + assert result is True + + def test_display_consent_keyboard_interrupt(self, tmp_path): + """Test handling of keyboard interrupt during consent.""" + with ( + patch( + "openhands.sdk.llm.auth.openai.get_credentials_dir", + return_value=tmp_path, + ), + patch("sys.stdin.isatty", return_value=True), + patch("builtins.input", side_effect=KeyboardInterrupt), + ): + result = _display_consent_and_confirm() + assert result is False + + def test_display_consent_eof_error(self, tmp_path): + """Test handling of EOF during consent.""" + with ( + patch( + "openhands.sdk.llm.auth.openai.get_credentials_dir", + return_value=tmp_path, + ), + patch("sys.stdin.isatty", return_value=True), + patch("builtins.input", side_effect=EOFError), + ): + result = _display_consent_and_confirm() + assert result is False diff --git a/tests/sdk/llm/test_responses_serialization.py b/tests/sdk/llm/test_responses_serialization.py index 1e0ce73abb..8c39e081a0 100644 --- a/tests/sdk/llm/test_responses_serialization.py +++ b/tests/sdk/llm/test_responses_serialization.py @@ -48,6 +48,49 @@ def test_system_to_responses_value_instructions_concat(): assert inputs == [] +def test_subscription_codex_transport_does_not_use_top_level_instructions_and_prepend_system_to_user(): # noqa: E501 + m_sys = Message(role="system", content=[TextContent(text="SYS")]) + m_user = Message(role="user", content=[TextContent(text="USER")]) + + llm = LLM(model="gpt-5.1-codex", base_url="https://chatgpt.com/backend-api/codex") + llm._is_subscription = True # Mark as subscription-based + instr, inputs = llm.format_messages_for_responses([m_sys, m_user]) + + assert instr is not None + assert "OpenHands agent" in instr + assert len(inputs) >= 1 + first_user = next(it for it in inputs if it.get("role") == "user") + content = first_user.get("content") + assert isinstance(content, list) + assert content[0]["type"] == "input_text" + assert "SYS" in content[0]["text"] + + +def test_subscription_codex_transport_injects_synthetic_user_message_when_none_exists(): + m_sys = Message(role="system", content=[TextContent(text="SYS")]) + m_asst = Message(role="assistant", content=[TextContent(text="ASST")]) + + llm = LLM(model="gpt-5.1-codex", base_url="https://chatgpt.com/backend-api/codex") + llm._is_subscription = True # Mark as subscription-based + instr, inputs = llm.format_messages_for_responses([m_sys, m_asst]) + + assert instr is not None + assert "OpenHands agent" in instr + assert len(inputs) >= 1 + first = inputs[0] + assert first.get("role") == "user" + assert "SYS" in first["content"][0]["text"] + + +def test_api_codex_models_keep_system_as_instructions(): + m_sys = Message(role="system", content=[TextContent(text="SYS")]) + llm = LLM(model="gpt-5.1-codex") + instr, inputs = llm.format_messages_for_responses([m_sys]) + + assert instr == "SYS" + assert inputs == [] + + def test_user_to_responses_dict_with_and_without_vision(): m = Message( role="user",