diff --git a/packages/app/src/components/dialog-auth-usage.tsx b/packages/app/src/components/dialog-auth-usage.tsx new file mode 100644 index 00000000000..cefebb561f4 --- /dev/null +++ b/packages/app/src/components/dialog-auth-usage.tsx @@ -0,0 +1,267 @@ +import { Dialog } from "@opencode-ai/ui/dialog" +import { ProviderIcon } from "@opencode-ai/ui/provider-icon" +import type { IconName } from "@opencode-ai/ui/icons/provider" +import { Spinner } from "@opencode-ai/ui/spinner" +import { createResource, For, Show, createMemo, createSignal } from "solid-js" +import { useGlobalSDK } from "@/context/global-sdk" + +interface AccountUsage { + id: string + label?: string + isActive?: boolean + health: { + successCount: number + failureCount: number + lastStatusCode?: number + cooldownUntil?: number + } +} + +interface AnthropicUsage { + fiveHour?: { utilization: number; resetsAt?: string } + sevenDay?: { utilization: number; resetsAt?: string } + sevenDaySonnet?: { utilization: number; resetsAt?: string } +} + +interface ProviderUsage { + accounts: AccountUsage[] + anthropicUsage?: AnthropicUsage +} + +type AuthUsageData = Record + +function formatResetTime(resetAt?: string): string { + if (!resetAt) return "" + const reset = new Date(resetAt) + const now = new Date() + const diffMs = reset.getTime() - now.getTime() + if (diffMs <= 0) return "now" + + const totalMinutes = Math.floor(diffMs / (1000 * 60)) + const hours = Math.floor(totalMinutes / 60) + const minutes = totalMinutes % 60 + + if (hours > 0) return `${hours}h ${minutes}m` + return `${minutes}m` +} + +function getColorClass(percent: number): string { + if (percent <= 50) return "bg-fill-success-base" + if (percent <= 80) return "bg-fill-warning-base" + return "bg-fill-danger-base" +} + +function UsageBarPercent(props: { label: string; utilization: number; resetsAt?: string }) { + return ( +
+
+ {props.label} + {props.utilization}% used +
+
+
+
+ +
Resets in {formatResetTime(props.resetsAt)}
+
+
+ ) +} + +export function DialogAuthUsage() { + const globalSDK = useGlobalSDK() + const [switching, setSwitching] = createSignal(null) + + const [usage, { refetch, mutate }] = createResource(async () => { + const result = await globalSDK.client.auth.usage({}) + return result.data as AuthUsageData + }) + + const providers = createMemo(() => { + const data = usage() + if (!data) return [] + return Object.entries(data).filter(([_, info]) => info.accounts.length > 0) + }) + + const switchAccount = async (providerID: string, recordID: string) => { + setSwitching(recordID) + try { + const result = await fetch(`${globalSDK.url}/auth/active`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ providerID, recordID }), + }).then((r) => r.json()) + + if (result.success) { + const current = usage() + if (current && current[providerID]) { + mutate({ + ...current, + [providerID]: { + ...current[providerID], + accounts: current[providerID].accounts.map((acc) => ({ + ...acc, + isActive: acc.id === recordID, + })), + anthropicUsage: result.anthropicUsage ?? current[providerID].anthropicUsage, + }, + }) + } + } + } finally { + setSwitching(null) + } + } + + return ( + +
+ +
+ +
+
+ + +
+ No OAuth providers configured. Login with Claude Max or another OAuth provider to see usage data. +
+
+ + + {([providerID, info]) => ( +
+
+ + {providerID} + + ({info.accounts.length} account{info.accounts.length > 1 ? "s" : ""}) + +
+ + {/* Anthropic Usage Limits */} + +
+
Usage Limits (Active Account)
+ + + + + + + + + +
+
+ + +
+ Unable to fetch usage limits. Make sure you're logged in with Claude Max. +
+
+ + {/* Account Details */} +
+ Accounts + 1}> + (click to switch) + +
+ + {(account, index) => { + const isInCooldown = () => { + const cooldown = account.health.cooldownUntil + return cooldown && cooldown > Date.now() + } + const cooldownRemaining = () => { + const cooldown = account.health.cooldownUntil + if (!cooldown) return "" + const diff = cooldown - Date.now() + if (diff <= 0) return "" + const secs = Math.ceil(diff / 1000) + return secs > 60 ? `${Math.ceil(secs / 60)}m` : `${secs}s` + } + const isSwitching = () => switching() === account.id + const canSwitch = () => info.accounts.length > 1 && !account.isActive && !isSwitching() + + return ( + + ) + }} + +
+ )} +
+ + 0}> +
+ +
+
+
+
+ ) +} diff --git a/packages/app/src/components/dialog-settings.tsx b/packages/app/src/components/dialog-settings.tsx new file mode 100644 index 00000000000..fc0b587f72b --- /dev/null +++ b/packages/app/src/components/dialog-settings.tsx @@ -0,0 +1,694 @@ +import { Dialog } from "@opencode-ai/ui/dialog" +import { ProviderIcon } from "@opencode-ai/ui/provider-icon" +import type { IconName } from "@opencode-ai/ui/icons/provider" +import { Icon } from "@opencode-ai/ui/icon" +import { Spinner } from "@opencode-ai/ui/spinner" +import { createResource, For, Show, createMemo, createSignal } from "solid-js" +import { useGlobalSDK } from "@/context/global-sdk" +import { useProviders } from "@/hooks/use-providers" +import { useDialog } from "@opencode-ai/ui/context/dialog" +import { DialogConnectProvider } from "./dialog-connect-provider" +import { usePlatform } from "@/context/platform" + +type Tab = "providers" | "about" + +interface AccountUsage { + id: string + label?: string + isActive?: boolean + health: { + successCount: number + failureCount: number + lastStatusCode?: number + cooldownUntil?: number + } +} + +interface AnthropicUsage { + fiveHour?: { utilization: number; resetsAt?: string } + sevenDay?: { utilization: number; resetsAt?: string } + sevenDaySonnet?: { utilization: number; resetsAt?: string } +} + +interface ProviderUsage { + accounts: AccountUsage[] + anthropicUsage?: AnthropicUsage +} + +type AuthUsageData = Record + +function formatResetTime(resetAt?: string): string { + if (!resetAt) return "" + const reset = new Date(resetAt) + const now = new Date() + const diffMs = reset.getTime() - now.getTime() + if (diffMs <= 0) return "now" + + const totalMinutes = Math.floor(diffMs / (1000 * 60)) + const hours = Math.floor(totalMinutes / 60) + const minutes = totalMinutes % 60 + + if (hours > 0) return `${hours}h ${minutes}m` + return `${minutes}m` +} + +function getColorClass(percent: number): string { + if (percent <= 50) return "bg-fill-success-base" + if (percent <= 80) return "bg-fill-warning-base" + return "bg-fill-danger-base" +} + +function UsageBarPercent(props: { label: string; utilization: number; resetsAt?: string }) { + return ( +
+
+ {props.label} + {props.utilization}% used +
+
+
+
+ +
Resets in {formatResetTime(props.resetsAt)}
+
+
+ ) +} + +// Provider OAuth multi-account support status +const OAUTH_MULTI_ACCOUNT_SUPPORT: Record = { + anthropic: { supported: true, note: "Claude Max/Pro subscription" }, + openai: { supported: true, note: "ChatGPT Plus/Pro subscription" }, + "github-copilot": { supported: true, note: "GitHub Copilot subscription" }, + google: { supported: false, note: "Contributions welcome" }, + openrouter: { supported: false, note: "API key only" }, + azure: { supported: false, note: "Service principal auth" }, + "amazon-bedrock": { supported: false, note: "AWS credential chain" }, + mistral: { supported: false, note: "API key only" }, + groq: { supported: false, note: "API key only" }, + xai: { supported: false, note: "API key only" }, + perplexity: { supported: false, note: "API key only" }, + cohere: { supported: false, note: "API key only" }, + deepinfra: { supported: false, note: "API key only" }, + cerebras: { supported: false, note: "API key only" }, + togetherai: { supported: false, note: "API key only" }, + "google-vertex": { supported: false, note: "Service account auth" }, + gitlab: { supported: false, note: "Token auth" }, + vercel: { supported: false, note: "API key only" }, +} + +function TabButton(props: { + active: boolean + onClick: () => void + icon: keyof typeof import("@opencode-ai/ui/icon").Icon extends (p: { name: infer N }) => any ? N : never + label: string +}) { + return ( + + ) +} + +// Provider detail view - shows accounts, usage, switch functionality +function ProviderDetailView(props: { providerID: string; providerName: string; onBack: () => void }) { + const globalSDK = useGlobalSDK() + const platform = usePlatform() + const dialog = useDialog() + const [switching, setSwitching] = createSignal(null) + const [deleting, setDeleting] = createSignal(null) + const [confirmDelete, setConfirmDelete] = createSignal(null) + + const [usage, { refetch }] = createResource(async () => { + const result = await globalSDK.client.auth.usage({}) + const data = result.data as AuthUsageData + return data[props.providerID] + }) + + const switchAccount = async (recordID: string) => { + setSwitching(recordID) + try { + const doFetch = platform.fetch ?? fetch + const response = await doFetch(`${globalSDK.url}/auth/active`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ providerID: props.providerID, recordID }), + }) + if (response.ok) { + await refetch() + } + } catch (e) { + console.error("Failed to switch account:", e) + } finally { + setSwitching(null) + } + } + + const deleteAccount = async (recordID: string) => { + setDeleting(recordID) + try { + const doFetch = platform.fetch ?? fetch + const response = await doFetch(`${globalSDK.url}/auth/account`, { + method: "DELETE", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ providerID: props.providerID, recordID }), + }) + if (response.ok) { + const result = await response.json() + if (result.remaining === 0) { + // Provider was disconnected, go back to list + props.onBack() + } else { + await refetch() + } + } + } catch (e) { + console.error("Failed to delete account:", e) + } finally { + setDeleting(null) + setConfirmDelete(null) + } + } + + const support = OAUTH_MULTI_ACCOUNT_SUPPORT[props.providerID] + const isAnthropic = props.providerID === "anthropic" + + return ( +
+
+ + +

{props.providerName}

