Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions packages/opencode/src/mcp/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ export namespace McpAuth {
tokens: Tokens.optional(),
clientInfo: ClientInfo.optional(),
codeVerifier: z.string().optional(),
oauthState: z.string().optional(),
serverUrl: z.string().optional(), // Track the URL these credentials are for
})
export type Entry = z.infer<typeof Entry>

Expand All @@ -34,14 +36,35 @@ export namespace McpAuth {
return data[mcpName]
}

/**
* Get auth entry and validate it's for the correct URL.
* Returns undefined if URL has changed (credentials are invalid).
*/
export async function getForUrl(mcpName: string, serverUrl: string): Promise<Entry | undefined> {
const entry = await get(mcpName)
if (!entry) return undefined

// If no serverUrl is stored, this is from an old version - consider it invalid
if (!entry.serverUrl) return undefined

// If URL has changed, credentials are invalid
if (entry.serverUrl !== serverUrl) return undefined

return entry
}

export async function all(): Promise<Record<string, Entry>> {
const file = Bun.file(filepath)
return file.json().catch(() => ({}))
}

export async function set(mcpName: string, entry: Entry): Promise<void> {
export async function set(mcpName: string, entry: Entry, serverUrl?: string): Promise<void> {
const file = Bun.file(filepath)
const data = await all()
// Always update serverUrl if provided
if (serverUrl) {
entry.serverUrl = serverUrl
}
await Bun.write(file, JSON.stringify({ ...data, [mcpName]: entry }, null, 2))
await fs.chmod(file.name!, 0o600)
}
Expand All @@ -54,16 +77,16 @@ export namespace McpAuth {
await fs.chmod(file.name!, 0o600)
}

export async function updateTokens(mcpName: string, tokens: Tokens): Promise<void> {
export async function updateTokens(mcpName: string, tokens: Tokens, serverUrl?: string): Promise<void> {
const entry = (await get(mcpName)) ?? {}
entry.tokens = tokens
await set(mcpName, entry)
await set(mcpName, entry, serverUrl)
}

export async function updateClientInfo(mcpName: string, clientInfo: ClientInfo): Promise<void> {
export async function updateClientInfo(mcpName: string, clientInfo: ClientInfo, serverUrl?: string): Promise<void> {
const entry = (await get(mcpName)) ?? {}
entry.clientInfo = clientInfo
await set(mcpName, entry)
await set(mcpName, entry, serverUrl)
}

export async function updateCodeVerifier(mcpName: string, codeVerifier: string): Promise<void> {
Expand All @@ -79,4 +102,23 @@ export namespace McpAuth {
await set(mcpName, entry)
}
}

export async function updateOAuthState(mcpName: string, oauthState: string): Promise<void> {
const entry = (await get(mcpName)) ?? {}
entry.oauthState = oauthState
await set(mcpName, entry)
}

export async function getOAuthState(mcpName: string): Promise<string | undefined> {
const entry = await get(mcpName)
return entry?.oauthState
}

export async function clearOAuthState(mcpName: string): Promise<void> {
const entry = await get(mcpName)
if (entry) {
delete entry.oauthState
await set(mcpName, entry)
}
}
}
40 changes: 26 additions & 14 deletions packages/opencode/src/mcp/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,13 @@ export namespace MCP {
// Start the callback server
await McpOAuthCallback.ensureRunning()

// Generate and store a cryptographically secure state parameter BEFORE creating the provider
// The SDK will call provider.state() to read this value
const oauthState = Array.from(crypto.getRandomValues(new Uint8Array(32)))
.map((b) => b.toString(16).padStart(2, "0"))
.join("")
await McpAuth.updateOAuthState(mcpName, oauthState)

// Create a new auth provider for this flow
// OAuth config is optional - if not provided, we'll use auto-discovery
const oauthConfig = typeof mcpConfig.oauth === "object" ? mcpConfig.oauth : undefined
Expand Down Expand Up @@ -491,25 +498,29 @@ export namespace MCP {
return s.status[mcpName] ?? { status: "connected" }
}

// Extract state from authorization URL to use as callback key
// If no state parameter, use mcpName as fallback
const authUrl = new URL(authorizationUrl)
let oauthState = mcpName

if (authUrl.searchParams.has("state")) {
oauthState = authUrl.searchParams.get("state")!
} else {
log.info("no state parameter in authorization URL, using mcpName as state", { mcpName })
authUrl.searchParams.set("state", oauthState)
// Get the state that was already generated and stored in startAuth()
const oauthState = await McpAuth.getOAuthState(mcpName)
if (!oauthState) {
throw new Error("OAuth state not found - this should not happen")
}

// Open browser
log.info("opening browser for oauth", { mcpName, url: authUrl.toString(), state: oauthState })
await open(authUrl.toString())
// The SDK has already added the state parameter to the authorization URL
// We just need to open the browser
log.info("opening browser for oauth", { mcpName, url: authorizationUrl, state: oauthState })
await open(authorizationUrl)

// Wait for callback using the OAuth state parameter (or mcpName as fallback)
// Wait for callback using the OAuth state parameter
const code = await McpOAuthCallback.waitForCallback(oauthState)

// Validate and clear the state
const storedState = await McpAuth.getOAuthState(mcpName)
if (storedState !== oauthState) {
await McpAuth.clearOAuthState(mcpName)
throw new Error("OAuth state mismatch - potential CSRF attack")
}

await McpAuth.clearOAuthState(mcpName)

// Finish auth
return finishAuth(mcpName, code)
}
Expand Down Expand Up @@ -561,6 +572,7 @@ export namespace MCP {
await McpAuth.remove(mcpName)
McpOAuthCallback.cancelPending(mcpName)
pendingOAuthTransports.delete(mcpName)
await McpAuth.clearOAuthState(mcpName)
log.info("removed oauth credentials", { mcpName })
}

Expand Down
47 changes: 22 additions & 25 deletions packages/opencode/src/mcp/oauth-callback.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,19 @@ export namespace McpOAuthCallback {

log.info("received oauth callback", { hasCode: !!code, state, error })

// Enforce state parameter presence
if (!state) {
const errorMsg = "Missing required state parameter - potential CSRF attack"
log.error("oauth callback missing state parameter", { url: url.toString() })
return new Response(HTML_ERROR(errorMsg), {
status: 400,
headers: { "Content-Type": "text/html" },
})
}

if (error) {
const errorMsg = errorDescription || error
if (state && pendingAuths.has(state)) {
if (pendingAuths.has(state)) {
const pending = pendingAuths.get(state)!
clearTimeout(pending.timeout)
pendingAuths.delete(state)
Expand All @@ -101,33 +111,20 @@ export namespace McpOAuthCallback {
})
}

// Try to find the pending auth by state parameter, or if no state, use the single pending auth
let pending: PendingAuth | undefined
let pendingKey: string | undefined

if (state && pendingAuths.has(state)) {
pending = pendingAuths.get(state)!
pendingKey = state
} else if (!state && pendingAuths.size === 1) {
// No state parameter but only one pending auth - use it
const [key, value] = pendingAuths.entries().next().value as [string, PendingAuth]
pending = value
pendingKey = key
log.info("no state parameter, using single pending auth", { key })
}

if (!pending || !pendingKey) {
const errorMsg = !state
? "No state parameter provided and multiple pending authorizations"
: "Unknown or expired authorization request"
// Validate state parameter
if (!pendingAuths.has(state)) {
const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
return new Response(HTML_ERROR(errorMsg), {
status: 400,
headers: { "Content-Type": "text/html" },
})
}

const pending = pendingAuths.get(state)!

clearTimeout(pending.timeout)
pendingAuths.delete(pendingKey)
pendingAuths.delete(state)
pending.resolve(code)

return new Response(HTML_SUCCESS, {
Expand All @@ -139,16 +136,16 @@ export namespace McpOAuthCallback {
log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
}

export function waitForCallback(mcpName: string): Promise<string> {
export function waitForCallback(oauthState: string): Promise<string> {
return new Promise((resolve, reject) => {
const timeout = setTimeout(() => {
if (pendingAuths.has(mcpName)) {
pendingAuths.delete(mcpName)
if (pendingAuths.has(oauthState)) {
pendingAuths.delete(oauthState)
reject(new Error("OAuth callback timeout - authorization took too long"))
}
}, CALLBACK_TIMEOUT_MS)

pendingAuths.set(mcpName, { resolve, reject, timeout })
pendingAuths.set(oauthState, { resolve, reject, timeout })
})
}

Expand Down
52 changes: 37 additions & 15 deletions packages/opencode/src/mcp/oauth-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ export class McpOAuthProvider implements OAuthClientProvider {
}

// Check stored client info (from dynamic registration)
const entry = await McpAuth.get(this.mcpName)
// Use getForUrl to validate credentials are for the current server URL
const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl)
if (entry?.clientInfo) {
// Check if client secret has expired
if (entry.clientInfo.clientSecretExpiresAt && entry.clientInfo.clientSecretExpiresAt < Date.now() / 1000) {
Expand All @@ -69,25 +70,30 @@ export class McpOAuthProvider implements OAuthClientProvider {
}
}

// No client info - will trigger dynamic registration
// No client info or URL changed - will trigger dynamic registration
return undefined
}

async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
await McpAuth.updateClientInfo(this.mcpName, {
clientId: info.client_id,
clientSecret: info.client_secret,
clientIdIssuedAt: info.client_id_issued_at,
clientSecretExpiresAt: info.client_secret_expires_at,
})
await McpAuth.updateClientInfo(
this.mcpName,
{
clientId: info.client_id,
clientSecret: info.client_secret,
clientIdIssuedAt: info.client_id_issued_at,
clientSecretExpiresAt: info.client_secret_expires_at,
},
this.serverUrl,
)
log.info("saved dynamically registered client", {
mcpName: this.mcpName,
clientId: info.client_id,
})
}

async tokens(): Promise<OAuthTokens | undefined> {
const entry = await McpAuth.get(this.mcpName)
// Use getForUrl to validate tokens are for the current server URL
const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl)
if (!entry?.tokens) return undefined

return {
Expand All @@ -102,12 +108,16 @@ export class McpOAuthProvider implements OAuthClientProvider {
}

async saveTokens(tokens: OAuthTokens): Promise<void> {
await McpAuth.updateTokens(this.mcpName, {
accessToken: tokens.access_token,
refreshToken: tokens.refresh_token,
expiresAt: tokens.expires_in ? Date.now() / 1000 + tokens.expires_in : undefined,
scope: tokens.scope,
})
await McpAuth.updateTokens(
this.mcpName,
{
accessToken: tokens.access_token,
refreshToken: tokens.refresh_token,
expiresAt: tokens.expires_in ? Date.now() / 1000 + tokens.expires_in : undefined,
scope: tokens.scope,
},
this.serverUrl,
)
log.info("saved oauth tokens", { mcpName: this.mcpName })
}

Expand All @@ -127,6 +137,18 @@ export class McpOAuthProvider implements OAuthClientProvider {
}
return entry.codeVerifier
}

async saveState(state: string): Promise<void> {
await McpAuth.updateOAuthState(this.mcpName, state)
}

async state(): Promise<string> {
const entry = await McpAuth.get(this.mcpName)
if (!entry?.oauthState) {
throw new Error(`No OAuth state saved for MCP server: ${this.mcpName}`)
}
return entry.oauthState
}
}

export { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH }