Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/browserContextFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ export function contextFactory(browserConfig: FullConfig['browser']): BrowserCon
return new PersistentContextFactory(browserConfig);
}

export type ClientInfo = { name: string, version: string };

export interface BrowserContextFactory {
createContext(clientInfo: { name: string, version: string }): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }>;
createContext(clientInfo: ClientInfo, abortSignal: AbortSignal): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }>;
}

class BaseContextFactory implements BrowserContextFactory {
Expand Down
2 changes: 0 additions & 2 deletions src/browserServerBackend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import type { Tool } from './tools/tool.js';
export class BrowserServerBackend implements ServerBackend {
name = 'Playwright';
version = packageJSON.version;
onclose?: () => void;

private _tools: Tool[];
private _context: Context | undefined;
Expand Down Expand Up @@ -75,7 +74,6 @@ export class BrowserServerBackend implements ServerBackend {
}

serverClosed() {
this.onclose?.();
void this._context!.dispose().catch(logUnhandledError);
}
}
4 changes: 3 additions & 1 deletion src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export class Context {
private static _allContexts: Set<Context> = new Set();
private _closeBrowserContextPromise: Promise<void> | undefined;
private _isRunningTool: boolean = false;
private _abortController = new AbortController();

constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory, sessionLog: SessionLog | undefined) {
this.tools = tools;
Expand Down Expand Up @@ -154,6 +155,7 @@ export class Context {
}

async dispose() {
this._abortController.abort('MCP context disposed');
await this.closeBrowserContext();
Context._allContexts.delete(this);
}
Expand Down Expand Up @@ -186,7 +188,7 @@ export class Context {
if (this._closeBrowserContextPromise)
throw new Error('Another browser context is being closed.');
// TODO: move to the browser context factory to make it based on isolation mode.
const result = await this._browserContextFactory.createContext(this.clientVersion!);
const result = await this._browserContextFactory.createContext(this.clientVersion!, this._abortController.signal);
const { browserContext } = result;
await this._setupRequestInterception(browserContext);
if (this.sessionLog)
Expand Down
70 changes: 14 additions & 56 deletions src/extension/cdpRelay.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@
* - /extension/guid - Extension connection for chrome.debugger forwarding
*/

import http from 'http';
import { spawn } from 'child_process';
import { WebSocket, WebSocketServer } from 'ws';
import http from 'http';
import debug from 'debug';
import * as playwright from 'playwright';
// @ts-ignore
const { registry } = await import('playwright-core/lib/server/registry/index');
import { httpAddressToString, startHttpServer } from '../httpServer.js';
import { WebSocket, WebSocketServer } from 'ws';
import { httpAddressToString } from '../httpServer.js';
import { logUnhandledError } from '../log.js';
import { ManualPromise } from '../manualPromise.js';
import type { BrowserContextFactory } from '../browserContextFactory.js';
import type websocket from 'ws';
import type { ClientInfo } from '../browserContextFactory.js';

// @ts-ignore
const { registry } = await import('playwright-core/lib/server/registry/index');

const debugLogger = debug('pw:mcp:relay');

Expand Down Expand Up @@ -90,17 +90,20 @@ export class CDPRelayServer {
return `${this._wsHost}${this._extensionPath}`;
}

async ensureExtensionConnectionForMCPContext(clientInfo: { name: string, version: string }) {
async ensureExtensionConnectionForMCPContext(clientInfo: ClientInfo, abortSignal: AbortSignal) {
debugLogger('Ensuring extension connection for MCP context');
if (this._extensionConnection)
return;
await this._connectBrowser(clientInfo);
this._connectBrowser(clientInfo);
debugLogger('Waiting for incoming extension connection');
await this._extensionConnectionPromise;
await Promise.race([
this._extensionConnectionPromise,
new Promise((_, reject) => abortSignal.addEventListener('abort', reject))
]);
debugLogger('Extension connection established');
}

private async _connectBrowser(clientInfo: { name: string, version: string }) {
private _connectBrowser(clientInfo: ClientInfo) {
const mcpRelayEndpoint = `${this._wsHost}${this._extensionPath}`;
// Need to specify "key" in the manifest.json to make the id stable when loading from file.
const url = new URL('chrome-extension://jakfalbnbhgkpmoaakfflhflbfpkailf/lib/ui/connect.html');
Expand Down Expand Up @@ -300,51 +303,6 @@ export class CDPRelayServer {
}
}

class ExtensionContextFactory implements BrowserContextFactory {
private _relay: CDPRelayServer;
private _browserPromise: Promise<playwright.Browser> | undefined;

constructor(relay: CDPRelayServer) {
this._relay = relay;
}

async createContext(clientInfo: { name: string, version: string }): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }> {
// First call will establish the connection to the extension.
if (!this._browserPromise)
this._browserPromise = this._obtainBrowser(clientInfo);
const browser = await this._browserPromise;
return {
browserContext: browser.contexts()[0],
close: async () => {
debugLogger('close() called for browser context, ignoring');
}
};
}

clientDisconnected() {
this._relay.closeConnections('MCP client disconnected');
this._browserPromise = undefined;
}

private async _obtainBrowser(clientInfo: { name: string, version: string }): Promise<playwright.Browser> {
await this._relay.ensureExtensionConnectionForMCPContext(clientInfo);
const browser = await playwright.chromium.connectOverCDP(this._relay.cdpEndpoint());
browser.on('disconnected', () => {
this._browserPromise = undefined;
debugLogger('Browser disconnected');
});
return browser;
}
}

export async function startCDPRelayServer(browserChannel: string, abortController: AbortController) {
const httpServer = await startHttpServer({});
const cdpRelayServer = new CDPRelayServer(httpServer, browserChannel);
abortController.signal.addEventListener('abort', () => cdpRelayServer.stop());
debugLogger(`CDP relay server started, extension endpoint: ${cdpRelayServer.extensionEndpoint()}.`);
return new ExtensionContextFactory(cdpRelayServer);
}

type ExtensionResponse = {
id?: number;
method?: string;
Expand Down
75 changes: 75 additions & 0 deletions src/extension/extensionContextFactory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import debug from 'debug';
import * as playwright from 'playwright';
import { startHttpServer } from '../httpServer.js';
import { CDPRelayServer } from './cdpRelay.js';

import type { BrowserContextFactory, ClientInfo } from '../browserContextFactory.js';

const debugLogger = debug('pw:mcp:relay');

export class ExtensionContextFactory implements BrowserContextFactory {
private _browserChannel: string;
private _relayPromise: Promise<CDPRelayServer> | undefined;
private _browserPromise: Promise<playwright.Browser> | undefined;

constructor(browserChannel: string) {
this._browserChannel = browserChannel;
}

async createContext(clientInfo: ClientInfo, abortSignal: AbortSignal): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }> {
// First call will establish the connection to the extension.
if (!this._browserPromise)
this._browserPromise = this._obtainBrowser(clientInfo, abortSignal);
const browser = await this._browserPromise;
return {
browserContext: browser.contexts()[0],
close: async () => {
debugLogger('close() called for browser context');
await browser.close();
this._browserPromise = undefined;
}
};
}

private async _obtainBrowser(clientInfo: ClientInfo, abortSignal: AbortSignal): Promise<playwright.Browser> {
if (!this._relayPromise)
this._relayPromise = this._startRelay(abortSignal);
const relay = await this._relayPromise;

abortSignal.throwIfAborted();
await relay.ensureExtensionConnectionForMCPContext(clientInfo, abortSignal);
const browser = await playwright.chromium.connectOverCDP(relay.cdpEndpoint());
browser.on('disconnected', () => {
this._browserPromise = undefined;
debugLogger('Browser disconnected');
});
return browser;
}

private async _startRelay(abortSignal: AbortSignal) {
const httpServer = await startHttpServer({});
const cdpRelayServer = new CDPRelayServer(httpServer, this._browserChannel);
debugLogger(`CDP relay server started, extension endpoint: ${cdpRelayServer.extensionEndpoint()}.`);
if (abortSignal.aborted)
cdpRelayServer.stop();
else
abortSignal.addEventListener('abort', () => cdpRelayServer.stop());
return cdpRelayServer;
}
}
20 changes: 4 additions & 16 deletions src/extension/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,14 @@
* limitations under the License.
*/

import { startCDPRelayServer } from './cdpRelay.js';
import { ExtensionContextFactory } from './extensionContextFactory.js';
import { BrowserServerBackend } from '../browserServerBackend.js';
import * as mcpTransport from '../mcp/transport.js';

import type { FullConfig } from '../config.js';

export async function runWithExtension(config: FullConfig, abortController: AbortController) {
const contextFactory = await startCDPRelayServer(config.browser.launchOptions.channel || 'chrome', abortController);

let backend: BrowserServerBackend | undefined;
const serverBackendFactory = () => {
if (backend)
throw new Error('Another MCP client is still connected. Only one connection at a time is allowed.');
backend = new BrowserServerBackend(config, contextFactory);
backend.onclose = () => {
contextFactory.clientDisconnected();
backend = undefined;
};
return backend;
};

export async function runWithExtension(config: FullConfig) {
const contextFactory = new ExtensionContextFactory(config.browser.launchOptions.channel || 'chrome');
const serverBackendFactory = () => new BrowserServerBackend(config, contextFactory);
await mcpTransport.start(serverBackendFactory, config.server);
}
9 changes: 2 additions & 7 deletions src/program.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ program
.addOption(new Option('--loop-tools', 'Run loop tools').hideHelp())
.addOption(new Option('--vision', 'Legacy option, use --caps=vision instead').hideHelp())
.action(async options => {
const abortController = setupExitWatchdog();
setupExitWatchdog();

if (options.vision) {
// eslint-disable-next-line no-console
Expand All @@ -69,7 +69,7 @@ program
const config = await resolveCLIConfig(options);

if (options.extension) {
await runWithExtension(config, abortController);
await runWithExtension(config);
return;
}
if (options.loopTools) {
Expand All @@ -91,24 +91,19 @@ program
});

function setupExitWatchdog() {
const abortController = new AbortController();

let isExiting = false;
const handleExit = async () => {
if (isExiting)
return;
isExiting = true;
setTimeout(() => process.exit(0), 15000);
abortController.abort('Process exiting');
await Context.disposeAll();
process.exit(0);
};

process.stdin.on('close', handleExit);
process.on('SIGINT', handleExit);
process.on('SIGTERM', handleExit);

return abortController;
}

void program.parseAsync(process.argv);
Loading