diff --git a/packages/playwright/src/mcp/sdk/mdb.ts b/packages/playwright/src/mcp/sdk/mdb.ts index bf3108cac7022..b979a90416206 100644 --- a/packages/playwright/src/mcp/sdk/mdb.ts +++ b/packages/playwright/src/mcp/sdk/mdb.ts @@ -34,26 +34,27 @@ export class MDBBackend implements mcpServer.ServerBackend { private _stack: { client: Client, toolNames: string[], resultPromise: ManualPromise | undefined }[] = []; private _interruptPromise: ManualPromise | undefined; private _topLevelBackend: mcpServer.ServerBackend; - private _initialized = false; + private _roots: mcpServer.Root[] | undefined; constructor(topLevelBackend: mcpServer.ServerBackend) { this._topLevelBackend = topLevelBackend; } - async initialize(server: mcpServer.Server): Promise { - if (this._initialized) - return; - this._initialized = true; - const transport = await wrapInProcess(this._topLevelBackend); - await this._pushClient(transport); + async initialize(server: mcpServer.Server, clientVersion: mcpServer.ClientVersion, roots: mcpServer.Root[]): Promise { + if (!this._roots) + this._roots = roots; } async listTools(): Promise { - const response = await this._client().listTools(); + const client = await this._client(); + const response = await client.listTools(); return response.tools; } async callTool(name: string, args: mcpServer.CallToolRequest['params']['arguments']): Promise { + // Needs to go first to push the top-level tool first if missing. + await this._client(); + if (name === pushToolsSchema.name) return await this._pushTools(pushToolsSchema.inputSchema.parse(args || {})); @@ -65,28 +66,22 @@ export class MDBBackend implements mcpServer.ServerBackend { while (entry && !entry.toolNames.includes(name)) { mdbDebug('popping client from stack for ', name); this._stack.shift(); - await entry.client.close(); + await entry.client.close().catch(errorsDebug); entry = this._stack[0]; } if (!entry) throw new Error(`Tool ${name} not found in the tool stack`); + const client = await this._client(); const resultPromise = new ManualPromise(); entry.resultPromise = resultPromise; - this._client().callTool({ + client.callTool({ name, arguments: args, }).then(result => { resultPromise.resolve(result as mcpServer.CallToolResult); - }).catch(e => { - mdbDebug('error in client call', e); - if (this._stack.length < 2) - throw e; - this._stack.shift(); - const prevEntry = this._stack[0]; - void prevEntry.resultPromise!.then(result => resultPromise.resolve(result)); - }); + }).catch(e => resultPromise.reject(e)); const result = await Promise.race([interruptPromise, resultPromise]); if (interruptPromise.isDone()) mdbDebug('client call intercepted', result); @@ -95,11 +90,12 @@ export class MDBBackend implements mcpServer.ServerBackend { return result; } - private _client(): Client { - const [entry] = this._stack; - if (!entry) - throw new Error('No debugging backend available'); - return entry.client; + private async _client(): Promise { + if (!this._stack.length) { + const transport = await wrapInProcess(this._topLevelBackend); + await this._pushClient(transport); + } + return this._stack[0].client; } private async _pushTools(params: { mcpUrl: string, introMessage?: string }): Promise { @@ -111,7 +107,8 @@ export class MDBBackend implements mcpServer.ServerBackend { private async _pushClient(transport: Transport, introMessage?: string): Promise { mdbDebug('pushing client to the stack'); - const client = new mcpBundle.Client({ name: 'Internal client', version: '0.0.0' }); + const client = new mcpBundle.Client({ name: 'Internal client', version: '0.0.0' }, { capabilities: { roots: {} } }); + client.setRequestHandler(mcpBundle.ListRootsRequestSchema, () => ({ roots: this._roots || [] })); client.setRequestHandler(mcpBundle.PingRequestSchema, () => ({})); await client.connect(transport); mdbDebug('connected to the new client'); @@ -141,10 +138,6 @@ const pushToolsSchema = defineToolSchema({ type: 'readOnly', }); -export type ServerBackendOnPause = mcpServer.ServerBackend & { - requestSelfDestruct?: () => void; -}; - export async function runMainBackend(backendFactory: mcpServer.ServerBackendFactory, options?: { port?: number }): Promise { const mdbBackend = new MDBBackend(backendFactory.create()); // Start HTTP unconditionally. @@ -162,8 +155,8 @@ export async function runMainBackend(backendFactory: mcpServer.ServerBackendFact await mcpServer.connect(factory, new mcpBundle.StdioServerTransport(), false); } -export async function runOnPauseBackendLoop(backend: ServerBackendOnPause, introMessage: string) { - const wrappedBackend = new OnceTimeServerBackendWrapper(backend); +export async function runOnPauseBackendLoop(backend: mcpServer.ServerBackend, introMessage: string) { + const wrappedBackend = new ServerBackendWithCloseListener(backend); const factory = { name: 'on-pause-backend', @@ -204,13 +197,12 @@ async function startAsHttp(backendFactory: mcpServer.ServerBackendFactory, optio } -class OnceTimeServerBackendWrapper implements mcpServer.ServerBackend { - private _backend: ServerBackendOnPause; - private _selfDestructPromise = new ManualPromise(); +class ServerBackendWithCloseListener implements mcpServer.ServerBackend { + private _backend: mcpServer.ServerBackend; + private _serverClosedPromise = new ManualPromise(); - constructor(backend: ServerBackendOnPause) { + constructor(backend: mcpServer.ServerBackend) { this._backend = backend; - this._backend.requestSelfDestruct = () => this._selfDestructPromise.resolve(); } async initialize(server: mcpServer.Server, clientVersion: mcpServer.ClientVersion, roots: mcpServer.Root[]): Promise { @@ -227,10 +219,10 @@ class OnceTimeServerBackendWrapper implements mcpServer.ServerBackend { serverClosed(server: mcpServer.Server) { this._backend.serverClosed?.(server); - this._selfDestructPromise.resolve(); + this._serverClosedPromise.resolve(); } async waitForClosed() { - await this._selfDestructPromise; + await this._serverClosedPromise; } } diff --git a/packages/playwright/src/mcp/sdk/proxyBackend.ts b/packages/playwright/src/mcp/sdk/proxyBackend.ts index de55736947a39..c80b30062f57f 100644 --- a/packages/playwright/src/mcp/sdk/proxyBackend.ts +++ b/packages/playwright/src/mcp/sdk/proxyBackend.ts @@ -45,11 +45,11 @@ export class ProxyBackend implements ServerBackend { async initialize(server: Server, clientVersion: ClientVersion, roots: Root[]): Promise { this._roots = roots; - await this._setCurrentClient(this._mcpProviders[0]); } async listTools(): Promise { - const response = await this._currentClient!.listTools(); + const currentClient = await this._ensureCurrentClient(); + const response = await currentClient.listTools(); if (this._mcpProviders.length === 1) return response.tools; return [ @@ -61,7 +61,8 @@ export class ProxyBackend implements ServerBackend { async callTool(name: string, args: CallToolRequest['params']['arguments']): Promise { if (name === this._contextSwitchTool.name) return this._callContextSwitchTool(args); - return await this._currentClient!.callTool({ + const currentClient = await this._ensureCurrentClient(); + return await currentClient.callTool({ name, arguments: args, }) as CallToolResult; @@ -107,6 +108,12 @@ export class ProxyBackend implements ServerBackend { }; } + private async _ensureCurrentClient(): Promise { + if (this._currentClient) + return this._currentClient; + return await this._setCurrentClient(this._mcpProviders[0]); + } + private async _setCurrentClient(factory: MCPProvider) { await this._currentClient?.close(); this._currentClient = undefined; @@ -123,5 +130,6 @@ export class ProxyBackend implements ServerBackend { const transport = await factory.connect(); await client.connect(transport); this._currentClient = client; + return client; } } diff --git a/packages/playwright/src/mcp/sdk/server.ts b/packages/playwright/src/mcp/sdk/server.ts index a50f2d80f5a65..8e54afd91c266 100644 --- a/packages/playwright/src/mcp/sdk/server.ts +++ b/packages/playwright/src/mcp/sdk/server.ts @@ -27,7 +27,6 @@ export type { Tool, CallToolResult, CallToolRequest, Root } from '@modelcontextp import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; const serverDebug = debug('pw:mcp:server'); -const errorsDebug = debug('pw:mcp:errors'); export type ClientVersion = { name: string, version: string }; @@ -56,8 +55,6 @@ export async function wrapInProcess(backend: ServerBackend): Promise } export function createServer(name: string, version: string, backend: ServerBackend, runHeartbeat: boolean): Server { - let initializedPromiseResolve = () => {}; - const initializedPromise = new Promise(resolve => initializedPromiseResolve = resolve); const server = new mcpBundle.Server({ name, version }, { capabilities: { tools: {}, @@ -66,22 +63,17 @@ export function createServer(name: string, version: string, backend: ServerBacke server.setRequestHandler(mcpBundle.ListToolsRequestSchema, async () => { serverDebug('listTools'); - await initializedPromise; const tools = await backend.listTools(); return { tools }; }); - let heartbeatRunning = false; + let initializePromise: Promise | undefined; server.setRequestHandler(mcpBundle.CallToolRequestSchema, async request => { serverDebug('callTool', request); - await initializedPromise; - - if (runHeartbeat && !heartbeatRunning) { - heartbeatRunning = true; - startHeartbeat(server); - } - try { + if (!initializePromise) + initializePromise = initializeServer(server, backend, runHeartbeat); + await initializePromise; return await backend.callTool(request.params.name, request.params.arguments || {}); } catch (error) { return { @@ -90,34 +82,23 @@ export function createServer(name: string, version: string, backend: ServerBacke }; } }); - addServerListener(server, 'initialized', async () => { - try { - const capabilities = server.getClientCapabilities(); - let clientRoots: Root[] = []; - if (capabilities?.roots) { - for (let i = 0; i < 2; i++) { - try { - // In the @modelcontextprotocol TypeScript SDK (and Cursor) in the streaming http - // mode, the SSE channel is not ready yet, when `initialized` notification arrives, - // `listRoots` times out in that case and we retry once. - const { roots } = await server.listRoots(undefined, { timeout: 2_000 }); - clientRoots = roots; - } catch (e) { - continue; - } - } - } - const clientVersion = server.getClientVersion() ?? { name: 'unknown', version: 'unknown' }; - await backend.initialize?.(server, clientVersion, clientRoots); - initializedPromiseResolve(); - } catch (e) { - errorsDebug(e); - } - }); addServerListener(server, 'close', () => backend.serverClosed?.(server)); return server; } +const initializeServer = async (server: Server, backend: ServerBackend, runHeartbeat: boolean) => { + const capabilities = server.getClientCapabilities(); + let clientRoots: Root[] = []; + if (capabilities?.roots) { + const { roots } = await server.listRoots(); + clientRoots = roots; + } + const clientVersion = server.getClientVersion() ?? { name: 'unknown', version: 'unknown' }; + await backend.initialize?.(server, clientVersion, clientRoots); + if (runHeartbeat) + startHeartbeat(server); +}; + const startHeartbeat = (server: Server) => { const beat = () => { Promise.race([ diff --git a/tests/mcp/fixtures.ts b/tests/mcp/fixtures.ts index 9045ca2c3d45a..49f6cd6632d80 100644 --- a/tests/mcp/fixtures.ts +++ b/tests/mcp/fixtures.ts @@ -188,7 +188,7 @@ async function createTransport(mcpServerType: TestOptions['mcpServerType'], args stderr: 'pipe', env: { ...process.env, - DEBUG: 'pw:mcp:test', + DEBUG: process.env.DEBUG ? `${process.env.DEBUG},pw:mcp:test` : 'pw:mcp:test', DEBUG_COLORS: '0', DEBUG_HIDE_DATE: '1', PWMCP_PROFILES_DIR_FOR_TEST: profilesDir, diff --git a/tests/mcp/http.spec.ts b/tests/mcp/http.spec.ts index f533dd8dc88ae..4a498072b12d5 100644 --- a/tests/mcp/http.spec.ts +++ b/tests/mcp/http.spec.ts @@ -254,7 +254,7 @@ test('http transport (default)', async ({ serverEndpoint }) => { expect(transport.sessionId, 'has session support').toBeDefined(); }); -test('client should receive list roots request', async ({ serverEndpoint }) => { +test('client should receive list roots request', async ({ serverEndpoint, server }) => { const { url } = await serverEndpoint(); const transport = new StreamableHTTPClientTransport(url); const client = new Client({ name: 'test', version: '1.0.0' }, { capabilities: { roots: {} } }); @@ -275,6 +275,9 @@ test('client should receive list roots request', async ({ serverEndpoint }) => { }; }); await client.connect(transport); - await client.ping(); + await client.callTool({ + name: 'browser_navigate', + arguments: { url: server.HELLO_WORLD }, + }); expect(await rootsListedPromise).toBe('success'); }); diff --git a/tests/mcp/mdb.spec.ts b/tests/mcp/mdb.spec.ts index 74f240385d067..4ec782eb8a4ef 100644 --- a/tests/mcp/mdb.spec.ts +++ b/tests/mcp/mdb.spec.ts @@ -20,6 +20,7 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { runMainBackend, runOnPauseBackendLoop } from '../../packages/playwright/lib/mcp/sdk/mdb'; +import * as mcpBundle from '../../packages/playwright/lib/mcp/sdk/bundle'; import { test, expect } from './fixtures'; @@ -40,8 +41,7 @@ test('call top level tool', async () => { name: 'cli_pause_in_gdb_twice', description: 'Pause in gdb twice', inputSchema: expect.any(Object), - } - ]); + }]); const echoResult = await mdbClient.client.callTool({ name: 'cli_echo', @@ -49,7 +49,7 @@ test('call top level tool', async () => { message: 'Hello, world!', }, }); - expect(echoResult.content).toEqual([{ type: 'text', text: 'Echo: Hello, world!' }]); + expect(echoResult.content).toEqual([{ type: 'text', text: 'Echo: Hello, world!', roots: [] }]); await mdbClient.close(); }); @@ -72,7 +72,7 @@ test('pause on error', async () => { name: 'gdb_bt', }), expect.objectContaining({ - name: 'gdb_continue', + name: 'gdb_echo', }), ]); @@ -83,39 +83,43 @@ test('pause on error', async () => { }); expect(btResult.content).toEqual([{ type: 'text', text: 'Backtrace' }]); - // Continue execution. - const continueResult = await mdbClient.client.callTool({ - name: 'gdb_continue', - arguments: {}, - }); - expect(continueResult.content).toEqual([{ type: 'text', text: 'Done' }]); - await mdbClient.close(); }); -test('pause on error twice', async () => { +test('outer and inner roots available', async () => { const { mdbUrl } = await startMDBAndCLI(); - const mdbClient = await createMDBClient(mdbUrl); + const mdbClient = await createMDBClient(mdbUrl, [{ name: 'test', uri: 'file://tmp/' }]); - // Make a call that results in a recoverable error. - const result = await mdbClient.client.callTool({ - name: 'cli_pause_in_gdb_twice', - arguments: {}, + expect(await mdbClient.client.callTool({ + name: 'cli_echo', + arguments: { + message: 'Hello, cli!', + }, + })).toEqual({ + content: [{ + type: 'text', + text: 'Echo: Hello, cli!', + roots: [{ name: 'test', uri: 'file://tmp/' }] + }] }); - expect(result.content).toEqual([{ type: 'text', text: 'Paused on exception 1' }]); - // Continue execution. - const continueResult1 = await mdbClient.client.callTool({ - name: 'gdb_continue', + await mdbClient.client.callTool({ + name: 'cli_pause_in_gdb', arguments: {}, }); - expect(continueResult1.content).toEqual([{ type: 'text', text: 'Paused on exception 2' }]); - const continueResult2 = await mdbClient.client.callTool({ - name: 'gdb_continue', - arguments: {}, + expect(await mdbClient.client.callTool({ + name: 'gdb_echo', + arguments: { + message: 'Hello, bt!', + }, + })).toEqual({ + content: [{ + type: 'text', + text: 'Echo: Hello, bt!', + roots: [{ name: 'test', uri: 'file://tmp/' }] + }] }); - expect(continueResult2.content).toEqual([{ type: 'text', text: 'Done' }]); await mdbClient.close(); }); @@ -134,8 +138,10 @@ async function startMDBAndCLI(): Promise<{ mdbUrl: string }> { return { mdbUrl }; } -async function createMDBClient(mdbUrl: string): Promise<{ client: Client, close: () => Promise }> { - const client = new Client({ name: 'Internal client', version: '0.0.0' }); +async function createMDBClient(mdbUrl: string, roots: any[] | undefined = undefined): Promise<{ client: Client, close: () => Promise }> { + const client = new Client({ name: 'Internal client', version: '0.0.0' }, roots ? { capabilities: { roots: {} } } : undefined); + if (roots) + client.setRequestHandler(mcpBundle.ListRootsRequestSchema, () => ({ roots })); const transport = new StreamableHTTPClientTransport(new URL(mdbUrl)); await client.connect(transport); return { @@ -148,8 +154,14 @@ async function createMDBClient(mdbUrl: string): Promise<{ client: Client, close: } class CLIBackend { + private _roots: any[] | undefined; + constructor(private readonly mdbUrlBox: { mdbUrl: string | undefined }) {} + async initialize(server, clientVersion, roots) { + this._roots = roots; + } + async listTools() { return [{ name: 'cli_echo', @@ -168,7 +180,7 @@ class CLIBackend { async callTool(name: string, args: any) { if (name === 'cli_echo') - return { content: [{ type: 'text', text: 'Echo: ' + (args?.message as string) }] }; + return { content: [{ type: 'text', text: 'Echo: ' + (args?.message as string), roots: this._roots }] }; if (name === 'cli_pause_in_gdb') { await runOnPauseBackendLoop(new GDBBackend(), 'Paused on exception'); return { content: [{ type: 'text', text: 'Done' }] }; @@ -183,26 +195,29 @@ class CLIBackend { } class GDBBackend { + private _roots: any[] | undefined; + + async initialize(server, clientVersion, roots) { + this._roots = roots; + } + async listTools() { return [{ name: 'gdb_bt', description: 'Print backtrace', inputSchema: zodToJsonSchema(z.object({})) as any, }, { - name: 'gdb_continue', - description: 'Continue execution', - inputSchema: zodToJsonSchema(z.object({})) as any, + name: 'gdb_echo', + description: 'Echo a message', + inputSchema: zodToJsonSchema(z.object({ message: z.string() })) as any, }]; } - async callTool(name: string) { + async callTool(name: string, args: any) { + if (name === 'gdb_echo') + return { content: [{ type: 'text', text: 'Echo: ' + (args?.message as string), roots: this._roots }] }; if (name === 'gdb_bt') return { content: [{ type: 'text', text: 'Backtrace' }] }; - if (name === 'gdb_continue') { - (this as any).requestSelfDestruct?.(); - // Stall - await new Promise(f => setTimeout(f, 1000)); - } throw new Error(`Unknown tool: ${name}`); } } diff --git a/tests/mcp/test-tools.spec.ts b/tests/mcp/test-tools.spec.ts index 7995ff6b252b7..187264e89448c 100644 --- a/tests/mcp/test-tools.spec.ts +++ b/tests/mcp/test-tools.spec.ts @@ -215,6 +215,47 @@ test('playwright_test_browser_snapshot', async ({ startClient }) => { }); }); +test('playwright_test_debug_test (pause/snapshot/resume)', async ({ startClient }) => { + const { client, id } = await prepareDebugTest(startClient); + + expect(await client.callTool({ + name: 'playwright_test_debug_test', + arguments: { + test: { id, title: 'fail' }, + }, + })).toHaveTextResponse(`### Paused on error: +expect(locator).toBeVisible() failed + +Locator: getByRole('button', { name: 'Missing' }) +Expected: visible +Timeout: 1000ms +Error: element(s) not found + +Call log: + - Expect "toBeVisible" with timeout 1000ms + - waiting for getByRole('button', { name: 'Missing' }) + + +### Current page snapshot: +- button "Submit" [ref=e2] + +### Task +Try recovering from the error prior to continuing`); + + expect(await client.callTool({ + name: 'browser_snapshot', + })).toHaveResponse({ + pageState: expect.stringContaining(`- button \"Submit\" [ref=e2]`), + }); + + expect(await client.callTool({ + name: 'playwright_test_run_tests', + arguments: { + locations: ['a.test.ts'], + }, + })).toHaveTextResponse(expect.stringContaining(`1) [id=] a.test.ts:3:11 › fail`)); +}); + test('playwright_test_evaluate_on_pause', async ({ startClient }) => { const { client, id } = await prepareDebugTest(startClient); await client.callTool({