diff --git a/packages/playwright/src/mcp/browser/config.ts b/packages/playwright/src/mcp/browser/config.ts index 1e66feffd79fd..623976490e051 100644 --- a/packages/playwright/src/mcp/browser/config.ts +++ b/packages/playwright/src/mcp/browser/config.ts @@ -29,6 +29,7 @@ import type { Config, ToolCapability } from '../config'; import type { ClientInfo } from '../sdk/server'; export type CLIOptions = { + allowedHosts?: string[]; allowedOrigins?: string[]; blockedOrigins?: string[]; blockServiceWorkers?: boolean; @@ -217,6 +218,7 @@ export function configFromCLIOptions(cliOptions: CLIOptions): Config { server: { port: cliOptions.port, host: cliOptions.host, + allowedHosts: cliOptions.allowedHosts, }, capabilities: cliOptions.caps as ToolCapability[], network: { @@ -240,6 +242,7 @@ export function configFromCLIOptions(cliOptions: CLIOptions): Config { function configFromEnv(): Config { const options: CLIOptions = {}; + options.allowedHosts = commaSeparatedList(process.env.PLAYWRIGHT_MCP_ALLOWED_HOSTNAMES); options.allowedOrigins = semicolonSeparatedList(process.env.PLAYWRIGHT_MCP_ALLOWED_ORIGINS); options.blockedOrigins = semicolonSeparatedList(process.env.PLAYWRIGHT_MCP_BLOCKED_ORIGINS); options.blockServiceWorkers = envToBoolean(process.env.PLAYWRIGHT_MCP_BLOCK_SERVICE_WORKERS); diff --git a/packages/playwright/src/mcp/config.d.ts b/packages/playwright/src/mcp/config.d.ts index 66b015ce7f71a..1da086d4b4537 100644 --- a/packages/playwright/src/mcp/config.d.ts +++ b/packages/playwright/src/mcp/config.d.ts @@ -86,6 +86,12 @@ export type Config = { * The host to bind the server to. Default is localhost. Use 0.0.0.0 to bind to all interfaces. */ host?: string; + + /** + * The hosts this server is allowed to serve from. Defaults to the host server is bound to. + * This is not for CORS, but rather for the DNS rebinding protection. + */ + allowedHosts?: string[]; }, /** diff --git a/packages/playwright/src/mcp/program.ts b/packages/playwright/src/mcp/program.ts index 1af6c0e587ac2..45e1a3f05a414 100644 --- a/packages/playwright/src/mcp/program.ts +++ b/packages/playwright/src/mcp/program.ts @@ -28,7 +28,9 @@ import type { Command } from 'playwright-core/lib/utilsBundle'; import type { MCPProvider } from './sdk/proxyBackend'; export function decorateCommand(command: Command, version: string) { - command.option('--allowed-origins ', 'semicolon-separated list of origins to allow the browser to request. Default is to allow all.', semicolonSeparatedList) + command + .option('--allowed-hosts ', 'comma-separated list of hosts this server is allowed to serve from. Defaults to the host the server is bound to.', commaSeparatedList) + .option('--allowed-origins ', 'semicolon-separated list of origins to allow the browser to request. Default is to allow all.', semicolonSeparatedList) .option('--blocked-origins ', 'semicolon-separated list of origins to block the browser from requesting. Blocklist is evaluated before allowlist. If used without the allowlist, requests not matching the blocklist are still allowed.', semicolonSeparatedList) .option('--block-service-workers', 'block service workers') .option('--browser ', 'browser or chrome channel to use, possible values: chrome, firefox, webkit, msedge.') diff --git a/packages/playwright/src/mcp/sdk/http.ts b/packages/playwright/src/mcp/sdk/http.ts index 4d941d0b0dfa0..59fb59a8b1bb5 100644 --- a/packages/playwright/src/mcp/sdk/http.ts +++ b/packages/playwright/src/mcp/sdk/http.ts @@ -59,10 +59,27 @@ export function httpAddressToString(address: string | net.AddressInfo | null): s return `http://${resolvedHost}:${resolvedPort}`; } -export async function installHttpTransport(httpServer: http.Server, serverBackendFactory: ServerBackendFactory) { +export async function installHttpTransport(httpServer: http.Server, serverBackendFactory: ServerBackendFactory, allowedHosts?: string[]) { + const url = httpAddressToString(httpServer.address()); + const host = new URL(url).host; + allowedHosts = (allowedHosts || [host]).map(h => h.toLowerCase()); + const sseSessions = new Map(); const streamableSessions = new Map(); httpServer.on('request', async (req, res) => { + const host = req.headers.host?.toLowerCase(); + if (!host) { + res.statusCode = 400; + return res.end('Missing host'); + } + + // Prevent DNS evil.com -> localhost rebind. + if (!allowedHosts.includes(host)) { + // Access from the browser is forbidden. + res.statusCode = 403; + return res.end('Access is only allowed at ' + allowedHosts.join(', ')); + } + const url = new URL(`http://localhost${req.url}`); if (url.pathname === '/killkillkill' && req.method === 'GET') { res.statusCode = 200; diff --git a/packages/playwright/src/mcp/sdk/server.ts b/packages/playwright/src/mcp/sdk/server.ts index b95c5759253d0..6c1b1acfc76c0 100644 --- a/packages/playwright/src/mcp/sdk/server.ts +++ b/packages/playwright/src/mcp/sdk/server.ts @@ -139,15 +139,15 @@ function addServerListener(server: Server, event: 'close' | 'initialized', liste }; } -export async function start(serverBackendFactory: ServerBackendFactory, options: { host?: string; port?: number }) { +export async function start(serverBackendFactory: ServerBackendFactory, options: { host?: string; port?: number, allowedHosts?: string[] }) { if (options.port === undefined) { await connect(serverBackendFactory, new mcpBundle.StdioServerTransport(), false); return; } const httpServer = await startHttpServer(options); - await installHttpTransport(httpServer, serverBackendFactory); const url = httpAddressToString(httpServer.address()); + await installHttpTransport(httpServer, serverBackendFactory, options.allowedHosts); const mcpConfig: any = { mcpServers: { } }; mcpConfig.mcpServers[serverBackendFactory.nameInConfig] = { diff --git a/tests/mcp/http.spec.ts b/tests/mcp/http.spec.ts index af43cc4bd339b..01ef8ac4708b1 100644 --- a/tests/mcp/http.spec.ts +++ b/tests/mcp/http.spec.ts @@ -15,6 +15,7 @@ */ import fs from 'fs'; +import net from 'net'; import { ChildProcess, spawn } from 'child_process'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; @@ -348,3 +349,42 @@ test('client should receive list roots request', async ({ serverEndpoint, server }); expect(await rootsListedPromise).toBe('success'); }); + +test('should not allow rebinding to localhost', async ({ serverEndpoint }) => { + const { url } = await serverEndpoint(); + const response = await fetch(url.href.replace('localhost', '127.0.0.1')); + expect(response.status).toBe(403); + expect(await response.text()).toContain('Access is only allowed at localhost'); +}); + +test('should respect allowed hosts (negative)', async ({ serverEndpoint }) => { + const { url } = await serverEndpoint({ args: ['--allowed-hosts=example.com'] }); + const response = await fetch(url.href); + expect(response.status).toBe(403); + expect(await response.text()).toContain('Access is only allowed at example.com'); +}); + +test('should respect allowed hosts (positive)', async ({ serverEndpoint }) => { + const port = await findFreePort(); + await serverEndpoint({ + args: [ + '--host=127.0.0.1', + '--port=' + port, + '--allowed-hosts=localhost:' + port, + ] + }); + const response = await fetch('http://localhost:' + port); + // 400 is expected for the mcp fetch. + expect(response.status).toBe(400); +}); + +async function findFreePort(): Promise { + return new Promise((resolve, reject) => { + const server = net.createServer(); + server.listen(0, () => { + const { port } = server.address() as net.AddressInfo; + server.close(() => resolve(port)); + }); + server.on('error', reject); + }); +}