Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 29 additions & 37 deletions packages/playwright/src/mcp/sdk/mdb.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,27 @@ export class MDBBackend implements mcpServer.ServerBackend {
private _stack: { client: Client, toolNames: string[], resultPromise: ManualPromise<mcpServer.CallToolResult> | undefined }[] = [];
private _interruptPromise: ManualPromise<mcpServer.CallToolResult> | undefined;
private _topLevelBackend: mcpServer.ServerBackend;
private _initialized = false;
private _roots: mcpServer.Root[] | undefined;

constructor(topLevelBackend: mcpServer.ServerBackend) {
this._topLevelBackend = topLevelBackend;
}

async initialize(server: mcpServer.Server): Promise<void> {
if (this._initialized)
return;
this._initialized = true;
const transport = await wrapInProcess(this._topLevelBackend);
await this._pushClient(transport);
async initialize(server: mcpServer.Server, clientVersion: mcpServer.ClientVersion, roots: mcpServer.Root[]): Promise<void> {
if (!this._roots)
this._roots = roots;
}

async listTools(): Promise<mcpServer.Tool[]> {
const response = await this._client().listTools();
const client = await this._client();
const response = await client.listTools();
return response.tools;
}

async callTool(name: string, args: mcpServer.CallToolRequest['params']['arguments']): Promise<mcpServer.CallToolResult> {
// Needs to go first to push the top-level tool first if missing.
await this._client();

if (name === pushToolsSchema.name)
return await this._pushTools(pushToolsSchema.inputSchema.parse(args || {}));

Expand All @@ -65,28 +66,22 @@ export class MDBBackend implements mcpServer.ServerBackend {
while (entry && !entry.toolNames.includes(name)) {
mdbDebug('popping client from stack for ', name);
this._stack.shift();
await entry.client.close();
await entry.client.close().catch(errorsDebug);
entry = this._stack[0];
}
if (!entry)
throw new Error(`Tool ${name} not found in the tool stack`);

const client = await this._client();
const resultPromise = new ManualPromise<mcpServer.CallToolResult>();
entry.resultPromise = resultPromise;

this._client().callTool({
client.callTool({
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now the same as const resultPromise = client.callTool(...

name,
arguments: args,
}).then(result => {
resultPromise.resolve(result as mcpServer.CallToolResult);
}).catch(e => {
mdbDebug('error in client call', e);
if (this._stack.length < 2)
throw e;
this._stack.shift();
const prevEntry = this._stack[0];
void prevEntry.resultPromise!.then(result => resultPromise.resolve(result));
});
}).catch(e => resultPromise.reject(e));
const result = await Promise.race([interruptPromise, resultPromise]);
if (interruptPromise.isDone())
mdbDebug('client call intercepted', result);
Expand All @@ -95,11 +90,12 @@ export class MDBBackend implements mcpServer.ServerBackend {
return result;
}

private _client(): Client {
const [entry] = this._stack;
if (!entry)
throw new Error('No debugging backend available');
return entry.client;
private async _client(): Promise<Client> {
if (!this._stack.length) {
const transport = await wrapInProcess(this._topLevelBackend);
await this._pushClient(transport);
}
return this._stack[0].client;
}

private async _pushTools(params: { mcpUrl: string, introMessage?: string }): Promise<mcpServer.CallToolResult> {
Expand All @@ -111,7 +107,8 @@ export class MDBBackend implements mcpServer.ServerBackend {

private async _pushClient(transport: Transport, introMessage?: string): Promise<mcpServer.CallToolResult> {
mdbDebug('pushing client to the stack');
const client = new mcpBundle.Client({ name: 'Internal client', version: '0.0.0' });
const client = new mcpBundle.Client({ name: 'Internal client', version: '0.0.0' }, { capabilities: { roots: {} } });
client.setRequestHandler(mcpBundle.ListRootsRequestSchema, () => ({ roots: this._roots || [] }));
client.setRequestHandler(mcpBundle.PingRequestSchema, () => ({}));
await client.connect(transport);
mdbDebug('connected to the new client');
Expand Down Expand Up @@ -141,10 +138,6 @@ const pushToolsSchema = defineToolSchema({
type: 'readOnly',
});

export type ServerBackendOnPause = mcpServer.ServerBackend & {
requestSelfDestruct?: () => void;
};

export async function runMainBackend(backendFactory: mcpServer.ServerBackendFactory, options?: { port?: number }): Promise<string | undefined> {
const mdbBackend = new MDBBackend(backendFactory.create());
// Start HTTP unconditionally.
Expand All @@ -162,8 +155,8 @@ export async function runMainBackend(backendFactory: mcpServer.ServerBackendFact
await mcpServer.connect(factory, new mcpBundle.StdioServerTransport(), false);
}

export async function runOnPauseBackendLoop(backend: ServerBackendOnPause, introMessage: string) {
const wrappedBackend = new OnceTimeServerBackendWrapper(backend);
export async function runOnPauseBackendLoop(backend: mcpServer.ServerBackend, introMessage: string) {
const wrappedBackend = new ServerBackendWithCloseListener(backend);

const factory = {
name: 'on-pause-backend',
Expand Down Expand Up @@ -204,13 +197,12 @@ async function startAsHttp(backendFactory: mcpServer.ServerBackendFactory, optio
}


class OnceTimeServerBackendWrapper implements mcpServer.ServerBackend {
private _backend: ServerBackendOnPause;
private _selfDestructPromise = new ManualPromise<void>();
class ServerBackendWithCloseListener implements mcpServer.ServerBackend {
private _backend: mcpServer.ServerBackend;
private _serverClosedPromise = new ManualPromise<void>();

constructor(backend: ServerBackendOnPause) {
constructor(backend: mcpServer.ServerBackend) {
this._backend = backend;
this._backend.requestSelfDestruct = () => this._selfDestructPromise.resolve();
}

async initialize(server: mcpServer.Server, clientVersion: mcpServer.ClientVersion, roots: mcpServer.Root[]): Promise<void> {
Expand All @@ -227,10 +219,10 @@ class OnceTimeServerBackendWrapper implements mcpServer.ServerBackend {

serverClosed(server: mcpServer.Server) {
this._backend.serverClosed?.(server);
this._selfDestructPromise.resolve();
this._serverClosedPromise.resolve();
}

async waitForClosed() {
await this._selfDestructPromise;
await this._serverClosedPromise;
}
}
14 changes: 11 additions & 3 deletions packages/playwright/src/mcp/sdk/proxyBackend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ export class ProxyBackend implements ServerBackend {

async initialize(server: Server, clientVersion: ClientVersion, roots: Root[]): Promise<void> {
this._roots = roots;
await this._setCurrentClient(this._mcpProviders[0]);
}

async listTools(): Promise<Tool[]> {
const response = await this._currentClient!.listTools();
const currentClient = await this._ensureCurrentClient();
const response = await currentClient.listTools();
if (this._mcpProviders.length === 1)
return response.tools;
return [
Expand All @@ -61,7 +61,8 @@ export class ProxyBackend implements ServerBackend {
async callTool(name: string, args: CallToolRequest['params']['arguments']): Promise<CallToolResult> {
if (name === this._contextSwitchTool.name)
return this._callContextSwitchTool(args);
return await this._currentClient!.callTool({
const currentClient = await this._ensureCurrentClient();
return await currentClient.callTool({
name,
arguments: args,
}) as CallToolResult;
Expand Down Expand Up @@ -107,6 +108,12 @@ export class ProxyBackend implements ServerBackend {
};
}

private async _ensureCurrentClient(): Promise<Client> {
if (this._currentClient)
return this._currentClient;
return await this._setCurrentClient(this._mcpProviders[0]);
}

private async _setCurrentClient(factory: MCPProvider) {
await this._currentClient?.close();
this._currentClient = undefined;
Expand All @@ -123,5 +130,6 @@ export class ProxyBackend implements ServerBackend {
const transport = await factory.connect();
await client.connect(transport);
this._currentClient = client;
return client;
}
}
53 changes: 17 additions & 36 deletions packages/playwright/src/mcp/sdk/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ export type { Tool, CallToolResult, CallToolRequest, Root } from '@modelcontextp
import type { Server } from '@modelcontextprotocol/sdk/server/index.js';

const serverDebug = debug('pw:mcp:server');
const errorsDebug = debug('pw:mcp:errors');

export type ClientVersion = { name: string, version: string };

Expand Down Expand Up @@ -56,8 +55,6 @@ export async function wrapInProcess(backend: ServerBackend): Promise<Transport>
}

export function createServer(name: string, version: string, backend: ServerBackend, runHeartbeat: boolean): Server {
let initializedPromiseResolve = () => {};
const initializedPromise = new Promise<void>(resolve => initializedPromiseResolve = resolve);
const server = new mcpBundle.Server({ name, version }, {
capabilities: {
tools: {},
Expand All @@ -66,22 +63,17 @@ export function createServer(name: string, version: string, backend: ServerBacke

server.setRequestHandler(mcpBundle.ListToolsRequestSchema, async () => {
serverDebug('listTools');
await initializedPromise;
const tools = await backend.listTools();
return { tools };
});

let heartbeatRunning = false;
let initializePromise: Promise<void> | undefined;
server.setRequestHandler(mcpBundle.CallToolRequestSchema, async request => {
serverDebug('callTool', request);
await initializedPromise;

if (runHeartbeat && !heartbeatRunning) {
heartbeatRunning = true;
startHeartbeat(server);
}

try {
if (!initializePromise)
initializePromise = initializeServer(server, backend, runHeartbeat);
await initializePromise;
return await backend.callTool(request.params.name, request.params.arguments || {});
} catch (error) {
return {
Expand All @@ -90,34 +82,23 @@ export function createServer(name: string, version: string, backend: ServerBacke
};
}
});
addServerListener(server, 'initialized', async () => {
try {
const capabilities = server.getClientCapabilities();
let clientRoots: Root[] = [];
if (capabilities?.roots) {
for (let i = 0; i < 2; i++) {
try {
// In the @modelcontextprotocol TypeScript SDK (and Cursor) in the streaming http
// mode, the SSE channel is not ready yet, when `initialized` notification arrives,
// `listRoots` times out in that case and we retry once.
const { roots } = await server.listRoots(undefined, { timeout: 2_000 });
clientRoots = roots;
} catch (e) {
continue;
}
}
}
const clientVersion = server.getClientVersion() ?? { name: 'unknown', version: 'unknown' };
await backend.initialize?.(server, clientVersion, clientRoots);
initializedPromiseResolve();
} catch (e) {
errorsDebug(e);
}
});
addServerListener(server, 'close', () => backend.serverClosed?.(server));
return server;
}

const initializeServer = async (server: Server, backend: ServerBackend, runHeartbeat: boolean) => {
const capabilities = server.getClientCapabilities();
let clientRoots: Root[] = [];
if (capabilities?.roots) {
const { roots } = await server.listRoots();
clientRoots = roots;
}
const clientVersion = server.getClientVersion() ?? { name: 'unknown', version: 'unknown' };
await backend.initialize?.(server, clientVersion, clientRoots);
if (runHeartbeat)
startHeartbeat(server);
};

const startHeartbeat = (server: Server) => {
const beat = () => {
Promise.race([
Expand Down
2 changes: 1 addition & 1 deletion tests/mcp/fixtures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ async function createTransport(mcpServerType: TestOptions['mcpServerType'], args
stderr: 'pipe',
env: {
...process.env,
DEBUG: 'pw:mcp:test',
DEBUG: process.env.DEBUG ? `${process.env.DEBUG},pw:mcp:test` : 'pw:mcp:test',
DEBUG_COLORS: '0',
DEBUG_HIDE_DATE: '1',
PWMCP_PROFILES_DIR_FOR_TEST: profilesDir,
Expand Down
7 changes: 5 additions & 2 deletions tests/mcp/http.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ test('http transport (default)', async ({ serverEndpoint }) => {
expect(transport.sessionId, 'has session support').toBeDefined();
});

test('client should receive list roots request', async ({ serverEndpoint }) => {
test('client should receive list roots request', async ({ serverEndpoint, server }) => {
const { url } = await serverEndpoint();
const transport = new StreamableHTTPClientTransport(url);
const client = new Client({ name: 'test', version: '1.0.0' }, { capabilities: { roots: {} } });
Expand All @@ -275,6 +275,9 @@ test('client should receive list roots request', async ({ serverEndpoint }) => {
};
});
await client.connect(transport);
await client.ping();
await client.callTool({
name: 'browser_navigate',
arguments: { url: server.HELLO_WORLD },
});
expect(await rootsListedPromise).toBe('success');
});
Loading
Loading