diff --git a/packages/cloud/src/bridge/BridgeOrchestrator.ts b/packages/cloud/src/bridge/BridgeOrchestrator.ts index 15b5c65eb20..1b82cd99e49 100644 --- a/packages/cloud/src/bridge/BridgeOrchestrator.ts +++ b/packages/cloud/src/bridge/BridgeOrchestrator.ts @@ -59,13 +59,28 @@ export class BridgeOrchestrator { return BridgeOrchestrator.instance } - public static isEnabled(user?: CloudUserInfo | null, remoteControlEnabled?: boolean): boolean { - return !!(user?.id && user.extensionBridgeEnabled && remoteControlEnabled) + public static isEnabled(user: CloudUserInfo | null, remoteControlEnabled: boolean): boolean { + // Always disabled if signed out. + if (!user) { + return false + } + + // Disabled by the user's organization? + if (!user.extensionBridgeEnabled) { + return false + } + + // Disabled by the user? + if (!remoteControlEnabled) { + return false + } + + return true } public static async connectOrDisconnect( - userInfo: CloudUserInfo | null, - remoteControlEnabled: boolean | undefined, + userInfo: CloudUserInfo, + remoteControlEnabled: boolean, options: BridgeOrchestratorOptions, ): Promise { if (BridgeOrchestrator.isEnabled(userInfo, remoteControlEnabled)) { diff --git a/src/__tests__/extension.spec.ts b/src/__tests__/extension.spec.ts new file mode 100644 index 00000000000..c39854f697c --- /dev/null +++ b/src/__tests__/extension.spec.ts @@ -0,0 +1,257 @@ +// npx vitest run __tests__/extension.spec.ts + +import type * as vscode from "vscode" +import type { AuthState } from "@roo-code/types" + +vi.mock("vscode", () => ({ + window: { + createOutputChannel: vi.fn().mockReturnValue({ + appendLine: vi.fn(), + }), + registerWebviewViewProvider: vi.fn(), + registerUriHandler: vi.fn(), + tabGroups: { + onDidChangeTabs: vi.fn(), + }, + onDidChangeActiveTextEditor: vi.fn(), + }, + workspace: { + registerTextDocumentContentProvider: vi.fn(), + getConfiguration: vi.fn().mockReturnValue({ + get: vi.fn().mockReturnValue([]), + }), + createFileSystemWatcher: vi.fn().mockReturnValue({ + onDidCreate: vi.fn(), + onDidChange: vi.fn(), + onDidDelete: vi.fn(), + dispose: vi.fn(), + }), + onDidChangeWorkspaceFolders: vi.fn(), + }, + languages: { + registerCodeActionsProvider: vi.fn(), + }, + commands: { + executeCommand: vi.fn(), + }, + env: { + language: "en", + }, + ExtensionMode: { + Production: 1, + }, +})) + +vi.mock("@dotenvx/dotenvx", () => ({ + config: vi.fn(), +})) + +const mockBridgeOrchestratorDisconnect = vi.fn().mockResolvedValue(undefined) + +vi.mock("@roo-code/cloud", () => ({ + CloudService: { + createInstance: vi.fn(), + hasInstance: vi.fn().mockReturnValue(true), + get instance() { + return { + off: vi.fn(), + on: vi.fn(), + getUserInfo: vi.fn().mockReturnValue(null), + isTaskSyncEnabled: vi.fn().mockReturnValue(false), + } + }, + }, + BridgeOrchestrator: { + disconnect: mockBridgeOrchestratorDisconnect, + }, + getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"), +})) + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + createInstance: vi.fn().mockReturnValue({ + register: vi.fn(), + setProvider: vi.fn(), + shutdown: vi.fn(), + }), + get instance() { + return { + register: vi.fn(), + setProvider: vi.fn(), + shutdown: vi.fn(), + } + }, + }, + PostHogTelemetryClient: vi.fn(), +})) + +vi.mock("../utils/outputChannelLogger", () => ({ + createOutputChannelLogger: vi.fn().mockReturnValue(vi.fn()), + createDualLogger: vi.fn().mockReturnValue(vi.fn()), +})) + +vi.mock("../shared/package", () => ({ + Package: { + name: "test-extension", + outputChannel: "Test Output", + version: "1.0.0", + }, +})) + +vi.mock("../shared/language", () => ({ + formatLanguage: vi.fn().mockReturnValue("en"), +})) + +vi.mock("../core/config/ContextProxy", () => ({ + ContextProxy: { + getInstance: vi.fn().mockResolvedValue({ + getValue: vi.fn(), + setValue: vi.fn(), + getValues: vi.fn().mockReturnValue({}), + getProviderSettings: vi.fn().mockReturnValue({}), + }), + }, +})) + +vi.mock("../integrations/editor/DiffViewProvider", () => ({ + DIFF_VIEW_URI_SCHEME: "test-diff-scheme", +})) + +vi.mock("../integrations/terminal/TerminalRegistry", () => ({ + TerminalRegistry: { + initialize: vi.fn(), + cleanup: vi.fn(), + }, +})) + +vi.mock("../services/mcp/McpServerManager", () => ({ + McpServerManager: { + cleanup: vi.fn().mockResolvedValue(undefined), + getInstance: vi.fn().mockResolvedValue(null), + unregisterProvider: vi.fn(), + }, +})) + +vi.mock("../services/code-index/manager", () => ({ + CodeIndexManager: { + getInstance: vi.fn().mockReturnValue(null), + }, +})) + +vi.mock("../services/mdm/MdmService", () => ({ + MdmService: { + createInstance: vi.fn().mockResolvedValue(null), + }, +})) + +vi.mock("../utils/migrateSettings", () => ({ + migrateSettings: vi.fn().mockResolvedValue(undefined), +})) + +vi.mock("../utils/autoImportSettings", () => ({ + autoImportSettings: vi.fn().mockResolvedValue(undefined), +})) + +vi.mock("../extension/api", () => ({ + API: vi.fn().mockImplementation(() => ({})), +})) + +vi.mock("../activate", () => ({ + handleUri: vi.fn(), + registerCommands: vi.fn(), + registerCodeActions: vi.fn(), + registerTerminalActions: vi.fn(), + CodeActionProvider: vi.fn().mockImplementation(() => ({ + providedCodeActionKinds: [], + })), +})) + +vi.mock("../i18n", () => ({ + initializeI18n: vi.fn(), +})) + +describe("extension.ts", () => { + let mockContext: vscode.ExtensionContext + let authStateChangedHandler: + | ((data: { state: AuthState; previousState: AuthState }) => void | Promise) + | undefined + + beforeEach(() => { + vi.clearAllMocks() + mockBridgeOrchestratorDisconnect.mockClear() + + mockContext = { + extensionPath: "/test/path", + globalState: { + get: vi.fn().mockReturnValue(undefined), + update: vi.fn(), + }, + subscriptions: [], + } as unknown as vscode.ExtensionContext + + authStateChangedHandler = undefined + }) + + test("authStateChangedHandler calls BridgeOrchestrator.disconnect when logged-out event fires", async () => { + const { CloudService, BridgeOrchestrator } = await import("@roo-code/cloud") + + // Capture the auth state changed handler. + vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => { + if (handlers?.["auth-state-changed"]) { + authStateChangedHandler = handlers["auth-state-changed"] + } + + return { + off: vi.fn(), + on: vi.fn(), + telemetryClient: null, + } as any + }) + + // Activate the extension. + const { activate } = await import("../extension") + await activate(mockContext) + + // Verify handler was registered. + expect(authStateChangedHandler).toBeDefined() + + // Trigger logout. + await authStateChangedHandler!({ + state: "logged-out" as AuthState, + previousState: "logged-in" as AuthState, + }) + + // Verify BridgeOrchestrator.disconnect was called + expect(mockBridgeOrchestratorDisconnect).toHaveBeenCalled() + }) + + test("authStateChangedHandler does not call BridgeOrchestrator.disconnect for other states", async () => { + const { CloudService } = await import("@roo-code/cloud") + + // Capture the auth state changed handler. + vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => { + if (handlers?.["auth-state-changed"]) { + authStateChangedHandler = handlers["auth-state-changed"] + } + + return { + off: vi.fn(), + on: vi.fn(), + telemetryClient: null, + } as any + }) + + // Activate the extension. + const { activate } = await import("../extension") + await activate(mockContext) + + // Trigger login. + await authStateChangedHandler!({ + state: "logged-in" as AuthState, + previousState: "logged-out" as AuthState, + }) + + // Verify BridgeOrchestrator.disconnect was NOT called. + expect(mockBridgeOrchestratorDisconnect).not.toHaveBeenCalled() + }) +}) diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index f198fad8b2b..9abddc6d962 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -2262,7 +2262,19 @@ export class ClineProvider } public async remoteControlEnabled(enabled: boolean) { + if (!enabled) { + await BridgeOrchestrator.disconnect() + return + } + const userInfo = CloudService.instance.getUserInfo() + + if (!userInfo) { + this.log("[ClineProvider#remoteControlEnabled] Failed to get user info, disconnecting") + await BridgeOrchestrator.disconnect() + return + } + const config = await CloudService.instance.cloudAPI?.bridgeConfig().catch(() => undefined) if (!config) { diff --git a/src/extension.ts b/src/extension.ts index 26a8f9d6e2f..dc96e282c43 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -134,10 +134,9 @@ export async function activate(context: vscode.ExtensionContext) { if (data.state === "logged-out") { try { await provider.remoteControlEnabled(false) - cloudLogger("[CloudService] BridgeOrchestrator disconnected on logout") } catch (error) { cloudLogger( - `[CloudService] Failed to disconnect BridgeOrchestrator on logout: ${error instanceof Error ? error.message : String(error)}`, + `[authStateChangedHandler] remoteControlEnabled(false) failed: ${error instanceof Error ? error.message : String(error)}`, ) } } @@ -151,7 +150,7 @@ export async function activate(context: vscode.ExtensionContext) { provider.remoteControlEnabled(CloudService.instance.isTaskSyncEnabled()) } catch (error) { cloudLogger( - `[CloudService] BridgeOrchestrator#connectOrDisconnect failed on settings change: ${error instanceof Error ? error.message : String(error)}`, + `[settingsUpdatedHandler] remoteControlEnabled failed: ${error instanceof Error ? error.message : String(error)}`, ) } } @@ -163,7 +162,7 @@ export async function activate(context: vscode.ExtensionContext) { postStateListener() if (!CloudService.instance.cloudAPI) { - cloudLogger("[CloudService] CloudAPI is not initialized") + cloudLogger("[userInfoHandler] CloudAPI is not initialized") return } @@ -171,7 +170,7 @@ export async function activate(context: vscode.ExtensionContext) { provider.remoteControlEnabled(CloudService.instance.isTaskSyncEnabled()) } catch (error) { cloudLogger( - `[CloudService] BridgeOrchestrator#connectOrDisconnect failed on user change: ${error instanceof Error ? error.message : String(error)}`, + `[userInfoHandler] remoteControlEnabled failed: ${error instanceof Error ? error.message : String(error)}`, ) } }