Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/browserServerBackend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import { Response } from './response.js';
import { SessionLog } from './sessionLog.js';
import { filteredTools } from './tools.js';
import { packageJSON } from './utils/package.js';
import { toToolDefinition } from './tools/tool.js';
import { toMcpTool } from './tools/tool.js';

import type { Tool } from './tools/tool.js';
import type { BrowserContextFactory } from './browserContextFactory.js';
Expand Down Expand Up @@ -64,12 +64,14 @@ export class BrowserServerBackend implements ServerBackend {
});
}

tools(): mcpServer.ToolDefinition[] {
return this._tools.map(tool => toToolDefinition(tool.schema));
async listTools(): Promise<mcpServer.Tool[]> {
return this._tools.map(tool => toMcpTool(tool.schema));
}

async callTool(name: string, rawArguments: any) {
async callTool(name: string, rawArguments: mcpServer.CallToolRequest['params']['arguments']) {
const tool = this._tools.find(tool => tool.schema.name === name)!;
if (!tool)
throw new Error(`Tool "${name}" not found`);
const parsedArguments = tool.schema.inputSchema.parse(rawArguments || {});
const context = this._context!;
const response = new Response(context, name, parsedArguments);
Expand Down
45 changes: 0 additions & 45 deletions src/inProcessMcpFactrory.ts

This file was deleted.

2 changes: 1 addition & 1 deletion src/loopTools/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export class Context {
return new Context(config, client);
}

async runTask(task: string, oneShot: boolean = false): Promise<mcpServer.ToolResponse> {
async runTask(task: string, oneShot: boolean = false): Promise<mcpServer.CallToolResult> {
const messages = await runTask(this._delegate, this._client!, task, oneShot);
const lines: string[] = [];

Expand Down
10 changes: 5 additions & 5 deletions src/loopTools/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import { packageJSON } from '../utils/package.js';
import { Context } from './context.js';
import { perform } from './perform.js';
import { snapshot } from './snapshot.js';
import { toToolDefinition } from '../tools/tool.js';
import { toMcpTool } from '../tools/tool.js';

import type { FullConfig } from '../config.js';
import type { ServerBackend } from '../mcp/server.js';
Expand All @@ -49,13 +49,13 @@ class LoopToolsServerBackend implements ServerBackend {
this._context = await Context.create(this._config);
}

tools(): mcpServer.ToolDefinition[] {
return this._tools.map(tool => toToolDefinition(tool.schema));
async listTools(): Promise<mcpServer.Tool[]> {
return this._tools.map(tool => toMcpTool(tool.schema));
}

async callTool(name: string, rawArguments: any): Promise<mcpServer.ToolResponse> {
async callTool(name: string, args: mcpServer.CallToolRequest['params']['arguments']): Promise<mcpServer.CallToolResult> {
const tool = this._tools.find(tool => tool.schema.name === name)!;
const parsedArguments = tool.schema.inputSchema.parse(rawArguments || {});
const parsedArguments = tool.schema.inputSchema.parse(args || {});
return await tool.handle(this._context!, parsedArguments);
}

Expand Down
2 changes: 1 addition & 1 deletion src/loopTools/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import type { ToolSchema } from '../tools/tool.js';

export type Tool<Input extends z.Schema = z.Schema> = {
schema: ToolSchema<Input>;
handle: (context: Context, params: z.output<Input>) => Promise<mcpServer.ToolResponse>;
handle: (context: Context, params: z.output<Input>) => Promise<mcpServer.CallToolResult>;
};

export function defineTool<Input extends z.Schema>(tool: Tool<Input>): Tool<Input> {
Expand Down
84 changes: 38 additions & 46 deletions src/mcp/proxyBackend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,65 +24,67 @@ import { packageJSON } from '../utils/package.js';


import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
import type { ToolDefinition, ServerBackend, ToolResponse } from './server.js';
import type { ServerBackend } from './server.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { Root, Tool, CallToolResult, CallToolRequest } from '@modelcontextprotocol/sdk/types.js';

type NonEmptyArray<T> = [T, ...T[]];

export type MCPFactory = {
export type MCPProvider = {
name: string;
description: string;
create(): Promise<Transport>;
connect(): Promise<Transport>;
};

export type MCPFactoryList = NonEmptyArray<MCPFactory>;

export class ProxyBackend implements ServerBackend {
name = 'Playwright MCP Client Switcher';
version = packageJSON.version;

private _mcpFactories: MCPFactoryList;
private _mcpProviders: MCPProvider[];
private _currentClient: Client | undefined;
private _contextSwitchTool: ToolDefinition;
private _tools: ToolDefinition[] = [];
private _server: Server | undefined;
private _contextSwitchTool: Tool;
private _roots: Root[] = [];

constructor(clientFactories: MCPFactoryList) {
this._mcpFactories = clientFactories;
constructor(mcpProviders: MCPProvider[]) {
this._mcpProviders = mcpProviders;
this._contextSwitchTool = this._defineContextSwitchTool();
}

async initialize(server: Server): Promise<void> {
this._server = server;
await this._setCurrentClient(this._mcpFactories[0]);
const version = server.getClientVersion();
const capabilities = server.getClientCapabilities();
if (capabilities?.roots && version && clientsWithRoots.includes(version.name)) {
const { roots } = await server.listRoots();
this._roots = roots;
}

await this._setCurrentClient(this._mcpProviders[0]);
}

tools(): ToolDefinition[] {
if (this._mcpFactories.length === 1)
return this._tools;
async listTools(): Promise<Tool[]> {
const response = await this._currentClient!.listTools();
if (this._mcpProviders.length === 1)
return response.tools;
return [
...this._tools,
...response.tools,
this._contextSwitchTool,
];
}

async callTool(name: string, rawArguments: any): Promise<ToolResponse> {
async callTool(name: string, args: CallToolRequest['params']['arguments']): Promise<CallToolResult> {
if (name === this._contextSwitchTool.name)
return this._callContextSwitchTool(rawArguments);
const result = await this._currentClient!.callTool({
return this._callContextSwitchTool(args);
return await this._currentClient!.callTool({
name,
arguments: rawArguments,
});
return result as unknown as ToolResponse;
arguments: args,
}) as CallToolResult;
}

serverClosed?(): void {
void this._currentClient?.close().catch(logUnhandledError);
}

private async _callContextSwitchTool(params: any): Promise<ToolResponse> {
private async _callContextSwitchTool(params: any): Promise<CallToolResult> {
try {
const factory = this._mcpFactories.find(factory => factory.name === params.name);
const factory = this._mcpProviders.find(factory => factory.name === params.name);
if (!factory)
throw new Error('Unknown connection method: ' + params.name);

Expand All @@ -98,16 +100,16 @@ export class ProxyBackend implements ServerBackend {
}
}

private _defineContextSwitchTool(): ToolDefinition {
private _defineContextSwitchTool(): Tool {
return {
name: 'browser_connect',
description: [
'Connect to a browser using one of the available methods:',
...this._mcpFactories.map(factory => `- "${factory.name}": ${factory.description}`),
...this._mcpProviders.map(factory => `- "${factory.name}": ${factory.description}`),
].join('\n'),
inputSchema: zodToJsonSchema(z.object({
name: z.enum(this._mcpFactories.map(factory => factory.name) as [string, ...string[]]).default(this._mcpFactories[0].name).describe('The method to use to connect to the browser'),
}), { strictUnions: true }) as ToolDefinition['inputSchema'],
name: z.enum(this._mcpProviders.map(factory => factory.name) as [string, ...string[]]).default(this._mcpProviders[0].name).describe('The method to use to connect to the browser'),
}), { strictUnions: true }) as Tool['inputSchema'],
annotations: {
title: 'Connect to a browser context',
readOnlyHint: true,
Expand All @@ -116,7 +118,7 @@ export class ProxyBackend implements ServerBackend {
};
}

private async _setCurrentClient(factory: MCPFactory) {
private async _setCurrentClient(factory: MCPProvider) {
await this._currentClient?.close();
this._currentClient = undefined;

Expand All @@ -126,23 +128,13 @@ export class ProxyBackend implements ServerBackend {
listRoots: true,
},
});
client.setRequestHandler(ListRootsRequestSchema, async () => {
const clientName = this._server!.getClientVersion()?.name;
if (this._server!.getClientCapabilities()?.roots && (
clientName === 'Visual Studio Code' ||
clientName === 'Visual Studio Code - Insiders')) {
const { roots } = await this._server!.listRoots();
return { roots };
}
return { roots: [] };
});
client.setRequestHandler(ListRootsRequestSchema, () => ({ roots: this._roots }));
client.setRequestHandler(PingRequestSchema, () => ({}));

const transport = await factory.create();
const transport = await factory.connect();
await client.connect(transport);

this._currentClient = client;
const tools = await this._currentClient.listTools();
this._tools = tools.tools;
}
}

const clientsWithRoots = ['Visual Studio Code', 'Visual Studio Code - Insiders'];
38 changes: 10 additions & 28 deletions src/mcp/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,19 @@ import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprot
import { ManualPromise } from '../utils/manualPromise.js';
import { logUnhandledError } from '../utils/log.js';

import type { ImageContent, TextContent, Tool } from '@modelcontextprotocol/sdk/types.js';
import type { Tool, CallToolResult, CallToolRequest } from '@modelcontextprotocol/sdk/types.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export type { Server } from '@modelcontextprotocol/sdk/server/index.js';
export type { Tool, CallToolResult, CallToolRequest } from '@modelcontextprotocol/sdk/types.js';

const serverDebug = debug('pw:mcp:server');

export type ClientCapabilities = {
roots?: {
listRoots?: boolean
};
};

export type ToolResponse = {
content: (TextContent | ImageContent)[];
isError?: boolean;
};

export type ToolDefinition = Tool;

export interface ServerBackend {
name: string;
version: string;
initialize?(server: Server): Promise<void>;
tools(): ToolDefinition[];
callTool(name: string, rawArguments: any): Promise<ToolResponse>;
listTools(): Promise<Tool[]>;
callTool(name: string, args: CallToolRequest['params']['arguments']): Promise<CallToolResult>;
serverClosed?(): void;
}

Expand All @@ -66,7 +54,7 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser

server.setRequestHandler(ListToolsRequestSchema, async () => {
serverDebug('listTools');
const tools = backend.tools();
const tools = await backend.listTools();
return { tools };
});

Expand All @@ -80,19 +68,13 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
startHeartbeat(server);
}

const errorResult = (...messages: string[]) => ({
content: [{ type: 'text', text: '### Result\n' + messages.join('\n') }],
isError: true,
});
const tools = backend.tools();
const tool = tools.find(tool => tool.name === request.params.name);
if (!tool)
return errorResult(`Error: Tool "${request.params.name}" not found`);

try {
return await backend.callTool(tool.name, request.params.arguments || {});
return await backend.callTool(request.params.name, request.params.arguments || {});
} catch (error) {
return errorResult(String(error));
return {
content: [{ type: 'text', text: '### Result\n' + String(error) }],
isError: true,
};
}
});
addServerListener(server, 'initialized', () => {
Expand Down
Loading
Loading