diff --git a/cli/src/connection.rs b/cli/src/connection.rs index 220d5224..10e898d5 100644 --- a/cli/src/connection.rs +++ b/cli/src/connection.rs @@ -220,6 +220,7 @@ pub fn ensure_daemon( provider: Option<&str>, device: Option<&str>, session_name: Option<&str>, + allow_origins: Option<&str>, ) -> Result { // Check if daemon is running AND responsive if is_daemon_running(session) && daemon_ready(session) { @@ -364,6 +365,10 @@ pub fn ensure_daemon( cmd.env("AGENT_BROWSER_SESSION_NAME", sn); } + if let Some(ao) = allow_origins { + cmd.env("AGENT_BROWSER_ALLOWED_ORIGINS", ao); + } + // Create new process group and session to fully detach unsafe { cmd.pre_exec(|| { @@ -447,6 +452,10 @@ pub fn ensure_daemon( cmd.env("AGENT_BROWSER_SESSION_NAME", sn); } + if let Some(ao) = allow_origins { + cmd.env("AGENT_BROWSER_ALLOWED_ORIGINS", ao); + } + // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS const CREATE_NEW_PROCESS_GROUP: u32 = 0x00000200; const DETACHED_PROCESS: u32 = 0x00000008; diff --git a/cli/src/flags.rs b/cli/src/flags.rs index 84d41ebf..c2a137b0 100644 --- a/cli/src/flags.rs +++ b/cli/src/flags.rs @@ -22,6 +22,7 @@ pub struct Flags { pub device: Option, pub auto_connect: bool, pub session_name: Option, + pub allow_origins: Option, // Track which launch-time options were explicitly passed via CLI // (as opposed to being set only via environment variables) @@ -69,6 +70,7 @@ pub fn parse_flags(args: &[String]) -> Flags { device: env::var("AGENT_BROWSER_IOS_DEVICE").ok(), auto_connect: env::var("AGENT_BROWSER_AUTO_CONNECT").is_ok(), session_name: env::var("AGENT_BROWSER_SESSION_NAME").ok(), + allow_origins: env::var("AGENT_BROWSER_ALLOWED_ORIGINS").ok(), // Track CLI-passed flags (default false, set to true when flag is passed) cli_executable_path: false, cli_extensions: false, @@ -186,6 +188,12 @@ pub fn parse_flags(args: &[String]) -> Flags { i += 1; } } + "--allow-origins" => { + if let Some(s) = args.get(i + 1) { + flags.allow_origins = Some(s.clone()); + i += 1; + } + } _ => {} } i += 1; @@ -224,6 +232,7 @@ pub fn clean_args(args: &[String]) -> Vec { "--provider", "--device", "--session-name", + "--allow-origins", ]; for arg in args.iter() { diff --git a/cli/src/main.rs b/cli/src/main.rs index 7be0f153..f3ef9f88 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -226,6 +226,7 @@ fn main() { flags.provider.as_deref(), flags.device.as_deref(), flags.session_name.as_deref(), + flags.allow_origins.as_deref(), ) { Ok(result) => result, Err(e) => { diff --git a/cli/src/output.rs b/cli/src/output.rs index 7f083a51..2f70a63c 100644 --- a/cli/src/output.rs +++ b/cli/src/output.rs @@ -1873,6 +1873,8 @@ Options: --cdp Connect via CDP (Chrome DevTools Protocol) --auto-connect Auto-discover and connect to running Chrome --session-name Auto-save/restore session state (cookies, localStorage) + --allow-origins Extra allowed WebSocket origins, comma-separated + (or AGENT_BROWSER_ALLOWED_ORIGINS env) --debug Debug output --version, -V Show version @@ -1887,6 +1889,7 @@ Environment: AGENT_BROWSER_STREAM_PORT Enable WebSocket streaming on port (e.g., 9223) AGENT_BROWSER_IOS_DEVICE Default iOS device name AGENT_BROWSER_IOS_UDID Default iOS device UDID + AGENT_BROWSER_ALLOWED_ORIGINS Extra allowed WebSocket origins (comma-separated) Examples: agent-browser open example.com diff --git a/src/daemon.ts b/src/daemon.ts index 36d5acc4..7bc658b9 100644 --- a/src/daemon.ts +++ b/src/daemon.ts @@ -308,6 +308,13 @@ export async function startDaemon(options?: { ? parseInt(process.env.AGENT_BROWSER_STREAM_PORT, 10) : 0); + // Configure custom allowed origins for stream server + const allowedOriginsEnv = process.env.AGENT_BROWSER_ALLOWED_ORIGINS; + if (allowedOriginsEnv) { + const { setAllowedOrigins } = await import('./stream-server.js'); + setAllowedOrigins(allowedOriginsEnv.split(',').map(s => s.trim())); + } + if (streamPort > 0 && !isIOS && manager instanceof BrowserManager) { streamServer = new StreamServer(manager, streamPort); await streamServer.start(); diff --git a/src/stream-server.test.ts b/src/stream-server.test.ts index 096a4de2..bfdbbc4f 100644 --- a/src/stream-server.test.ts +++ b/src/stream-server.test.ts @@ -1,7 +1,11 @@ -import { describe, it, expect } from 'vitest'; -import { isAllowedOrigin } from './stream-server.js'; +import { describe, it, expect, afterEach } from 'vitest'; +import { isAllowedOrigin, setAllowedOrigins } from './stream-server.js'; describe('isAllowedOrigin', () => { + afterEach(() => { + setAllowedOrigins([]); + }); + describe('allowed origins', () => { it('should allow connections with no origin (CLI tools)', () => { expect(isAllowedOrigin(undefined)).toBe(true); @@ -38,6 +42,22 @@ describe('isAllowedOrigin', () => { expect(isAllowedOrigin('http://[::1]')).toBe(true); expect(isAllowedOrigin('http://[::1]:3000')).toBe(true); }); + + it('should allow vscode-webview:// origins', () => { + expect(isAllowedOrigin('vscode-webview://abc123')).toBe(true); + expect(isAllowedOrigin('vscode-webview://some-extension-id/index.html')).toBe(true); + }); + + it('should allow custom origins', () => { + setAllowedOrigins(['https://my-app.com']); + expect(isAllowedOrigin('https://my-app.com')).toBe(true); + expect(isAllowedOrigin('https://evil.com')).toBe(false); + }); + + it('should allow custom origin prefixes', () => { + setAllowedOrigins(['chrome-extension://']); + expect(isAllowedOrigin('chrome-extension://abcdef123456')).toBe(true); + }); }); describe('rejected origins', () => { diff --git a/src/stream-server.ts b/src/stream-server.ts index 9981eccc..cacfe32e 100644 --- a/src/stream-server.ts +++ b/src/stream-server.ts @@ -2,9 +2,17 @@ import { WebSocketServer, WebSocket } from 'ws'; import type { BrowserManager, ScreencastFrame } from './browser.js'; import { setScreencastFrameCallback } from './actions.js'; +// Custom allowed origins set via setAllowedOrigins() +let customAllowedOrigins: string[] = []; + +export function setAllowedOrigins(origins: string[]): void { + customAllowedOrigins = origins; +} + /** * Check whether a WebSocket connection origin should be allowed. - * Allows: no origin (CLI tools), file:// origins, and localhost/loopback origins. + * Allows: no origin (CLI tools), file:// origins, vscode-webview:// origins, + * custom allowed origins, and localhost/loopback origins. * Rejects: all other origins (prevents malicious web pages from connecting). */ export function isAllowedOrigin(origin: string | undefined): boolean { @@ -16,6 +24,20 @@ export function isAllowedOrigin(origin: string | undefined): boolean { if (origin.startsWith('file://')) { return true; } + // Allow vscode-webview:// origins (VSCode Webview extensions) + if (origin.startsWith('vscode-webview://')) { + return true; + } + // Check custom allowed origins + for (const allowed of customAllowedOrigins) { + if (origin === allowed) return true; + if (origin.startsWith(allowed)) { + // Scheme prefixes (e.g. "chrome-extension://") match any extension ID after them + if (allowed.endsWith('://')) return true; + const next = origin[allowed.length]; + if (next === undefined || next === '/' || next === ':') return true; + } + } // Allow localhost/loopback origins (browser-based stream viewers) try { const url = new URL(origin);