+ + + Multi-account + + +
+ + +
+ +
+
+ + + {(data) => ( + <> + {/* Anthropic Usage Stats */} + +
+
Rate Limits (Active Account)
+ + + + + + + + + +
+
+ + {/* Account List */} +
+
+
+ Accounts ({data().accounts.length}) + 1 && support?.supported}> + - click to switch + +
+ 1}> + Auto-rotation enabled + +
+ +
+ + {(account, index) => { + const isInCooldown = () => { + const cooldown = account.health.cooldownUntil + return cooldown && cooldown > Date.now() + } + const cooldownRemaining = () => { + const cooldown = account.health.cooldownUntil + if (!cooldown) return "" + const diff = cooldown - Date.now() + if (diff <= 0) return "" + const secs = Math.ceil(diff / 1000) + return secs > 60 ? `${Math.ceil(secs / 60)}m` : `${secs}s` + } + const isSwitching = () => switching() === account.id + const isDeleting = () => deleting() === account.id + const isConfirming = () => confirmDelete() === account.id + const canSwitch = () => + data().accounts.length > 1 && !account.isActive && !isSwitching() && support?.supported + + return ( +
+ + {/* Delete button */} + +
+ + +
+
+ + + +
+ ) + }} +
+
+
+ + {/* Add Account Button */} + + + {/* Info box for non-Anthropic providers */} + +
+
+ Usage statistics are currently only available for Anthropic. Multi-account switching works for this + provider. Contributions for usage stats are welcome! +
+
+
+ + {/* Refresh button */} + + + )} +
+ + +
+ No account data available. +
+
+
+ ) +} + +function ProvidersTab() { + const dialog = useDialog() + const providers = useProviders() + const [view, setView] = createSignal<"list" | "add" | { detail: string }>("list") + const [search, setSearch] = createSignal("") + + const connected = createMemo(() => + providers + .all() + .filter((p) => providers.connected().some((c) => c.id === p.id)) + .sort((a, b) => a.name.localeCompare(b.name)), + ) + + const available = createMemo(() => { + const query = search().toLowerCase() + return providers + .all() + .filter((p) => !query || p.name.toLowerCase().includes(query) || p.id.toLowerCase().includes(query)) + .sort((a, b) => { + const aPopular = ["anthropic", "openai", "github-copilot", "google", "openrouter"].includes(a.id) + const bPopular = ["anthropic", "openai", "github-copilot", "google", "openrouter"].includes(b.id) + if (aPopular && !bPopular) return -1 + if (!aPopular && bPopular) return 1 + return a.name.localeCompare(b.name) + }) + }) + + const detailProvider = createMemo(() => { + const v = view() + if (typeof v === "object" && "detail" in v) { + return providers.all().find((p) => p.id === v.detail) + } + return undefined + }) + + return ( + <> + {/* Provider detail view */} + + {(provider) => ( + setView("list")} + /> + )} + + + {/* Add provider view */} + +
+
+ +

Add Provider

+
+ + setSearch(e.currentTarget.value)} + autofocus + /> + +
+ + {(provider) => { + const isConnected = providers.connected().some((c) => c.id === provider.id) + const support = OAUTH_MULTI_ACCOUNT_SUPPORT[provider.id] + return ( + + ) + }} + +
+
+
+ + {/* List view (default) */} + +
+
+

Providers

+

+ Manage your AI provider connections. Click on a provider to view accounts and usage. +

+
+ + 0} + fallback={ +
+ No providers connected yet. Add a provider to get started. +
+ } + > +
+ + {(provider) => { + const support = OAUTH_MULTI_ACCOUNT_SUPPORT[provider.id] + return ( + + ) + }} + +
+
+ + + +
+
Multi-Account OAuth Rotation
+

+ For supported providers (Anthropic, OpenAI, GitHub Copilot), you can login with multiple accounts. + OpenCode will automatically rotate between them when one account hits rate limits. +

