diff --git a/packages/opencode/src/mcp/auth.ts b/packages/opencode/src/mcp/auth.ts index 385cb3c7339..6ebb95698d7 100644 --- a/packages/opencode/src/mcp/auth.ts +++ b/packages/opencode/src/mcp/auth.ts @@ -24,6 +24,8 @@ export namespace McpAuth { tokens: Tokens.optional(), clientInfo: ClientInfo.optional(), codeVerifier: z.string().optional(), + oauthState: z.string().optional(), + serverUrl: z.string().optional(), // Track the URL these credentials are for }) export type Entry = z.infer @@ -34,14 +36,35 @@ export namespace McpAuth { return data[mcpName] } + /** + * Get auth entry and validate it's for the correct URL. + * Returns undefined if URL has changed (credentials are invalid). + */ + export async function getForUrl(mcpName: string, serverUrl: string): Promise { + const entry = await get(mcpName) + if (!entry) return undefined + + // If no serverUrl is stored, this is from an old version - consider it invalid + if (!entry.serverUrl) return undefined + + // If URL has changed, credentials are invalid + if (entry.serverUrl !== serverUrl) return undefined + + return entry + } + export async function all(): Promise> { const file = Bun.file(filepath) return file.json().catch(() => ({})) } - export async function set(mcpName: string, entry: Entry): Promise { + export async function set(mcpName: string, entry: Entry, serverUrl?: string): Promise { const file = Bun.file(filepath) const data = await all() + // Always update serverUrl if provided + if (serverUrl) { + entry.serverUrl = serverUrl + } await Bun.write(file, JSON.stringify({ ...data, [mcpName]: entry }, null, 2)) await fs.chmod(file.name!, 0o600) } @@ -54,16 +77,16 @@ export namespace McpAuth { await fs.chmod(file.name!, 0o600) } - export async function updateTokens(mcpName: string, tokens: Tokens): Promise { + export async function updateTokens(mcpName: string, tokens: Tokens, serverUrl?: string): Promise { const entry = (await get(mcpName)) ?? {} entry.tokens = tokens - await set(mcpName, entry) + await set(mcpName, entry, serverUrl) } - export async function updateClientInfo(mcpName: string, clientInfo: ClientInfo): Promise { + export async function updateClientInfo(mcpName: string, clientInfo: ClientInfo, serverUrl?: string): Promise { const entry = (await get(mcpName)) ?? {} entry.clientInfo = clientInfo - await set(mcpName, entry) + await set(mcpName, entry, serverUrl) } export async function updateCodeVerifier(mcpName: string, codeVerifier: string): Promise { @@ -79,4 +102,23 @@ export namespace McpAuth { await set(mcpName, entry) } } + + export async function updateOAuthState(mcpName: string, oauthState: string): Promise { + const entry = (await get(mcpName)) ?? {} + entry.oauthState = oauthState + await set(mcpName, entry) + } + + export async function getOAuthState(mcpName: string): Promise { + const entry = await get(mcpName) + return entry?.oauthState + } + + export async function clearOAuthState(mcpName: string): Promise { + const entry = await get(mcpName) + if (entry) { + delete entry.oauthState + await set(mcpName, entry) + } + } } diff --git a/packages/opencode/src/mcp/index.ts b/packages/opencode/src/mcp/index.ts index 41d59097b52..625809af9a8 100644 --- a/packages/opencode/src/mcp/index.ts +++ b/packages/opencode/src/mcp/index.ts @@ -436,6 +436,13 @@ export namespace MCP { // Start the callback server await McpOAuthCallback.ensureRunning() + // Generate and store a cryptographically secure state parameter BEFORE creating the provider + // The SDK will call provider.state() to read this value + const oauthState = Array.from(crypto.getRandomValues(new Uint8Array(32))) + .map((b) => b.toString(16).padStart(2, "0")) + .join("") + await McpAuth.updateOAuthState(mcpName, oauthState) + // Create a new auth provider for this flow // OAuth config is optional - if not provided, we'll use auto-discovery const oauthConfig = typeof mcpConfig.oauth === "object" ? mcpConfig.oauth : undefined @@ -491,25 +498,29 @@ export namespace MCP { return s.status[mcpName] ?? { status: "connected" } } - // Extract state from authorization URL to use as callback key - // If no state parameter, use mcpName as fallback - const authUrl = new URL(authorizationUrl) - let oauthState = mcpName - - if (authUrl.searchParams.has("state")) { - oauthState = authUrl.searchParams.get("state")! - } else { - log.info("no state parameter in authorization URL, using mcpName as state", { mcpName }) - authUrl.searchParams.set("state", oauthState) + // Get the state that was already generated and stored in startAuth() + const oauthState = await McpAuth.getOAuthState(mcpName) + if (!oauthState) { + throw new Error("OAuth state not found - this should not happen") } - // Open browser - log.info("opening browser for oauth", { mcpName, url: authUrl.toString(), state: oauthState }) - await open(authUrl.toString()) + // The SDK has already added the state parameter to the authorization URL + // We just need to open the browser + log.info("opening browser for oauth", { mcpName, url: authorizationUrl, state: oauthState }) + await open(authorizationUrl) - // Wait for callback using the OAuth state parameter (or mcpName as fallback) + // Wait for callback using the OAuth state parameter const code = await McpOAuthCallback.waitForCallback(oauthState) + // Validate and clear the state + const storedState = await McpAuth.getOAuthState(mcpName) + if (storedState !== oauthState) { + await McpAuth.clearOAuthState(mcpName) + throw new Error("OAuth state mismatch - potential CSRF attack") + } + + await McpAuth.clearOAuthState(mcpName) + // Finish auth return finishAuth(mcpName, code) } @@ -561,6 +572,7 @@ export namespace MCP { await McpAuth.remove(mcpName) McpOAuthCallback.cancelPending(mcpName) pendingOAuthTransports.delete(mcpName) + await McpAuth.clearOAuthState(mcpName) log.info("removed oauth credentials", { mcpName }) } diff --git a/packages/opencode/src/mcp/oauth-callback.ts b/packages/opencode/src/mcp/oauth-callback.ts index 67bb5168410..bb3b56f2e95 100644 --- a/packages/opencode/src/mcp/oauth-callback.ts +++ b/packages/opencode/src/mcp/oauth-callback.ts @@ -81,9 +81,19 @@ export namespace McpOAuthCallback { log.info("received oauth callback", { hasCode: !!code, state, error }) + // Enforce state parameter presence + if (!state) { + const errorMsg = "Missing required state parameter - potential CSRF attack" + log.error("oauth callback missing state parameter", { url: url.toString() }) + return new Response(HTML_ERROR(errorMsg), { + status: 400, + headers: { "Content-Type": "text/html" }, + }) + } + if (error) { const errorMsg = errorDescription || error - if (state && pendingAuths.has(state)) { + if (pendingAuths.has(state)) { const pending = pendingAuths.get(state)! clearTimeout(pending.timeout) pendingAuths.delete(state) @@ -101,33 +111,20 @@ export namespace McpOAuthCallback { }) } - // Try to find the pending auth by state parameter, or if no state, use the single pending auth - let pending: PendingAuth | undefined - let pendingKey: string | undefined - - if (state && pendingAuths.has(state)) { - pending = pendingAuths.get(state)! - pendingKey = state - } else if (!state && pendingAuths.size === 1) { - // No state parameter but only one pending auth - use it - const [key, value] = pendingAuths.entries().next().value as [string, PendingAuth] - pending = value - pendingKey = key - log.info("no state parameter, using single pending auth", { key }) - } - - if (!pending || !pendingKey) { - const errorMsg = !state - ? "No state parameter provided and multiple pending authorizations" - : "Unknown or expired authorization request" + // Validate state parameter + if (!pendingAuths.has(state)) { + const errorMsg = "Invalid or expired state parameter - potential CSRF attack" + log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) }) return new Response(HTML_ERROR(errorMsg), { status: 400, headers: { "Content-Type": "text/html" }, }) } + const pending = pendingAuths.get(state)! + clearTimeout(pending.timeout) - pendingAuths.delete(pendingKey) + pendingAuths.delete(state) pending.resolve(code) return new Response(HTML_SUCCESS, { @@ -139,16 +136,16 @@ export namespace McpOAuthCallback { log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT }) } - export function waitForCallback(mcpName: string): Promise { + export function waitForCallback(oauthState: string): Promise { return new Promise((resolve, reject) => { const timeout = setTimeout(() => { - if (pendingAuths.has(mcpName)) { - pendingAuths.delete(mcpName) + if (pendingAuths.has(oauthState)) { + pendingAuths.delete(oauthState) reject(new Error("OAuth callback timeout - authorization took too long")) } }, CALLBACK_TIMEOUT_MS) - pendingAuths.set(mcpName, { resolve, reject, timeout }) + pendingAuths.set(oauthState, { resolve, reject, timeout }) }) } diff --git a/packages/opencode/src/mcp/oauth-provider.ts b/packages/opencode/src/mcp/oauth-provider.ts index 584eca8e88a..35ead25e8be 100644 --- a/packages/opencode/src/mcp/oauth-provider.ts +++ b/packages/opencode/src/mcp/oauth-provider.ts @@ -56,7 +56,8 @@ export class McpOAuthProvider implements OAuthClientProvider { } // Check stored client info (from dynamic registration) - const entry = await McpAuth.get(this.mcpName) + // Use getForUrl to validate credentials are for the current server URL + const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl) if (entry?.clientInfo) { // Check if client secret has expired if (entry.clientInfo.clientSecretExpiresAt && entry.clientInfo.clientSecretExpiresAt < Date.now() / 1000) { @@ -69,17 +70,21 @@ export class McpOAuthProvider implements OAuthClientProvider { } } - // No client info - will trigger dynamic registration + // No client info or URL changed - will trigger dynamic registration return undefined } async saveClientInformation(info: OAuthClientInformationFull): Promise { - await McpAuth.updateClientInfo(this.mcpName, { - clientId: info.client_id, - clientSecret: info.client_secret, - clientIdIssuedAt: info.client_id_issued_at, - clientSecretExpiresAt: info.client_secret_expires_at, - }) + await McpAuth.updateClientInfo( + this.mcpName, + { + clientId: info.client_id, + clientSecret: info.client_secret, + clientIdIssuedAt: info.client_id_issued_at, + clientSecretExpiresAt: info.client_secret_expires_at, + }, + this.serverUrl, + ) log.info("saved dynamically registered client", { mcpName: this.mcpName, clientId: info.client_id, @@ -87,7 +92,8 @@ export class McpOAuthProvider implements OAuthClientProvider { } async tokens(): Promise { - const entry = await McpAuth.get(this.mcpName) + // Use getForUrl to validate tokens are for the current server URL + const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl) if (!entry?.tokens) return undefined return { @@ -102,12 +108,16 @@ export class McpOAuthProvider implements OAuthClientProvider { } async saveTokens(tokens: OAuthTokens): Promise { - await McpAuth.updateTokens(this.mcpName, { - accessToken: tokens.access_token, - refreshToken: tokens.refresh_token, - expiresAt: tokens.expires_in ? Date.now() / 1000 + tokens.expires_in : undefined, - scope: tokens.scope, - }) + await McpAuth.updateTokens( + this.mcpName, + { + accessToken: tokens.access_token, + refreshToken: tokens.refresh_token, + expiresAt: tokens.expires_in ? Date.now() / 1000 + tokens.expires_in : undefined, + scope: tokens.scope, + }, + this.serverUrl, + ) log.info("saved oauth tokens", { mcpName: this.mcpName }) } @@ -127,6 +137,18 @@ export class McpOAuthProvider implements OAuthClientProvider { } return entry.codeVerifier } + + async saveState(state: string): Promise { + await McpAuth.updateOAuthState(this.mcpName, state) + } + + async state(): Promise { + const entry = await McpAuth.get(this.mcpName) + if (!entry?.oauthState) { + throw new Error(`No OAuth state saved for MCP server: ${this.mcpName}`) + } + return entry.oauthState + } } export { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH }