diff --git a/src/browserContextFactory.ts b/src/browserContextFactory.ts index 7e17b7865..470ead51c 100644 --- a/src/browserContextFactory.ts +++ b/src/browserContextFactory.ts @@ -35,8 +35,10 @@ export function contextFactory(browserConfig: FullConfig['browser']): BrowserCon return new PersistentContextFactory(browserConfig); } +export type ClientInfo = { name: string, version: string }; + export interface BrowserContextFactory { - createContext(clientInfo: { name: string, version: string }): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise }>; + createContext(clientInfo: ClientInfo, abortSignal: AbortSignal): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise }>; } class BaseContextFactory implements BrowserContextFactory { diff --git a/src/browserServerBackend.ts b/src/browserServerBackend.ts index 7f5fda515..a8cd7c36c 100644 --- a/src/browserServerBackend.ts +++ b/src/browserServerBackend.ts @@ -30,7 +30,6 @@ import type { Tool } from './tools/tool.js'; export class BrowserServerBackend implements ServerBackend { name = 'Playwright'; version = packageJSON.version; - onclose?: () => void; private _tools: Tool[]; private _context: Context | undefined; @@ -75,7 +74,6 @@ export class BrowserServerBackend implements ServerBackend { } serverClosed() { - this.onclose?.(); void this._context!.dispose().catch(logUnhandledError); } } diff --git a/src/context.ts b/src/context.ts index e9de93a77..5dfc5a1ed 100644 --- a/src/context.ts +++ b/src/context.ts @@ -42,6 +42,7 @@ export class Context { private static _allContexts: Set = new Set(); private _closeBrowserContextPromise: Promise | undefined; private _isRunningTool: boolean = false; + private _abortController = new AbortController(); constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory, sessionLog: SessionLog | undefined) { this.tools = tools; @@ -154,6 +155,7 @@ export class Context { } async dispose() { + this._abortController.abort('MCP context disposed'); await this.closeBrowserContext(); Context._allContexts.delete(this); } @@ -186,7 +188,7 @@ export class Context { if (this._closeBrowserContextPromise) throw new Error('Another browser context is being closed.'); // TODO: move to the browser context factory to make it based on isolation mode. - const result = await this._browserContextFactory.createContext(this.clientVersion!); + const result = await this._browserContextFactory.createContext(this.clientVersion!, this._abortController.signal); const { browserContext } = result; await this._setupRequestInterception(browserContext); if (this.sessionLog) diff --git a/src/extension/cdpRelay.ts b/src/extension/cdpRelay.ts index 324acfcdf..7a6c55125 100644 --- a/src/extension/cdpRelay.ts +++ b/src/extension/cdpRelay.ts @@ -22,18 +22,18 @@ * - /extension/guid - Extension connection for chrome.debugger forwarding */ -import http from 'http'; import { spawn } from 'child_process'; -import { WebSocket, WebSocketServer } from 'ws'; +import http from 'http'; import debug from 'debug'; -import * as playwright from 'playwright'; -// @ts-ignore -const { registry } = await import('playwright-core/lib/server/registry/index'); -import { httpAddressToString, startHttpServer } from '../httpServer.js'; +import { WebSocket, WebSocketServer } from 'ws'; +import { httpAddressToString } from '../httpServer.js'; import { logUnhandledError } from '../log.js'; import { ManualPromise } from '../manualPromise.js'; -import type { BrowserContextFactory } from '../browserContextFactory.js'; import type websocket from 'ws'; +import type { ClientInfo } from '../browserContextFactory.js'; + +// @ts-ignore +const { registry } = await import('playwright-core/lib/server/registry/index'); const debugLogger = debug('pw:mcp:relay'); @@ -90,17 +90,20 @@ export class CDPRelayServer { return `${this._wsHost}${this._extensionPath}`; } - async ensureExtensionConnectionForMCPContext(clientInfo: { name: string, version: string }) { + async ensureExtensionConnectionForMCPContext(clientInfo: ClientInfo, abortSignal: AbortSignal) { debugLogger('Ensuring extension connection for MCP context'); if (this._extensionConnection) return; - await this._connectBrowser(clientInfo); + this._connectBrowser(clientInfo); debugLogger('Waiting for incoming extension connection'); - await this._extensionConnectionPromise; + await Promise.race([ + this._extensionConnectionPromise, + new Promise((_, reject) => abortSignal.addEventListener('abort', reject)) + ]); debugLogger('Extension connection established'); } - private async _connectBrowser(clientInfo: { name: string, version: string }) { + private _connectBrowser(clientInfo: ClientInfo) { const mcpRelayEndpoint = `${this._wsHost}${this._extensionPath}`; // Need to specify "key" in the manifest.json to make the id stable when loading from file. const url = new URL('chrome-extension://jakfalbnbhgkpmoaakfflhflbfpkailf/lib/ui/connect.html'); @@ -300,51 +303,6 @@ export class CDPRelayServer { } } -class ExtensionContextFactory implements BrowserContextFactory { - private _relay: CDPRelayServer; - private _browserPromise: Promise | undefined; - - constructor(relay: CDPRelayServer) { - this._relay = relay; - } - - async createContext(clientInfo: { name: string, version: string }): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise }> { - // First call will establish the connection to the extension. - if (!this._browserPromise) - this._browserPromise = this._obtainBrowser(clientInfo); - const browser = await this._browserPromise; - return { - browserContext: browser.contexts()[0], - close: async () => { - debugLogger('close() called for browser context, ignoring'); - } - }; - } - - clientDisconnected() { - this._relay.closeConnections('MCP client disconnected'); - this._browserPromise = undefined; - } - - private async _obtainBrowser(clientInfo: { name: string, version: string }): Promise { - await this._relay.ensureExtensionConnectionForMCPContext(clientInfo); - const browser = await playwright.chromium.connectOverCDP(this._relay.cdpEndpoint()); - browser.on('disconnected', () => { - this._browserPromise = undefined; - debugLogger('Browser disconnected'); - }); - return browser; - } -} - -export async function startCDPRelayServer(browserChannel: string, abortController: AbortController) { - const httpServer = await startHttpServer({}); - const cdpRelayServer = new CDPRelayServer(httpServer, browserChannel); - abortController.signal.addEventListener('abort', () => cdpRelayServer.stop()); - debugLogger(`CDP relay server started, extension endpoint: ${cdpRelayServer.extensionEndpoint()}.`); - return new ExtensionContextFactory(cdpRelayServer); -} - type ExtensionResponse = { id?: number; method?: string; diff --git a/src/extension/extensionContextFactory.ts b/src/extension/extensionContextFactory.ts new file mode 100644 index 000000000..bb2360576 --- /dev/null +++ b/src/extension/extensionContextFactory.ts @@ -0,0 +1,75 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import debug from 'debug'; +import * as playwright from 'playwright'; +import { startHttpServer } from '../httpServer.js'; +import { CDPRelayServer } from './cdpRelay.js'; + +import type { BrowserContextFactory, ClientInfo } from '../browserContextFactory.js'; + +const debugLogger = debug('pw:mcp:relay'); + +export class ExtensionContextFactory implements BrowserContextFactory { + private _browserChannel: string; + private _relayPromise: Promise | undefined; + private _browserPromise: Promise | undefined; + + constructor(browserChannel: string) { + this._browserChannel = browserChannel; + } + + async createContext(clientInfo: ClientInfo, abortSignal: AbortSignal): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise }> { + // First call will establish the connection to the extension. + if (!this._browserPromise) + this._browserPromise = this._obtainBrowser(clientInfo, abortSignal); + const browser = await this._browserPromise; + return { + browserContext: browser.contexts()[0], + close: async () => { + debugLogger('close() called for browser context'); + await browser.close(); + this._browserPromise = undefined; + } + }; + } + + private async _obtainBrowser(clientInfo: ClientInfo, abortSignal: AbortSignal): Promise { + if (!this._relayPromise) + this._relayPromise = this._startRelay(abortSignal); + const relay = await this._relayPromise; + + abortSignal.throwIfAborted(); + await relay.ensureExtensionConnectionForMCPContext(clientInfo, abortSignal); + const browser = await playwright.chromium.connectOverCDP(relay.cdpEndpoint()); + browser.on('disconnected', () => { + this._browserPromise = undefined; + debugLogger('Browser disconnected'); + }); + return browser; + } + + private async _startRelay(abortSignal: AbortSignal) { + const httpServer = await startHttpServer({}); + const cdpRelayServer = new CDPRelayServer(httpServer, this._browserChannel); + debugLogger(`CDP relay server started, extension endpoint: ${cdpRelayServer.extensionEndpoint()}.`); + if (abortSignal.aborted) + cdpRelayServer.stop(); + else + abortSignal.addEventListener('abort', () => cdpRelayServer.stop()); + return cdpRelayServer; + } +} diff --git a/src/extension/main.ts b/src/extension/main.ts index f6d519c3f..50c2e3a87 100644 --- a/src/extension/main.ts +++ b/src/extension/main.ts @@ -14,26 +14,14 @@ * limitations under the License. */ -import { startCDPRelayServer } from './cdpRelay.js'; +import { ExtensionContextFactory } from './extensionContextFactory.js'; import { BrowserServerBackend } from '../browserServerBackend.js'; import * as mcpTransport from '../mcp/transport.js'; import type { FullConfig } from '../config.js'; -export async function runWithExtension(config: FullConfig, abortController: AbortController) { - const contextFactory = await startCDPRelayServer(config.browser.launchOptions.channel || 'chrome', abortController); - - let backend: BrowserServerBackend | undefined; - const serverBackendFactory = () => { - if (backend) - throw new Error('Another MCP client is still connected. Only one connection at a time is allowed.'); - backend = new BrowserServerBackend(config, contextFactory); - backend.onclose = () => { - contextFactory.clientDisconnected(); - backend = undefined; - }; - return backend; - }; - +export async function runWithExtension(config: FullConfig) { + const contextFactory = new ExtensionContextFactory(config.browser.launchOptions.channel || 'chrome'); + const serverBackendFactory = () => new BrowserServerBackend(config, contextFactory); await mcpTransport.start(serverBackendFactory, config.server); } diff --git a/src/program.ts b/src/program.ts index 508e977af..f34c57e96 100644 --- a/src/program.ts +++ b/src/program.ts @@ -59,7 +59,7 @@ program .addOption(new Option('--loop-tools', 'Run loop tools').hideHelp()) .addOption(new Option('--vision', 'Legacy option, use --caps=vision instead').hideHelp()) .action(async options => { - const abortController = setupExitWatchdog(); + setupExitWatchdog(); if (options.vision) { // eslint-disable-next-line no-console @@ -69,7 +69,7 @@ program const config = await resolveCLIConfig(options); if (options.extension) { - await runWithExtension(config, abortController); + await runWithExtension(config); return; } if (options.loopTools) { @@ -91,15 +91,12 @@ program }); function setupExitWatchdog() { - const abortController = new AbortController(); - let isExiting = false; const handleExit = async () => { if (isExiting) return; isExiting = true; setTimeout(() => process.exit(0), 15000); - abortController.abort('Process exiting'); await Context.disposeAll(); process.exit(0); }; @@ -107,8 +104,6 @@ function setupExitWatchdog() { process.stdin.on('close', handleExit); process.on('SIGINT', handleExit); process.on('SIGTERM', handleExit); - - return abortController; } void program.parseAsync(process.argv);