+
+
+
+ + ) +} + +function AboutTab() { + const platform = usePlatform() + + return ( +
+
+

About OpenCode

+

+ OpenCode is an open-source AI coding assistant that runs in your terminal and desktop. +

+
+ +
+
+
+ OC +
+
+ OpenCode + Community-driven AI coding assistant +
+
+ +
+ + + + +
+
+ +
+
Keyboard Shortcuts
+
+
+ Toggle sidebar + ⌘B +
+
+ Open project + ⌘O +
+
+ Previous session + ⌥↑ +
+
+ Next session + ⌥↓ +
+
+ Cycle theme + ⌘⇧T +
+
+
+ +
+
Contributing
+

+ OpenCode is a community project. Contributions for new provider integrations, multi-account OAuth support, and + other features are welcome! Check out our GitHub repository to get started. +

+
+
+ ) +} + +export function DialogSettings(props: { initialTab?: Tab }) { + const [activeTab, setActiveTab] = createSignal(props.initialTab ?? "providers") + + return ( + +
+
+ setActiveTab("providers")} + icon="brain" + label="Providers" + /> + setActiveTab("about")} icon="help" label="About" /> +
+ +
+ + + + + + +
+
+
+ ) +} diff --git a/packages/app/src/components/session/session-context-tab.tsx b/packages/app/src/components/session/session-context-tab.tsx index a975f9fa56f..e2b13db9150 100644 --- a/packages/app/src/components/session/session-context-tab.tsx +++ b/packages/app/src/components/session/session-context-tab.tsx @@ -1,17 +1,38 @@ -import { createMemo, createEffect, on, onCleanup, For, Show } from "solid-js" +import { createMemo, createEffect, on, onCleanup, For, Show, createSignal, createResource } from "solid-js" import type { JSX } from "solid-js" import { useParams } from "@solidjs/router" import { DateTime } from "luxon" import { useSync } from "@/context/sync" import { useLayout } from "@/context/layout" +import { useGlobalSDK } from "@/context/global-sdk" +import { usePlatform } from "@/context/platform" import { checksum } from "@opencode-ai/util/encode" import { Icon } from "@opencode-ai/ui/icon" import { Accordion } from "@opencode-ai/ui/accordion" import { StickyAccordionHeader } from "@opencode-ai/ui/sticky-accordion-header" import { Code } from "@opencode-ai/ui/code" import { Markdown } from "@opencode-ai/ui/markdown" +import { Spinner } from "@opencode-ai/ui/spinner" import type { AssistantMessage, Message, Part, UserMessage } from "@opencode-ai/sdk/v2/client" +interface AnthropicUsage { + fiveHour?: { utilization: number; resetsAt?: string } + sevenDay?: { utilization: number; resetsAt?: string } + sevenDaySonnet?: { utilization: number; resetsAt?: string } +} + +interface AccountUsage { + id: string + label?: string + isActive?: boolean + health: { successCount: number; failureCount: number; cooldownUntil?: number } +} + +interface ProviderUsageData { + accounts: AccountUsage[] + anthropicUsage?: AnthropicUsage +} + interface SessionContextTabProps { messages: () => Message[] visibleUserMessages: () => UserMessage[] @@ -19,6 +40,202 @@ interface SessionContextTabProps { info: () => ReturnType["session"]["get"]> } +function formatResetTime(resetAt?: string): string { + if (!resetAt) return "" + const reset = new Date(resetAt) + const now = new Date() + const diffMs = reset.getTime() - now.getTime() + if (diffMs <= 0) return "now" + const totalMinutes = Math.floor(diffMs / (1000 * 60)) + const hours = Math.floor(totalMinutes / 60) + const minutes = totalMinutes % 60 + if (hours > 0) return `${hours}h ${minutes}m` + return `${minutes}m` +} + +function getUsageColor(percent: number): string { + if (percent <= 50) return "var(--syntax-success)" + if (percent <= 80) return "var(--syntax-warning)" + return "var(--syntax-danger)" +} + +function AnthropicUsageSection() { + const globalSDK = useGlobalSDK() + const platform = usePlatform() + const [switching, setSwitching] = createSignal(null) + + const [usage, { refetch, mutate }] = createResource(async () => { + const result = await globalSDK.client.auth.usage({}) + const data = result.data as Record + return data["anthropic"] + }) + + const switchAccount = async (recordID: string) => { + setSwitching(recordID) + try { + const doFetch = platform.fetch ?? fetch + const response = await doFetch(`${globalSDK.url}/auth/active`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ providerID: "anthropic", recordID }), + }) + if (response.ok) { + const result = await response.json() + const current = usage() + if (current && result.success) { + // Update local state directly without full refetch + mutate({ + ...current, + accounts: current.accounts.map((acc) => ({ + ...acc, + isActive: acc.id === recordID, + })), + anthropicUsage: result.anthropicUsage ?? current.anthropicUsage, + }) + } + } + } catch (e) { + console.error("Failed to switch account:", e) + } finally { + setSwitching(null) + } + } + + const rateLimits = createMemo(() => { + const data = usage() + if (!data?.anthropicUsage) return [] + + const limits: { key: string; label: string; utilization: number; resetsAt?: string; color: string }[] = [] + + if (data.anthropicUsage.fiveHour) { + limits.push({ + key: "5h", + label: "5-Hour", + utilization: data.anthropicUsage.fiveHour.utilization, + resetsAt: data.anthropicUsage.fiveHour.resetsAt, + color: getUsageColor(data.anthropicUsage.fiveHour.utilization), + }) + } + if (data.anthropicUsage.sevenDay) { + limits.push({ + key: "7d", + label: "Weekly (All)", + utilization: data.anthropicUsage.sevenDay.utilization, + resetsAt: data.anthropicUsage.sevenDay.resetsAt, + color: getUsageColor(data.anthropicUsage.sevenDay.utilization), + }) + } + if (data.anthropicUsage.sevenDaySonnet) { + limits.push({ + key: "7d-sonnet", + label: "Weekly (Sonnet)", + utilization: data.anthropicUsage.sevenDaySonnet.utilization, + resetsAt: data.anthropicUsage.sevenDaySonnet.resetsAt, + color: getUsageColor(data.anthropicUsage.sevenDaySonnet.utilization), + }) + } + + return limits + }) + + return ( +
+
Anthropic Rate Limits
+ + +
+ +
+
+ + + {(data) => ( + <> + 0}> + + {(limit) => ( +
+
+
+
+
+
+
{limit.label}
+
{limit.utilization}%
+ +
resets {formatResetTime(limit.resetsAt)}
+
+
+
+ )} + + + + 1}> +
+
Accounts ({data().accounts.length}) - click to switch
+
+ + {(account, index) => { + const isSwitching = () => switching() === account.id + const canSwitch = () => !account.isActive && !isSwitching() + + return ( + + ) + }} + +
+
+
+ + + + )} + + + +
+ No Anthropic OAuth account connected. +
+
+
+ ) +} + export function SessionContextTab(props: SessionContextTabProps) { const params = useParams() const sync = useSync() @@ -402,6 +619,11 @@ export function SessionContextTab(props: SessionContextTabProps) {
+ {/* Anthropic Rate Limits - only show when provider is Anthropic */} + + + + {(prompt) => (
diff --git a/packages/app/src/pages/layout.tsx b/packages/app/src/pages/layout.tsx index 56d6bfbf8ca..0614aff1211 100644 --- a/packages/app/src/pages/layout.tsx +++ b/packages/app/src/pages/layout.tsx @@ -62,6 +62,7 @@ import { ConstrainDragXAxis } from "@/utils/solid-dnd" import { navStart } from "@/utils/perf" import { DialogSelectDirectory } from "@/components/dialog-select-directory" import { DialogEditProject } from "@/components/dialog-edit-project" +import { DialogSettings } from "@/components/dialog-settings" import { Titlebar } from "@/components/titlebar" import { useServer } from "@/context/server" @@ -1591,7 +1592,12 @@ export default function Layout(props: ParentProps) {
- + dialog.show(() => )} + /> +} + +const storage = new AsyncLocalStorage() + +export function getOAuthRecordID(providerID: string): string | undefined { + return storage.getStore()?.oauthRecordByProvider.get(providerID) +} + +export function withOAuthRecord(providerID: string, recordID: string, fn: () => T): T { + const current = storage.getStore() + const next: Store = { + oauthRecordByProvider: new Map(current?.oauthRecordByProvider ?? []), + } + next.oauthRecordByProvider.set(providerID, recordID) + + return storage.run(next, fn) +} diff --git a/packages/opencode/src/auth/credential-manager.ts b/packages/opencode/src/auth/credential-manager.ts new file mode 100644 index 00000000000..c7b499f953d --- /dev/null +++ b/packages/opencode/src/auth/credential-manager.ts @@ -0,0 +1,61 @@ +import z from "zod" +import { Bus } from "../bus" +import { BusEvent } from "../bus/bus-event" +import { Log } from "../util/log" +import { TuiEvent } from "../cli/cmd/tui/event" + +const log = Log.create({ service: "credential-manager" }) +const DEFAULT_FAILOVER_TOAST_MS = 8000 + +export namespace CredentialManager { + export const Event = { + Failover: BusEvent.define( + "credential.failover", + z.object({ + providerID: z.string(), + fromRecordID: z.string(), + toRecordID: z.string().optional(), + statusCode: z.number(), + message: z.string(), + }), + ), + } + + export async function notifyFailover(input: { + providerID: string + fromRecordID: string + toRecordID?: string + statusCode: number + toastDurationMs?: number + }): Promise { + const isRateLimit = input.statusCode === 429 + const message = isRateLimit + ? `Rate limited on "${input.providerID}". Switching OAuth credential...` + : input.statusCode === 0 + ? `Request failed on "${input.providerID}". Switching OAuth credential...` + : `Auth error on "${input.providerID}". Switching OAuth credential...` + const duration = Math.max(0, input.toastDurationMs ?? DEFAULT_FAILOVER_TOAST_MS) + + log.info("oauth credential failover", { + providerID: input.providerID, + fromRecordID: input.fromRecordID, + toRecordID: input.toRecordID, + statusCode: input.statusCode, + }) + + await Bus.publish(Event.Failover, { + providerID: input.providerID, + fromRecordID: input.fromRecordID, + toRecordID: input.toRecordID, + statusCode: input.statusCode, + message, + }).catch((error) => log.debug("failed to publish credential failover event", { error })) + + await Bus.publish(TuiEvent.ToastShow, { + title: "OAuth Credential Failover", + message, + variant: "warning", + duration, + }).catch((error) => log.debug("failed to show failover toast", { error })) + } +} diff --git a/packages/opencode/src/auth/index.ts b/packages/opencode/src/auth/index.ts index 3fd28305368..4e8f9a64c8c 100644 --- a/packages/opencode/src/auth/index.ts +++ b/packages/opencode/src/auth/index.ts @@ -2,6 +2,9 @@ import path from "path" import { Global } from "../global" import fs from "fs/promises" import z from "zod" +import { ulid } from "ulid" +import { getOAuthRecordID } from "./context" +import { Log } from "../util/log" export const OAUTH_DUMMY_KEY = "opencode-oauth-dummy-key" @@ -36,38 +39,712 @@ export namespace Auth { export type Info = z.infer const filepath = path.join(Global.Path.data, "auth.json") + const lockpath = `${filepath}.lock` + const STORE_LOCK_TIMEOUT_MS = 5_000 + const STORE_LOCK_STALE_MS = 30_000 + const STORE_LOCK_RETRY_MS = 25 + const STORE_LOCK_BEST_EFFORT_TIMEOUT_MS = 250 + const STORE_LOCK_BEST_EFFORT_RETRY_MS = 10 - export async function get(providerID: string) { - const auth = await all() - return auth[providerID] + const log = Log.create({ service: "auth.store" }) + + class StoreLockTimeoutError extends Error { + constructor() { + super("Timed out waiting for auth store lock") + this.name = "StoreLockTimeoutError" + } } - export async function all(): Promise> { + const Health = z + .object({ + cooldownUntil: z.number().optional(), + lastStatusCode: z.number().optional(), + lastErrorAt: z.number().optional(), + successCount: z.number().default(0), + failureCount: z.number().default(0), + }) + .strict() + .default(() => ({ successCount: 0, failureCount: 0 })) + type Health = z.infer + + const OAuthRecord = z + .object({ + id: z.string(), + namespace: z.string().default("default"), + label: z.string().optional(), + accountId: z.string().optional(), + enterpriseUrl: z.string().optional(), + refresh: z.string(), + access: z.string(), + expires: z.number(), + createdAt: z.number(), + updatedAt: z.number(), + health: Health, + }) + .strict() + type OAuthRecord = z.infer + + export type OAuthRecordMeta = Omit + + const OAuthProvider = z + .object({ + type: z.literal("oauth"), + active: z.record(z.string(), z.string()).default({}), + order: z.record(z.string(), z.array(z.string())).default({}), + records: z.array(OAuthRecord).default([]), + }) + .strict() + type OAuthProvider = z.infer + + const ApiProvider = z + .object({ + type: z.literal("api"), + key: z.string(), + }) + .strict() + + const WellKnownProvider = z + .object({ + type: z.literal("wellknown"), + key: z.string(), + token: z.string(), + }) + .strict() + + const ProviderEntry = z.union([OAuthProvider, ApiProvider, WellKnownProvider]) + type ProviderEntry = z.infer + + const StoreFile = z + .object({ + version: z.literal(2), + providers: z.record(z.string(), ProviderEntry).default({}), + }) + .strict() + type StoreFile = z.infer + + function toMeta(record: OAuthRecord): OAuthRecordMeta { + const { refresh: _refresh, access: _access, expires: _expires, ...meta } = record + return meta + } + + async function ensureDataDir(): Promise { + await fs.mkdir(path.dirname(filepath), { recursive: true }) + } + + async function withStoreLock( + fn: () => Promise, + options: { timeoutMs?: number; staleMs?: number; retryMs?: number } = {}, + ): Promise { + await ensureDataDir() + const timeoutMs = options.timeoutMs ?? STORE_LOCK_TIMEOUT_MS + const staleMs = options.staleMs ?? STORE_LOCK_STALE_MS + const retryMs = options.retryMs ?? STORE_LOCK_RETRY_MS + const start = Date.now() + while (true) { + try { + const handle = await fs.open(lockpath, "wx") + await handle.close() + break + } catch (error) { + const code = (error as { code?: string }).code + if (code !== "EEXIST") throw error + const stat = await fs.stat(lockpath).catch(() => undefined) + if (stat && Date.now() - stat.mtimeMs > staleMs) { + await fs.rm(lockpath).catch(() => {}) + continue + } + if (Date.now() - start > timeoutMs) { + throw new StoreLockTimeoutError() + } + await Bun.sleep(retryMs + Math.random() * retryMs) + } + } + + try { + return await fn() + } finally { + await fs.rm(lockpath).catch(() => {}) + } + } + + async function writeStoreFile(store: StoreFile): Promise { + await ensureDataDir() + const tempPath = `${filepath}.tmp` + const tempFile = Bun.file(tempPath) + await Bun.write(tempFile, JSON.stringify(store, null, 2)) + await fs.rename(tempPath, filepath) + await fs.chmod(filepath, 0o600).catch(() => {}) + } + + async function readStoreFile(): Promise<{ store: StoreFile; needsWrite: boolean }> { const file = Bun.file(filepath) - const data = await file.json().catch(() => ({}) as Record) - return Object.entries(data).reduce( - (acc, [key, value]) => { - const parsed = Info.safeParse(value) - if (!parsed.success) return acc - acc[key] = parsed.data - return acc - }, - {} as Record, - ) + const exists = await file.exists() + const raw = await file.json().catch(() => undefined) + + const parsed = StoreFile.safeParse(raw) + if (parsed.success) return { store: parsed.data, needsWrite: false } + + const legacyParsed = z.record(z.string(), Info).safeParse(raw) + if (legacyParsed.success) { + const now = Date.now() + const next: StoreFile = { version: 2, providers: {} } + + for (const [providerID, info] of Object.entries(legacyParsed.data)) { + if (info.type === "api") { + next.providers[providerID] = { type: "api", key: info.key } + continue + } + + if (info.type === "wellknown") { + next.providers[providerID] = { type: "wellknown", key: info.key, token: info.token } + continue + } + + const recordID = ulid() + next.providers[providerID] = { + type: "oauth", + active: { default: recordID }, + order: { default: [recordID] }, + records: [ + { + id: recordID, + namespace: "default", + label: "default", + accountId: info.accountId, + enterpriseUrl: info.enterpriseUrl, + refresh: info.refresh, + access: info.access, + expires: info.expires, + createdAt: now, + updatedAt: now, + health: { successCount: 0, failureCount: 0 }, + }, + ], + } + } + + return { store: next, needsWrite: true } + } + + return { store: { version: 2, providers: {} }, needsWrite: exists } + } + + async function loadStoreFile(): Promise { + const result = await readStoreFile() + return result.store + } + + type StoreUpdateResult = { + value: T + changed: boolean + } + + async function updateStoreWithLock( + fn: (store: StoreFile) => Promise> | StoreUpdateResult, + lockOptions?: { timeoutMs?: number; staleMs?: number; retryMs?: number }, + ) { + return withStoreLock(async () => { + const { store, needsWrite } = await readStoreFile() + const result = await fn(store) + if (result.changed || needsWrite) { + await writeStoreFile(store) + } + return result.value + }, lockOptions) + } + + async function updateStore(fn: (store: StoreFile) => Promise> | StoreUpdateResult) { + return updateStoreWithLock(fn) + } + + async function updateStoreBestEffort( + fn: (store: StoreFile) => Promise> | StoreUpdateResult, + ): Promise { + try { + await updateStoreWithLock(fn, { + timeoutMs: STORE_LOCK_BEST_EFFORT_TIMEOUT_MS, + retryMs: STORE_LOCK_BEST_EFFORT_RETRY_MS, + }) + } catch (error) { + if (error instanceof StoreLockTimeoutError) { + log.warn("auth store lock busy, skipping update", { timeoutMs: STORE_LOCK_BEST_EFFORT_TIMEOUT_MS }) + return + } + throw error + } + } + + function ensureOAuthProvider(store: StoreFile, providerID: string): OAuthProvider { + const existing = store.providers[providerID] + if (existing && existing.type === "oauth") return existing + + const next: OAuthProvider = { + type: "oauth", + active: {}, + order: {}, + records: [], + } + store.providers[providerID] = next + return next + } + + function findOAuthRecord(provider: OAuthProvider, recordID: string): OAuthRecord | undefined { + return provider.records.find((record) => record.id === recordID) + } + + function normalizeOrder(ids: string[], order: string[]): string[] { + const ordered: string[] = [] + for (const id of order) { + if (ids.includes(id) && !ordered.includes(id)) ordered.push(id) + } + for (const id of ids) { + if (!ordered.includes(id)) ordered.push(id) + } + return ordered + } + + function recordIDsForNamespace(provider: OAuthProvider, namespace: string): string[] { + const ids = provider.records.filter((record) => record.namespace === namespace).map((record) => record.id) + const order = provider.order[namespace] ?? [] + return normalizeOrder(ids, order) + } + + async function findOAuthRecordIDByRefreshToken(input: { + providerID: string + namespace: string + refresh: string + provider: OAuthProvider + }): Promise { + for (const record of input.provider.records) { + if (record.namespace !== input.namespace) continue + if (record.refresh === input.refresh) return record.id + } + return undefined + } + + export async function get(providerID: string): Promise { + const store = await loadStoreFile() + const entry = store.providers[providerID] + if (!entry) return undefined + + if (entry.type === "api") { + return { type: "api", key: entry.key } + } + + if (entry.type === "wellknown") { + return { type: "wellknown", key: entry.key, token: entry.token } + } + + const namespace = "default" + const contextID = getOAuthRecordID(providerID) + const active = contextID ?? entry.active[namespace] + const ordered = recordIDsForNamespace(entry, namespace) + const recordID = active && ordered.includes(active) ? active : ordered[0] + if (!recordID) return undefined + + const record = findOAuthRecord(entry, recordID) + if (!record) return undefined + return { + type: "oauth", + refresh: record.refresh, + access: record.access, + expires: record.expires, + accountId: record.accountId, + enterpriseUrl: record.enterpriseUrl, + } + } + + export async function all(): Promise> { + const store = await loadStoreFile() + const out: Record = {} + + for (const providerID of Object.keys(store.providers)) { + const info = await get(providerID) + if (!info) continue + out[providerID] = info + } + + return out } export async function set(key: string, info: Info) { - const file = Bun.file(filepath) - const data = await all() - await Bun.write(file, JSON.stringify({ ...data, [key]: info }, null, 2)) - await fs.chmod(file.name!, 0o600) + return updateStore(async (store) => { + if (info.type === "api") { + store.providers[key] = { type: "api", key: info.key } + return { value: undefined, changed: true } + } + + if (info.type === "wellknown") { + store.providers[key] = { type: "wellknown", key: info.key, token: info.token } + return { value: undefined, changed: true } + } + + const namespace = "default" + const provider = ensureOAuthProvider(store, key) + const recordID = + getOAuthRecordID(key) ?? + (await findOAuthRecordIDByRefreshToken({ providerID: key, namespace, refresh: info.refresh, provider })) ?? + provider.active[namespace] ?? + recordIDsForNamespace(provider, namespace)[0] ?? + ulid() + + const now = Date.now() + const existing = findOAuthRecord(provider, recordID) + if (!existing) { + provider.records.push({ + id: recordID, + namespace, + label: "default", + accountId: info.accountId, + enterpriseUrl: info.enterpriseUrl, + refresh: info.refresh, + access: info.access, + expires: info.expires, + createdAt: now, + updatedAt: now, + health: { successCount: 0, failureCount: 0 }, + }) + provider.order[namespace] = [...(provider.order[namespace] ?? []), recordID] + } else { + existing.refresh = info.refresh + existing.access = info.access + existing.expires = info.expires + existing.updatedAt = now + if (info.accountId !== undefined) existing.accountId = info.accountId + if (info.enterpriseUrl !== undefined) existing.enterpriseUrl = info.enterpriseUrl + const order = provider.order[namespace] ?? [] + if (!order.includes(recordID)) { + provider.order[namespace] = [...order, recordID] + } + } + provider.active[namespace] = recordID + + return { value: undefined, changed: true } + }) } export async function remove(key: string) { - const file = Bun.file(filepath) - const data = await all() - delete data[key] - await Bun.write(file, JSON.stringify(data, null, 2)) - await fs.chmod(file.name!, 0o600) + return updateStore((store) => { + const existing = store.providers[key] + if (!existing) return { value: undefined, changed: false } + + delete store.providers[key] + return { value: undefined, changed: true } + }) + } + + export async function addOAuth( + providerID: string, + input: Omit, "type"> & { namespace?: string; label?: string }, + ) { + const namespace = (input.namespace ?? "default").trim() || "default" + return updateStore(async (store) => { + const provider = ensureOAuthProvider(store, providerID) + const now = Date.now() + const existingRecordID = await findOAuthRecordIDByRefreshToken({ + providerID, + namespace, + refresh: input.refresh, + provider, + }) + + if (existingRecordID) { + const existing = findOAuthRecord(provider, existingRecordID) + if (existing) { + existing.refresh = input.refresh + existing.access = input.access + existing.expires = input.expires + existing.updatedAt = now + if (input.accountId !== undefined) existing.accountId = input.accountId + if (input.enterpriseUrl !== undefined) existing.enterpriseUrl = input.enterpriseUrl + if (input.label) existing.label = input.label + } + const order = provider.order[namespace] ?? [] + if (!order.includes(existingRecordID)) { + provider.order[namespace] = [...order, existingRecordID] + } + provider.active[namespace] = existingRecordID + + return { value: { providerID, namespace, recordID: existingRecordID }, changed: true } + } + + const recordID = ulid() + + provider.records.push({ + id: recordID, + namespace, + label: input.label ?? "default", + accountId: input.accountId, + enterpriseUrl: input.enterpriseUrl, + refresh: input.refresh, + access: input.access, + expires: input.expires, + createdAt: now, + updatedAt: now, + health: { successCount: 0, failureCount: 0 }, + }) + + provider.order[namespace] = [...(provider.order[namespace] ?? []), recordID] + provider.active[namespace] = recordID + + return { value: { providerID, namespace, recordID }, changed: true } + }) + } + + export namespace OAuthPool { + export async function snapshot( + providerID: string, + namespace = "default", + ): Promise<{ records: OAuthRecordMeta[]; orderedIDs: string[]; activeID?: string }> { + const store = await loadStoreFile() + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return { records: [], orderedIDs: [] } + + const normalized = namespace.trim() || "default" + const records = provider.records.filter((record) => record.namespace === normalized).map(toMeta) + const orderedIDs = recordIDsForNamespace(provider, normalized) + const activeID = provider.active[normalized] + + return { records, orderedIDs, activeID } + } + + export async function list(providerID: string, namespace = "default"): Promise { + return snapshot(providerID, namespace).then((result) => result.records) + } + + export async function orderedIDs(providerID: string, namespace = "default"): Promise { + return snapshot(providerID, namespace).then((result) => result.orderedIDs) + } + + export async function moveToBack(providerID: string, namespace: string, recordID: string): Promise { + await updateStoreBestEffort((store) => { + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return { value: undefined, changed: false } + const order = recordIDsForNamespace(provider, namespace) + provider.order[namespace] = order.filter((id) => id !== recordID).concat(recordID) + provider.active[namespace] = provider.order[namespace][0] ?? provider.active[namespace] + return { value: undefined, changed: true } + }) + } + + export async function recordOutcome(input: { + providerID: string + recordID: string + statusCode: number + ok: boolean + cooldownUntil?: number + }): Promise { + await updateStoreBestEffort((store) => { + const provider = store.providers[input.providerID] + if (!provider || provider.type !== "oauth") return { value: undefined, changed: false } + + const record = findOAuthRecord(provider, input.recordID) + if (!record) return { value: undefined, changed: false } + + const now = Date.now() + const prevCooldown = + record.health.cooldownUntil && record.health.cooldownUntil > now ? record.health.cooldownUntil : undefined + const cooldownUntil = input.ok ? undefined : (input.cooldownUntil ?? prevCooldown) + + record.health = { + ...record.health, + cooldownUntil, + lastStatusCode: input.statusCode, + lastErrorAt: input.ok ? undefined : now, + successCount: record.health.successCount + (input.ok ? 1 : 0), + failureCount: record.health.failureCount + (input.ok ? 0 : 1), + } + record.updatedAt = now + return { value: undefined, changed: true } + }) + } + + export async function markAccessExpired(providerID: string, namespace: string, recordID: string): Promise { + await updateStoreBestEffort((store) => { + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return { value: undefined, changed: false } + const record = findOAuthRecord(provider, recordID) + if (!record || record.namespace !== namespace) return { value: undefined, changed: false } + record.access = "" + record.expires = 0 + record.updatedAt = Date.now() + return { value: undefined, changed: true } + }) + } + + export async function getUsage( + providerID: string, + namespace = "default", + ): Promise< + Array<{ + id: string + label?: string + isActive: boolean + health: { + successCount: number + failureCount: number + lastStatusCode?: number + cooldownUntil?: number + } + }> + > { + const store = await loadStoreFile() + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return [] + + const orderedIDs = recordIDsForNamespace(provider, namespace) + const now = Date.now() + // Use explicitly set active account if it exists, otherwise fall back to first non-cooldown + const activeID = + provider.active[namespace] ?? + orderedIDs.find((id) => { + const record = provider.records.find((r) => r.id === id) + const cooldownUntil = record?.health.cooldownUntil + return !cooldownUntil || cooldownUntil <= now + }) ?? + orderedIDs[0] + + return provider.records + .filter((record) => record.namespace === namespace) + .map((record) => ({ + id: record.id, + label: record.label, + isActive: record.id === activeID, + health: { + successCount: record.health.successCount, + failureCount: record.health.failureCount, + lastStatusCode: record.health.lastStatusCode, + cooldownUntil: record.health.cooldownUntil, + }, + })) + } + + export async function setActive(providerID: string, namespace: string, recordID: string): Promise { + return updateStore((store) => { + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return { value: false, changed: false } + + const record = findOAuthRecord(provider, recordID) + if (!record || record.namespace !== namespace) return { value: false, changed: false } + + const order = recordIDsForNamespace(provider, namespace) + provider.order[namespace] = [recordID, ...order.filter((id) => id !== recordID)] + provider.active[namespace] = recordID + + return { value: true, changed: true } + }) + } + + export async function removeRecord( + providerID: string, + recordID: string, + namespace = "default", + ): Promise<{ removed: boolean; remaining: number }> { + return updateStore<{ removed: boolean; remaining: number }>((store) => { + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return { value: { removed: false, remaining: 0 }, changed: false } + + const index = provider.records.findIndex((r) => r.id === recordID && r.namespace === namespace) + if (index === -1) return { value: { removed: false, remaining: provider.records.length }, changed: false } + + // Remove the record + provider.records.splice(index, 1) + + // Update order array + const order = provider.order[namespace] ?? [] + provider.order[namespace] = order.filter((id) => id !== recordID) + + // If the removed record was active, set a new active + if (provider.active[namespace] === recordID) { + const remaining = recordIDsForNamespace(provider, namespace) + provider.active[namespace] = remaining[0] + } + + // If no records left for this namespace, clean up + const remaining = provider.records.filter((r) => r.namespace === namespace).length + if (remaining === 0) { + delete provider.order[namespace] + delete provider.active[namespace] + } + + // If no records left at all, remove the provider entry + if (provider.records.length === 0) { + delete store.providers[providerID] + } + + return { value: { removed: true, remaining }, changed: true } + }) + } + + export async function fetchAnthropicUsage( + providerID: string, + namespace = "default", + recordID?: string, + ): Promise<{ + fiveHour?: { utilization: number; resetsAt?: string } + sevenDay?: { utilization: number; resetsAt?: string } + sevenDaySonnet?: { utilization: number; resetsAt?: string } + } | null> { + if (providerID !== "anthropic") return null + + const store = await loadStoreFile() + const provider = store.providers[providerID] + if (!provider || provider.type !== "oauth") return null + + const orderedIDs = recordIDsForNamespace(provider, namespace) + const now = Date.now() + // Use explicit recordID, then provider.active, then first non-cooldown, then first in order + const activeID = + recordID ?? + provider.active[namespace] ?? + orderedIDs.find((id) => { + const rec = provider.records.find((r) => r.id === id) + const cooldownUntil = rec?.health.cooldownUntil + return !cooldownUntil || cooldownUntil <= now + }) ?? + orderedIDs[0] + const record = provider.records.find((r) => r.id === activeID && r.namespace === namespace) + if (!record?.access) return null + + try { + const controller = new AbortController() + const timeout = setTimeout(() => controller.abort(), 5000) + + const response = await fetch("https://api.anthropic.com/api/oauth/usage", { + method: "GET", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + Authorization: `Bearer ${record.access}`, + "anthropic-beta": "oauth-2025-04-20", + }, + signal: controller.signal, + }) + + clearTimeout(timeout) + + if (!response.ok) return null + + const data = (await response.json()) as { + five_hour?: { utilization: number; resets_at?: string } + seven_day?: { utilization: number; resets_at?: string } + seven_day_sonnet?: { utilization: number; resets_at?: string } + } + + return { + fiveHour: data.five_hour + ? { utilization: Math.round(data.five_hour.utilization), resetsAt: data.five_hour.resets_at } + : undefined, + sevenDay: data.seven_day + ? { utilization: Math.round(data.seven_day.utilization), resetsAt: data.seven_day.resets_at } + : undefined, + sevenDaySonnet: data.seven_day_sonnet + ? { utilization: Math.round(data.seven_day_sonnet.utilization), resetsAt: data.seven_day_sonnet.resets_at } + : undefined, + } + } catch { + return null + } + } } } diff --git a/packages/opencode/src/auth/rotating-fetch.ts b/packages/opencode/src/auth/rotating-fetch.ts new file mode 100644 index 00000000000..4127b1e3622 --- /dev/null +++ b/packages/opencode/src/auth/rotating-fetch.ts @@ -0,0 +1,332 @@ +import { Auth } from "./index" +import { withOAuthRecord } from "./context" +import { CredentialManager } from "./credential-manager" + +const DEFAULT_RATE_LIMIT_COOLDOWN_MS = 30_000 +const DEFAULT_AUTH_FAILURE_COOLDOWN_MS = 5 * 60_000 +const DEFAULT_NETWORK_RETRY_ATTEMPTS = 1 + +function isReadableStream(value: unknown): value is ReadableStream { + return typeof ReadableStream !== "undefined" && value instanceof ReadableStream +} + +function isAsyncIterable(value: unknown): value is AsyncIterable { + return typeof value === "object" && value !== null && Symbol.asyncIterator in value +} + +function isReplayableBody(body: unknown): boolean { + if (!body) return true + if (isReadableStream(body)) return false + if (isAsyncIterable(body)) return false + return true +} + +function isRequest(value: unknown): value is Request { + return typeof Request !== "undefined" && value instanceof Request +} + +async function drainResponse(response: Response): Promise { + try { + await response.body?.cancel() + } catch {} +} + +function parseRetryAfterMs(response: Response): number | undefined { + const value = response.headers.get("retry-after") ?? response.headers.get("Retry-After") + if (!value) return undefined + + const seconds = Number(value) + if (Number.isFinite(seconds)) return Math.max(0, seconds) * 1000 + + const dateMs = Date.parse(value) + if (!Number.isNaN(dateMs)) return Math.max(0, dateMs - Date.now()) + + return undefined +} + +const NETWORK_ERROR_CODES = new Set([ + "ECONNRESET", + "ECONNREFUSED", + "EHOSTUNREACH", + "ENETUNREACH", + "ENOTFOUND", + "EAI_AGAIN", + "ETIMEDOUT", + "ECONNABORTED", + "EPIPE", + "UND_ERR_CONNECT_TIMEOUT", + "UND_ERR_HEADERS_TIMEOUT", + "UND_ERR_BODY_TIMEOUT", + "UND_ERR_SOCKET", +]) +const NETWORK_ERROR_NAMES = new Set(["AbortError", "TimeoutError", "FetchError"]) + +function extractErrorCode(error: unknown): string | undefined { + if (!error || typeof error !== "object") return undefined + const code = (error as { code?: unknown }).code + return typeof code === "string" ? code : undefined +} + +function extractErrorName(error: unknown): string | undefined { + if (!error || typeof error !== "object") return undefined + const name = (error as { name?: unknown }).name + return typeof name === "string" ? name : undefined +} + +function extractErrorMessage(error: unknown): string | undefined { + if (!error || typeof error !== "object") return undefined + const message = (error as { message?: unknown }).message + return typeof message === "string" ? message : undefined +} + +function isNetworkError(error: unknown): boolean { + const directCode = extractErrorCode(error) + const cause = typeof error === "object" && error !== null ? (error as { cause?: unknown }).cause : undefined + const causeCode = extractErrorCode(cause) + const code = directCode ?? causeCode + if (code && NETWORK_ERROR_CODES.has(code)) return true + + const name = extractErrorName(error) + if (name && NETWORK_ERROR_NAMES.has(name)) return true + + const message = extractErrorMessage(error)?.toLowerCase() + if (!message) return false + return message.includes("fetch failed") || message.includes("network error") || message.includes("network down") +} + +function isAuthExpiredStatus(status: number): boolean { + return status === 401 || status === 403 +} + +export function createOAuthRotatingFetch Promise>( + fetchFn: TFetch, + opts: { + providerID: string + namespace?: string + maxAttempts?: number + rateLimitCooldownMs?: number + authFailureCooldownMs?: number + networkRetryAttempts?: number + toastDurationMs?: number + }, +): TFetch { + const namespace = (opts.namespace ?? "default").trim() || "default" + + return (async (input: any, init?: any) => { + const { records, orderedIDs, activeID } = await Auth.OAuthPool.snapshot(opts.providerID, namespace) + if (records.length === 0) return fetchFn(input, init) + + const recordByID = new Map(records.map((record) => [record.id, record])) + // Prefer activeID first, then follow the order + const candidates = + activeID && recordByID.has(activeID) + ? [activeID, ...orderedIDs.filter((id) => id !== activeID && recordByID.has(id))] + : orderedIDs.filter((id) => recordByID.has(id)) + if (candidates.length === 0) return fetchFn(input, init) + const inputIsRequest = isRequest(input) + let allowRetry = + isReplayableBody(init?.body) && (!inputIsRequest || (!input.bodyUsed && !isReadableStream(input.body))) + + const rateLimitCooldownMs = opts.rateLimitCooldownMs ?? DEFAULT_RATE_LIMIT_COOLDOWN_MS + const authFailureCooldownMs = opts.authFailureCooldownMs ?? DEFAULT_AUTH_FAILURE_COOLDOWN_MS + const configuredNetworkRetryAttempts = Math.max(0, opts.networkRetryAttempts ?? DEFAULT_NETWORK_RETRY_ATTEMPTS) + const maxAttemptBudget = opts.maxAttempts ?? candidates.length + let maxAttempts = Math.max(1, maxAttemptBudget) + if (!allowRetry) { + maxAttempts = 1 + } else if (maxAttempts > candidates.length) { + maxAttempts = candidates.length + } + + const attempted = new Set() + const refreshed = new Set() + let lastError: unknown + + const pickNextCandidate = (now: number) => + candidates.find((id) => { + if (attempted.has(id)) return false + const cooldownUntil = recordByID.get(id)?.health.cooldownUntil + return !cooldownUntil || cooldownUntil <= now + }) ?? candidates.find((id) => !attempted.has(id)) + + for (let attempt = 0; attempt < maxAttempts; attempt++) { + const now = Date.now() + + const nextID = pickNextCandidate(now) + + if (!nextID) break + attempted.add(nextID) + + const hasMoreAttempts = () => attempt + 1 < maxAttempts + let networkRetryAttempts = allowRetry ? configuredNetworkRetryAttempts : 0 + + const runWithNetworkRetry = async (): Promise => { + for (let networkAttempt = 0; ; networkAttempt++) { + let attemptInput = input + if (inputIsRequest && allowRetry) { + try { + attemptInput = input.clone() + } catch (e) { + lastError = e + allowRetry = false + networkRetryAttempts = 0 + maxAttempts = attempt + 1 + } + } + + try { + return await withOAuthRecord(opts.providerID, nextID, () => fetchFn(attemptInput, init)) + } catch (e) { + lastError = e + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: 0, + ok: false, + }) + const networkError = isNetworkError(e) + if (networkError && allowRetry && networkAttempt < networkRetryAttempts) { + continue + } + throw e + } + } + } + const notifyFailover = async (statusCode: number) => { + const candidate = pickNextCandidate(Date.now()) + if (!candidate) return + await CredentialManager.notifyFailover({ + providerID: opts.providerID, + fromRecordID: nextID, + toRecordID: candidate, + statusCode, + toastDurationMs: opts.toastDurationMs, + }) + } + + let response: Response + try { + response = await runWithNetworkRetry() + } catch (e) { + if (isNetworkError(e)) throw e + + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + await notifyFailover(0) + if (!hasMoreAttempts()) throw e + continue + } + + if (response.ok) { + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: response.status, + ok: true, + }) + return response + } + + if (response.status === 429) { + const cooldownMs = parseRetryAfterMs(response) ?? rateLimitCooldownMs + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: response.status, + ok: false, + cooldownUntil: Date.now() + cooldownMs, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + await notifyFailover(response.status) + if (!hasMoreAttempts()) return response + await drainResponse(response) + continue + } + + if (isAuthExpiredStatus(response.status) && !refreshed.has(nextID)) { + refreshed.add(nextID) + + await Auth.OAuthPool.markAccessExpired(opts.providerID, namespace, nextID) + if (!allowRetry) { + const cooldownUntil = Date.now() + authFailureCooldownMs + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: response.status, + ok: false, + cooldownUntil, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + await notifyFailover(response.status) + return response + } + + await drainResponse(response) + + try { + const retry = await runWithNetworkRetry() + if (retry.ok) { + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: retry.status, + ok: true, + }) + return retry + } + + if (retry.status === 429) { + const cooldownMs = parseRetryAfterMs(retry) ?? rateLimitCooldownMs + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: retry.status, + ok: false, + cooldownUntil: Date.now() + cooldownMs, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + await notifyFailover(retry.status) + if (!hasMoreAttempts()) return retry + await drainResponse(retry) + continue + } + + const cooldownUntil = Date.now() + authFailureCooldownMs + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: retry.status, + ok: false, + cooldownUntil, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + await notifyFailover(retry.status) + if (!hasMoreAttempts()) return retry + await drainResponse(retry) + continue + } catch (e) { + if (isNetworkError(e)) throw e + await notifyFailover(0) + if (!hasMoreAttempts()) throw e + } + + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + continue + } + + await Auth.OAuthPool.recordOutcome({ + providerID: opts.providerID, + recordID: nextID, + statusCode: response.status, + ok: false, + }) + await Auth.OAuthPool.moveToBack(opts.providerID, namespace, nextID) + await notifyFailover(response.status) + if (!hasMoreAttempts()) return response + await drainResponse(response) + continue + } + + if (lastError) throw lastError + return fetchFn(input, init) + }) as TFetch +} diff --git a/packages/opencode/src/cli/cmd/auth.ts b/packages/opencode/src/cli/cmd/auth.ts index bbaecfd8c71..9f2cc8e2c59 100644 --- a/packages/opencode/src/cli/cmd/auth.ts +++ b/packages/opencode/src/cli/cmd/auth.ts @@ -82,13 +82,12 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): if (result.type === "success") { const saveProvider = result.provider ?? provider if ("refresh" in result) { - const { type: _, provider: __, refresh, access, expires, ...extraFields } = result - await Auth.set(saveProvider, { - type: "oauth", - refresh, - access, - expires, - ...extraFields, + await Auth.addOAuth(saveProvider, { + refresh: result.refresh, + access: result.access, + expires: result.expires, + accountId: result.accountId, + enterpriseUrl: result.enterpriseUrl, }) } if ("key" in result) { @@ -114,13 +113,12 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): if (result.type === "success") { const saveProvider = result.provider ?? provider if ("refresh" in result) { - const { type: _, provider: __, refresh, access, expires, ...extraFields } = result - await Auth.set(saveProvider, { - type: "oauth", - refresh, - access, - expires, - ...extraFields, + await Auth.addOAuth(saveProvider, { + refresh: result.refresh, + access: result.access, + expires: result.expires, + accountId: result.accountId, + enterpriseUrl: result.enterpriseUrl, }) } if ("key" in result) { @@ -159,11 +157,153 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string): return false } +export const AuthUsageCommand = cmd({ + command: "usage", + describe: "show rate limit usage for providers", + async handler() { + UI.empty() + prompts.intro("Usage") + const all = await Auth.all() + const database = await ModelsDev.get() + const sorted = Object.entries(all).sort((a, b) => { + const nameA = database[a[0]]?.name || a[0] + const nameB = database[b[0]]?.name || b[0] + return nameA.localeCompare(nameB) + }) + + let hasOAuth = false + for (const [providerID, info] of sorted) { + if (info.type !== "oauth") continue + hasOAuth = true + + const name = database[providerID]?.name || providerID + const accounts = await Auth.OAuthPool.getUsage(providerID) + + prompts.log.step(`${name} (${accounts.length} account${accounts.length !== 1 ? "s" : ""})`) + + for (const account of accounts) { + const label = account.label || "default" + const status = account.isActive ? `${UI.Style.TEXT_SUCCESS}active` : UI.Style.TEXT_DIM + "inactive" + prompts.log.info(` Account: ${label} - ${status}`) + + if (account.health.cooldownUntil && account.health.cooldownUntil > Date.now()) { + const remaining = Math.ceil((account.health.cooldownUntil - Date.now()) / 1000) + prompts.log.warn(` In cooldown for ${remaining}s`) + } + + prompts.log.info( + ` ${account.health.successCount} successful, ${account.health.failureCount} failed requests`, + ) + + if (providerID === "anthropic") { + const usage = await Auth.OAuthPool.fetchAnthropicUsage(providerID, "default", account.id) + if (usage) { + const parts: string[] = [] + if (usage.fiveHour) parts.push(`5h: ${usage.fiveHour.utilization}%`) + if (usage.sevenDay) parts.push(`7d: ${usage.sevenDay.utilization}%`) + if (usage.sevenDaySonnet) parts.push(`7d-sonnet: ${usage.sevenDaySonnet.utilization}%`) + if (parts.length > 0) { + prompts.log.info(` Rate Limits: ${parts.join(", ")}`) + } + } + } + } + } + + if (!hasOAuth) { + prompts.log.warn("No OAuth providers configured") + } + + prompts.outro("") + }, +}) + +export const AuthSwitchCommand = cmd({ + command: "switch", + describe: "switch active OAuth account for a provider", + async handler() { + UI.empty() + prompts.intro("Switch Account") + const all = await Auth.all() + const database = await ModelsDev.get() + + const oauthProviders = Object.entries(all) + .filter(([, info]) => info.type === "oauth") + .sort((a, b) => { + const nameA = database[a[0]]?.name || a[0] + const nameB = database[b[0]]?.name || b[0] + return nameA.localeCompare(nameB) + }) + if (oauthProviders.length === 0) { + prompts.log.warn("No OAuth providers configured") + prompts.outro("") + return + } + + const providerID = await prompts.select({ + message: "Select provider", + options: oauthProviders.map(([id]) => ({ + label: database[id]?.name || id, + value: id, + })), + }) + if (prompts.isCancel(providerID)) throw new UI.CancelledError() + + const accounts = await Auth.OAuthPool.getUsage(providerID) + if (accounts.length < 2) { + prompts.log.warn("Only one account configured for this provider") + prompts.outro("") + return + } + + const accountOptions = [] + for (const account of accounts) { + const label = account.label || "default" + const status = account.isActive ? " (active)" : "" + let hint = `${account.health.successCount} requests` + + if (providerID === "anthropic") { + const usage = await Auth.OAuthPool.fetchAnthropicUsage(providerID, "default", account.id) + if (usage?.fiveHour) { + hint = `5h: ${usage.fiveHour.utilization}%` + } + } + + accountOptions.push({ + label: `${label}${status}`, + value: account.id, + hint, + }) + } + + const recordID = await prompts.select({ + message: "Select account to activate", + options: accountOptions, + }) + if (prompts.isCancel(recordID)) throw new UI.CancelledError() + + const success = await Auth.OAuthPool.setActive(providerID, "default", recordID) + if (success) { + prompts.log.success("Account switched successfully") + } else { + prompts.log.error("Failed to switch account") + } + + prompts.outro("") + }, +}) + export const AuthCommand = cmd({ command: "auth", describe: "manage credentials", builder: (yargs) => - yargs.command(AuthLoginCommand).command(AuthLogoutCommand).command(AuthListCommand).demandCommand(), + yargs + .command(AuthLoginCommand) + .command(AuthLogoutCommand) + .command(AuthListCommand) + .command(AuthUsageCommand) + .command(AuthSwitchCommand) + .demandCommand(), async handler() {}, }) @@ -177,12 +317,21 @@ export const AuthListCommand = cmd({ const homedir = os.homedir() const displayPath = authPath.startsWith(homedir) ? authPath.replace(homedir, "~") : authPath prompts.intro(`Credentials ${UI.Style.TEXT_DIM}${displayPath}`) - const results = Object.entries(await Auth.all()) const database = await ModelsDev.get() + const results = Object.entries(await Auth.all()).sort((a, b) => { + const nameA = database[a[0]]?.name || a[0] + const nameB = database[b[0]]?.name || b[0] + return nameA.localeCompare(nameB) + }) for (const [providerID, result] of results) { const name = database[providerID]?.name || providerID - prompts.log.info(`${name} ${UI.Style.TEXT_DIM}${result.type}`) + if (result.type === "oauth") { + const count = await Auth.OAuthPool.list(providerID).then((accounts) => accounts.length) + prompts.log.info(`${name} ${UI.Style.TEXT_DIM}oauth${count > 1 ? ` (${count} accounts)` : ""}`) + } else { + prompts.log.info(`${name} ${UI.Style.TEXT_DIM}${result.type}`) + } } prompts.outro(`${results.length} credentials`) @@ -205,7 +354,7 @@ export const AuthListCommand = cmd({ UI.empty() prompts.intro("Environment") - for (const { provider, envVar } of activeEnvVars) { + for (const { provider, envVar } of activeEnvVars.sort((a, b) => a.provider.localeCompare(b.provider))) { prompts.log.info(`${provider} ${UI.Style.TEXT_DIM}${envVar}`) } @@ -376,7 +525,7 @@ export const AuthLoginCommand = cmd({ export const AuthLogoutCommand = cmd({ command: "logout", - describe: "log out from a configured provider", + describe: "log out from a configured provider or individual account", async handler() { UI.empty() const credentials = await Auth.all().then((x) => Object.entries(x)) @@ -394,6 +543,44 @@ export const AuthLogoutCommand = cmd({ })), }) if (prompts.isCancel(providerID)) throw new UI.CancelledError() + + const info = credentials.find(([key]) => key === providerID)?.[1] + if (info?.type === "oauth") { + const accounts = await Auth.OAuthPool.list(providerID) + if (accounts.length > 1) { + const options = [ + { label: "Remove all accounts", value: "__all__", hint: `${accounts.length} accounts` }, + ...accounts.map((account, index) => ({ + label: `Account ${index + 1}${account.label && account.label !== "default" ? ` (${account.label})` : ""}`, + value: account.id, + hint: `ID: ${account.id.slice(0, 8)}...`, + })), + ] + + const selection = await prompts.select({ + message: "Remove which account?", + options, + }) + if (prompts.isCancel(selection)) throw new UI.CancelledError() + + if (selection === "__all__") { + await Auth.remove(providerID) + prompts.log.success(`Removed all ${accounts.length} accounts`) + } else { + const result = await Auth.OAuthPool.removeRecord(providerID, selection) + if (result.removed) { + prompts.log.success( + `Account removed. ${result.remaining} account${result.remaining !== 1 ? "s" : ""} remaining.`, + ) + } else { + prompts.log.error("Failed to remove account") + } + } + prompts.outro("Done") + return + } + } + await Auth.remove(providerID) prompts.outro("Logout successful") }, diff --git a/packages/opencode/src/config/config.ts b/packages/opencode/src/config/config.ts index ddb3af4b0a8..a7154c0ea06 100644 --- a/packages/opencode/src/config/config.ts +++ b/packages/opencode/src/config/config.ts @@ -837,6 +837,26 @@ export namespace Config { }), ) .optional(), + oauth: z + .object({ + rateLimitCooldownMs: z.number().int().positive().optional().describe("Rate limit cooldown in milliseconds"), + authFailureCooldownMs: z.number().int().positive().optional().describe("Auth failure cooldown in milliseconds"), + networkRetryAttempts: z + .number() + .int() + .nonnegative() + .optional() + .describe("Network retry attempts per OAuth credential before failing"), + maxAttempts: z + .number() + .int() + .positive() + .optional() + .describe("Maximum OAuth credential attempts per request"), + toastDurationMs: z.number().int().positive().optional().describe("Failover toast duration in milliseconds"), + }) + .optional() + .describe("OAuth rotation settings"), options: z .object({ apiKey: z.string().optional(), diff --git a/packages/opencode/src/provider/auth.ts b/packages/opencode/src/provider/auth.ts index e6681ff0891..283957b09b2 100644 --- a/packages/opencode/src/provider/auth.ts +++ b/packages/opencode/src/provider/auth.ts @@ -99,16 +99,13 @@ export namespace ProviderAuth { }) } if ("refresh" in result) { - const info: Auth.Info = { - type: "oauth", - access: result.access, + await Auth.addOAuth(input.providerID, { refresh: result.refresh, + access: result.access, expires: result.expires, - } - if (result.accountId) { - info.accountId = result.accountId - } - await Auth.set(input.providerID, info) + accountId: result.accountId, + enterpriseUrl: result.enterpriseUrl, + }) } return } diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index bcb115edf41..156b2843ee4 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -9,6 +9,7 @@ import { Plugin } from "../plugin" import { ModelsDev } from "./models" import { NamedError } from "@opencode-ai/util/error" import { Auth } from "../auth" +import { createOAuthRotatingFetch } from "../auth/rotating-fetch" import { Env } from "../env" import { Instance } from "../project/instance" import { Flag } from "../flag/flag" @@ -963,6 +964,7 @@ export namespace Provider { providerID: model.providerID, }) const s = await state() + const config = await Config.get() const provider = s.providers[model.providerID] const options = { ...provider.options } @@ -978,13 +980,13 @@ export namespace Provider { ...model.headers, } - const key = Bun.hash.xxHash32(JSON.stringify({ npm: model.api.npm, options })) + const key = Bun.hash.xxHash32(JSON.stringify({ providerID: model.providerID, npm: model.api.npm, options })) const existing = s.sdk.get(key) if (existing) return existing const customFetch = options["fetch"] - options["fetch"] = async (input: any, init?: BunFetchRequestInit) => { + const fetchWithTimeout = async (input: any, init?: BunFetchRequestInit) => { // Preserve custom fetch if it exists, wrap it with timeout logic const fetchFn = customFetch ?? fetch const opts = init ?? {} @@ -1024,6 +1026,16 @@ export namespace Provider { }) } + const oauthConfig = config.provider?.[model.providerID]?.oauth + options["fetch"] = createOAuthRotatingFetch(fetchWithTimeout, { + providerID: model.providerID, + maxAttempts: oauthConfig?.maxAttempts, + rateLimitCooldownMs: oauthConfig?.rateLimitCooldownMs, + authFailureCooldownMs: oauthConfig?.authFailureCooldownMs, + networkRetryAttempts: oauthConfig?.networkRetryAttempts, + toastDurationMs: oauthConfig?.toastDurationMs, + }) + // Special case: google-vertex-anthropic uses a subpath import const bundledKey = model.providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : model.api.npm diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index 28dec7f4043..8557aed2686 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -167,6 +167,126 @@ export namespace Server { .route("/", FileRoutes()) .route("/mcp", McpRoutes()) .route("/tui", TuiRoutes()) + .get( + "/auth/usage", + describeRoute({ + summary: "Get auth usage", + description: "Get rate limit and usage information for authenticated providers.", + operationId: "auth.usage", + responses: { + 200: { + description: "Usage information per provider and account", + content: { + "application/json": { + schema: resolver(z.any().meta({ ref: "AuthUsage" })), + }, + }, + }, + ...errors(400), + }, + }), + async (c) => { + const all = await Auth.all() + const result: Record< + string, + { + accounts: Awaited> + anthropicUsage?: Awaited> + } + > = {} + + for (const [providerID, info] of Object.entries(all)) { + if (info.type === "oauth") { + const accounts = await Auth.OAuthPool.getUsage(providerID) + const anthropicUsage = await Auth.OAuthPool.fetchAnthropicUsage(providerID) + result[providerID] = { accounts, anthropicUsage: anthropicUsage ?? undefined } + } + } + + return c.json(result) + }, + ) + .post( + "/auth/active", + describeRoute({ + summary: "Set active OAuth account", + description: + "Set the active OAuth account for a provider. This account will be used for requests until rate limited.", + operationId: "auth.setActive", + responses: { + 200: { + description: "Active account updated", + content: { + "application/json": { + schema: resolver( + z.object({ + success: z.boolean(), + anthropicUsage: z.any().optional(), + }), + ), + }, + }, + }, + ...errors(400), + }, + }), + validator( + "json", + z.object({ + providerID: z.string(), + recordID: z.string(), + namespace: z.string().optional(), + }), + ), + async (c) => { + const body = c.req.valid("json") + const namespace = body.namespace ?? "default" + const success = await Auth.OAuthPool.setActive(body.providerID, namespace, body.recordID) + const anthropicUsage = success + ? await Auth.OAuthPool.fetchAnthropicUsage(body.providerID, namespace, body.recordID) + : null + return c.json({ success, anthropicUsage: anthropicUsage ?? undefined }) + }, + ) + .delete( + "/auth/account", + describeRoute({ + summary: "Remove OAuth account", + description: + "Remove an OAuth account from a provider. If this is the last account, the provider will be disconnected.", + operationId: "auth.removeAccount", + responses: { + 200: { + description: "Account removed", + content: { + "application/json": { + schema: resolver( + z.object({ + removed: z.boolean(), + remaining: z.number(), + }), + ), + }, + }, + }, + ...errors(400), + }, + }), + validator( + "json", + z.object({ + providerID: z.string(), + recordID: z.string(), + namespace: z.string().optional(), + }), + ), + async (c) => { + const body = c.req.valid("json") + const namespace = body.namespace ?? "default" + const result = await Auth.OAuthPool.removeRecord(body.providerID, body.recordID, namespace) + return c.json(result) + }, + ) .post( "/instance/dispose", describeRoute({ diff --git a/packages/opencode/src/session/message-v2.ts b/packages/opencode/src/session/message-v2.ts index d326976f1ae..a3919af832d 100644 --- a/packages/opencode/src/session/message-v2.ts +++ b/packages/opencode/src/session/message-v2.ts @@ -497,10 +497,8 @@ export namespace MessageV2 { text: part.text, providerMetadata: part.metadata, }) - if (part.type === "step-start") - assistantMessage.parts.push({ - type: "step-start", - }) + // step-start parts are not added to UIMessage since "step-start" is not a valid + // UIMessagePart type in the AI SDK - they are only used internally for tracking if (part.type === "tool") { if (part.state.status === "completed") { if (part.state.attachments?.length) { diff --git a/packages/opencode/test/auth/oauth-rotation.test.ts b/packages/opencode/test/auth/oauth-rotation.test.ts new file mode 100644 index 00000000000..6b3a1757575 --- /dev/null +++ b/packages/opencode/test/auth/oauth-rotation.test.ts @@ -0,0 +1,461 @@ +import { describe, expect, test } from "bun:test" +import { Auth } from "../../src/auth" +import { createOAuthRotatingFetch } from "../../src/auth/rotating-fetch" +import { withOAuthRecord } from "../../src/auth/context" + +describe("OAuth subscription failover", () => { + const providerID = "oauth-rotation-test" + + test("rotates on 429 (Retry-After) and succeeds with next account", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + if (auth.refresh === "r1") { + return new Response("rate limited", { + status: 429, + headers: { + "Retry-After": "1", + }, + }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a2.recordID) + expect(order[1]).toBe(a1.recordID) + }) + + test("updates the correct OAuth record by refresh token without record context", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + await Auth.set(providerID, { + type: "oauth", + refresh: "r1", + access: "updated-a1", + expires: Date.now() + 60_000, + }) + + const record1 = await withOAuthRecord(providerID, a1.recordID, async () => Auth.get(providerID)) + const record2 = await withOAuthRecord(providerID, a2.recordID, async () => Auth.get(providerID)) + + expect(record1?.type).toBe("oauth") + expect(record1 && record1.type === "oauth" ? record1.access : "").toBe("updated-a1") + + expect(record2?.type).toBe("oauth") + expect(record2 && record2.type === "oauth" ? record2.access : "").toBe("a2") + }) + + test("retries once on 401/403 by forcing refresh, then succeeds", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "bad", + expires: Date.now() + 60_000, + }) + await Auth.addOAuth(providerID, { + refresh: "r2", + access: "ok", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + // Simulate plugin refresh behavior: when access is cleared/expired, + // it refreshes and persists via Auth.set(). + if (!auth.access) { + await Auth.set(providerID, { + type: "oauth", + refresh: auth.refresh, + access: `refreshed-${auth.refresh}`, + expires: Date.now() + 60_000, + }) + return new Response("ok", { status: 200 }) + } + + if (auth.access === "bad") { + return new Response("unauthorized", { status: 401 }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + expect(response.status).toBe(200) + + const record1 = await withOAuthRecord(providerID, a1.recordID, async () => Auth.get(providerID)) + expect(record1?.type).toBe("oauth") + expect(record1 && record1.type === "oauth" ? record1.access : "").toBe("refreshed-r1") + }) + + test("fails over on 401/403 when refresh does not fix the credential", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "bad", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "ok", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + if (auth.refresh === "r1") { + return new Response("unauthorized", { status: 401 }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a2.recordID) + expect(order[1]).toBe(a1.recordID) + }) + + test("fails over on non-auth errors", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + if (auth.refresh === "r1") { + return new Response("payment required", { status: 402 }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a2.recordID) + expect(order[1]).toBe(a1.recordID) + }) + + test("sticks to the active credential until rate limited", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const counts = new Map() + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + const refresh = auth.refresh + counts.set(refresh, (counts.get(refresh) ?? 0) + 1) + + if (refresh === "r1" && (counts.get(refresh) ?? 0) >= 3) { + return new Response("rate limited", { status: 429 }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + + const first = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + expect(first.status).toBe(200) + + const second = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + expect(second.status).toBe(200) + + const beforeRateLimit = await Auth.OAuthPool.orderedIDs(providerID) + expect(beforeRateLimit[0]).toBe(a1.recordID) + expect(beforeRateLimit[1]).toBe(a2.recordID) + + const third = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + expect(third.status).toBe(200) + + const afterRateLimit = await Auth.OAuthPool.orderedIDs(providerID) + expect(afterRateLimit[0]).toBe(a2.recordID) + expect(afterRateLimit[1]).toBe(a1.recordID) + }) + + test("does not retry non-replayable bodies but rotates for the next request", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const counts = new Map() + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + counts.set(auth.refresh, (counts.get(auth.refresh) ?? 0) + 1) + return new Response("rate limited", { status: 429 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const body = new ReadableStream({ + start(controller) { + controller.enqueue(new TextEncoder().encode("payload")) + controller.close() + }, + }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body }) + + expect(response.status).toBe(429) + expect(counts.get("r1") ?? 0).toBe(1) + expect(counts.get("r2") ?? 0).toBe(0) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a2.recordID) + expect(order[1]).toBe(a1.recordID) + }) + + test("returns the last response when all credentials are exhausted", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + return new Response("rate limited", { status: 429 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(429) + + const records = await Auth.OAuthPool.list(providerID) + const recordByID = new Map(records.map((record) => [record.id, record])) + expect(recordByID.get(a1.recordID)?.health.failureCount ?? 0).toBe(1) + expect(recordByID.get(a2.recordID)?.health.failureCount ?? 0).toBe(1) + }) + + test("retries network errors without rotating", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + const a2 = await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const counts = new Map() + let failures = 0 + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + counts.set(auth.refresh, (counts.get(auth.refresh) ?? 0) + 1) + + if (auth.refresh === "r1" && failures < 1) { + failures += 1 + const error = new Error("network down") + ;(error as { code?: string }).code = "ECONNRESET" + throw error + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + expect(counts.get("r1") ?? 0).toBe(2) + expect(counts.get("r2") ?? 0).toBe(0) + + const order = await Auth.OAuthPool.orderedIDs(providerID) + expect(order[0]).toBe(a1.recordID) + expect(order[1]).toBe(a2.recordID) + }) + + test("respects Retry-After HTTP date headers", async () => { + await Auth.remove(providerID) + + const a1 = await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const originalNow = Date.now + const now = 1_700_000_000_000 + Date.now = () => now + + try { + const retryAt = new Date(now + 5_000).toUTCString() + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + if (auth.refresh === "r1") { + return new Response("rate limited", { + status: 429, + headers: { + "Retry-After": retryAt, + }, + }) + } + + return new Response("ok", { status: 200 }) + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover("https://example.com", { method: "POST", body: "{}" }) + + expect(response.status).toBe(200) + + const records = await Auth.OAuthPool.list(providerID) + const recordByID = new Map(records.map((record) => [record.id, record])) + expect(recordByID.get(a1.recordID)?.health.cooldownUntil).toBe(now + 5_000) + } finally { + Date.now = originalNow + } + }) + + test("falls back when Request.clone throws", async () => { + await Auth.remove(providerID) + + await Auth.addOAuth(providerID, { + refresh: "r1", + access: "a1", + expires: Date.now() + 60_000, + }) + await Auth.addOAuth(providerID, { + refresh: "r2", + access: "a2", + expires: Date.now() + 60_000, + }) + + const counts = new Map() + const baseFetch = async (_input: RequestInfo | URL, _init?: RequestInit) => { + const auth = await Auth.get(providerID) + expect(auth?.type).toBe("oauth") + if (!auth || auth.type !== "oauth") return new Response("no auth", { status: 500 }) + + counts.set(auth.refresh, (counts.get(auth.refresh) ?? 0) + 1) + + if (auth.refresh === "r1") { + return new Response("rate limited", { status: 429 }) + } + + return new Response("ok", { status: 200 }) + } + + const request = new Request("https://example.com", { method: "POST" }) + ;(request as { clone: () => Request }).clone = () => { + throw new Error("clone failed") + } + + const fetchWithFailover = createOAuthRotatingFetch(baseFetch, { providerID }) + const response = await fetchWithFailover(request) + + expect(response.status).toBe(429) + expect(counts.get("r1") ?? 0).toBe(1) + expect(counts.get("r2") ?? 0).toBe(0) + }) +}) diff --git a/packages/plugin/src/index.ts b/packages/plugin/src/index.ts index e57eff579e6..b95b44fe573 100644 --- a/packages/plugin/src/index.ts +++ b/packages/plugin/src/index.ts @@ -115,6 +115,7 @@ export type AuthOuathResult = { url: string; instructions: string } & ( access: string expires: number accountId?: string + enterpriseUrl?: string // Used for GitHub Copilot Enterprise auth flows. } | { key: string } )) @@ -135,6 +136,7 @@ export type AuthOuathResult = { url: string; instructions: string } & ( access: string expires: number accountId?: string + enterpriseUrl?: string // Used for GitHub Copilot Enterprise auth flows. } | { key: string } )) diff --git a/packages/sdk/js/src/v2/gen/sdk.gen.ts b/packages/sdk/js/src/v2/gen/sdk.gen.ts index 6f699319965..132aeb1d97c 100644 --- a/packages/sdk/js/src/v2/gen/sdk.gen.ts +++ b/packages/sdk/js/src/v2/gen/sdk.gen.ts @@ -9,8 +9,12 @@ import type { AppLogResponses, AppSkillsResponses, Auth as Auth3, + AuthSetActiveErrors, + AuthSetActiveResponses, AuthSetErrors, AuthSetResponses, + AuthUsageErrors, + AuthUsageResponses, CommandListResponses, Config as Config2, ConfigGetResponses, @@ -2740,6 +2744,103 @@ export class Tui extends HeyApiClient { } } +export class Auth2 extends HeyApiClient { + /** + * Get auth usage + * + * Get rate limit and usage information for authenticated providers. + */ + public usage( + parameters?: { + directory?: string + }, + options?: Options, + ) { + const params = buildClientParams([parameters], [{ args: [{ in: "query", key: "directory" }] }]) + return (options?.client ?? this.client).get({ + url: "/auth/usage", + ...options, + ...params, + }) + } + + /** + * Set active OAuth account + * + * Set the active OAuth account for a provider. This account will be used for requests until rate limited. + */ + public setActive( + parameters?: { + directory?: string + providerID?: string + recordID?: string + namespace?: string + }, + options?: Options, + ) { + const params = buildClientParams( + [parameters], + [ + { + args: [ + { in: "query", key: "directory" }, + { in: "body", key: "providerID" }, + { in: "body", key: "recordID" }, + { in: "body", key: "namespace" }, + ], + }, + ], + ) + return (options?.client ?? this.client).post({ + url: "/auth/active", + ...options, + ...params, + headers: { + "Content-Type": "application/json", + ...options?.headers, + ...params.headers, + }, + }) + } + + /** + * Set auth credentials + * + * Set authentication credentials + */ + public set( + parameters: { + providerID: string + directory?: string + auth?: Auth3 + }, + options?: Options, + ) { + const params = buildClientParams( + [parameters], + [ + { + args: [ + { in: "path", key: "providerID" }, + { in: "query", key: "directory" }, + { key: "auth", map: "body" }, + ], + }, + ], + ) + return (options?.client ?? this.client).put({ + url: "/auth/{providerID}", + ...options, + ...params, + headers: { + "Content-Type": "application/json", + ...options?.headers, + ...params.headers, + }, + }) + } +} + export class Instance extends HeyApiClient { /** * Dispose instance @@ -2949,45 +3050,6 @@ export class Formatter extends HeyApiClient { } } -export class Auth2 extends HeyApiClient { - /** - * Set auth credentials - * - * Set authentication credentials - */ - public set( - parameters: { - providerID: string - directory?: string - auth?: Auth3 - }, - options?: Options, - ) { - const params = buildClientParams( - [parameters], - [ - { - args: [ - { in: "path", key: "providerID" }, - { in: "query", key: "directory" }, - { key: "auth", map: "body" }, - ], - }, - ], - ) - return (options?.client ?? this.client).put({ - url: "/auth/{providerID}", - ...options, - ...params, - headers: { - "Content-Type": "application/json", - ...options?.headers, - ...params.headers, - }, - }) - } -} - export class Event extends HeyApiClient { /** * Subscribe to events @@ -3097,6 +3159,11 @@ export class OpencodeClient extends HeyApiClient { return (this._tui ??= new Tui({ client: this.client })) } + private _auth?: Auth2 + get auth(): Auth2 { + return (this._auth ??= new Auth2({ client: this.client })) + } + private _instance?: Instance get instance(): Instance { return (this._instance ??= new Instance({ client: this.client })) @@ -3132,11 +3199,6 @@ export class OpencodeClient extends HeyApiClient { return (this._formatter ??= new Formatter({ client: this.client })) } - private _auth?: Auth2 - get auth(): Auth2 { - return (this._auth ??= new Auth2({ client: this.client })) - } - private _event?: Event get event(): Event { return (this._event ??= new Event({ client: this.client }))