diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 58026eebee3..aa459bd7374 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -42,6 +42,7 @@ import { ORGANIZATION_ALLOW_ALL, DEFAULT_MODES, DEFAULT_CHECKPOINT_TIMEOUT_SECONDS, + getModelId, } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" import { CloudService, BridgeOrchestrator, getRooCodeApiUrl } from "@roo-code/cloud" @@ -1295,6 +1296,31 @@ export class ClineProvider // Provider Profile Management + /** + * Updates the current task's API handler if the provider or model has changed. + * This prevents unnecessary context condensing when only non-model settings change. + * @param providerSettings The new provider settings to apply + */ + private updateTaskApiHandlerIfNeeded(providerSettings: ProviderSettings): void { + const task = this.getCurrentTask() + + if (task && task.apiConfiguration) { + // Only rebuild API handler if provider or model actually changed + // to avoid triggering unnecessary context condensing + const currentProvider = task.apiConfiguration.apiProvider + const newProvider = providerSettings.apiProvider + const currentModelId = getModelId(task.apiConfiguration) + const newModelId = getModelId(providerSettings) + + if (currentProvider !== newProvider || currentModelId !== newModelId) { + task.api = buildApiHandler(providerSettings) + } + } else if (task) { + // Fallback: rebuild if apiConfiguration is not available + task.api = buildApiHandler(providerSettings) + } + } + getProviderProfileEntries(): ProviderSettingsEntry[] { return this.contextProxy.getValues().listApiConfigMeta || [] } @@ -1342,11 +1368,7 @@ export class ClineProvider // Change the provider for the current task. // TODO: We should rename `buildApiHandler` for clarity (e.g. `getProviderClient`). - const task = this.getCurrentTask() - - if (task) { - task.api = buildApiHandler(providerSettings) - } + this.updateTaskApiHandlerIfNeeded(providerSettings) } else { await this.updateGlobalState("listApiConfigMeta", await this.providerSettingsManager.listConfig()) } @@ -1403,11 +1425,7 @@ export class ClineProvider } // Change the provider for the current task. - const task = this.getCurrentTask() - - if (task) { - task.api = buildApiHandler(providerSettings) - } + this.updateTaskApiHandlerIfNeeded(providerSettings) await this.postStateToWebview() diff --git a/src/core/webview/__tests__/ClineProvider.apiHandlerRebuild.spec.ts b/src/core/webview/__tests__/ClineProvider.apiHandlerRebuild.spec.ts new file mode 100644 index 00000000000..4037d9df266 --- /dev/null +++ b/src/core/webview/__tests__/ClineProvider.apiHandlerRebuild.spec.ts @@ -0,0 +1,520 @@ +// npx vitest core/webview/__tests__/ClineProvider.apiHandlerRebuild.spec.ts + +import * as vscode from "vscode" + +import { TelemetryService } from "@roo-code/telemetry" +import { getModelId } from "@roo-code/types" + +import { ContextProxy } from "../../config/ContextProxy" +import { Task, TaskOptions } from "../../task/Task" +import { ClineProvider } from "../ClineProvider" + +// Mock setup +vi.mock("fs/promises", () => ({ + mkdir: vi.fn().mockResolvedValue(undefined), + writeFile: vi.fn().mockResolvedValue(undefined), + readFile: vi.fn().mockResolvedValue(""), + unlink: vi.fn().mockResolvedValue(undefined), + rmdir: vi.fn().mockResolvedValue(undefined), +})) + +vi.mock("../../../utils/storage", () => ({ + getSettingsDirectoryPath: vi.fn().mockResolvedValue("/test/settings/path"), + getTaskDirectoryPath: vi.fn().mockResolvedValue("/test/task/path"), + getGlobalStoragePath: vi.fn().mockResolvedValue("/test/storage/path"), +})) + +vi.mock("p-wait-for", () => ({ + __esModule: true, + default: vi.fn().mockResolvedValue(undefined), +})) + +vi.mock("delay", () => { + const delayFn = (_ms: number) => Promise.resolve() + delayFn.createDelay = () => delayFn + delayFn.reject = () => Promise.reject(new Error("Delay rejected")) + delayFn.range = () => Promise.resolve() + return { default: delayFn } +}) + +vi.mock("vscode", () => ({ + ExtensionContext: vi.fn(), + OutputChannel: vi.fn(), + WebviewView: vi.fn(), + Uri: { + joinPath: vi.fn(), + file: vi.fn(), + }, + commands: { + executeCommand: vi.fn().mockResolvedValue(undefined), + }, + window: { + showInformationMessage: vi.fn(), + showWarningMessage: vi.fn(), + showErrorMessage: vi.fn(), + onDidChangeActiveTextEditor: vi.fn(() => ({ dispose: vi.fn() })), + }, + workspace: { + getConfiguration: vi.fn().mockReturnValue({ + get: vi.fn().mockReturnValue([]), + update: vi.fn(), + }), + onDidChangeConfiguration: vi.fn().mockImplementation(() => ({ + dispose: vi.fn(), + })), + }, + env: { + uriScheme: "vscode", + language: "en", + appName: "Visual Studio Code", + }, + ExtensionMode: { + Production: 1, + Development: 2, + Test: 3, + }, + version: "1.85.0", +})) + +vi.mock("../../../utils/tts", () => ({ + setTtsEnabled: vi.fn(), + setTtsSpeed: vi.fn(), +})) + +vi.mock("../../../api", () => ({ + buildApiHandler: vi.fn(), +})) + +vi.mock("../../../integrations/workspace/WorkspaceTracker", () => { + return { + default: vi.fn().mockImplementation(() => ({ + initializeFilePaths: vi.fn(), + dispose: vi.fn(), + })), + } +}) + +vi.mock("../../task/Task", () => ({ + Task: vi.fn().mockImplementation((options) => { + const mockTask = { + api: undefined, + abortTask: vi.fn(), + handleWebviewAskResponse: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + overwriteClineMessages: vi.fn(), + overwriteApiConversationHistory: vi.fn(), + taskId: options?.historyItem?.id || "test-task-id", + emit: vi.fn(), + } + // Define apiConfiguration as a property so tests can read it + Object.defineProperty(mockTask, "apiConfiguration", { + value: options?.apiConfiguration || { apiProvider: "openrouter", openRouterModelId: "openai/gpt-4" }, + writable: true, + configurable: true, + }) + return mockTask + }), +})) + +vi.mock("@roo-code/cloud", () => ({ + CloudService: { + hasInstance: vi.fn().mockReturnValue(true), + get instance() { + return { + isAuthenticated: vi.fn().mockReturnValue(false), + } + }, + }, + BridgeOrchestrator: { + isEnabled: vi.fn().mockReturnValue(false), + }, + getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"), +})) + +describe("ClineProvider - API Handler Rebuild Guard", () => { + let provider: ClineProvider + let mockContext: vscode.ExtensionContext + let mockOutputChannel: vscode.OutputChannel + let mockWebviewView: vscode.WebviewView + let mockPostMessage: any + let defaultTaskOptions: TaskOptions + let buildApiHandlerMock: any + + beforeEach(async () => { + vi.clearAllMocks() + + if (!TelemetryService.hasInstance()) { + TelemetryService.createInstance([]) + } + + const globalState: Record = { + mode: "code", + currentApiConfigName: "test-config", + } + + const secrets: Record = {} + + mockContext = { + extensionPath: "/test/path", + extensionUri: {} as vscode.Uri, + globalState: { + get: vi.fn().mockImplementation((key: string) => globalState[key]), + update: vi.fn().mockImplementation((key: string, value: any) => (globalState[key] = value)), + keys: vi.fn().mockImplementation(() => Object.keys(globalState)), + }, + secrets: { + get: vi.fn().mockImplementation((key: string) => secrets[key]), + store: vi.fn().mockImplementation((key: string, value: string | undefined) => (secrets[key] = value)), + delete: vi.fn().mockImplementation((key: string) => delete secrets[key]), + }, + subscriptions: [], + extension: { + packageJSON: { version: "1.0.0" }, + }, + globalStorageUri: { + fsPath: "/test/storage/path", + }, + } as unknown as vscode.ExtensionContext + + mockOutputChannel = { + appendLine: vi.fn(), + clear: vi.fn(), + dispose: vi.fn(), + } as unknown as vscode.OutputChannel + + mockPostMessage = vi.fn() + + mockWebviewView = { + webview: { + postMessage: mockPostMessage, + html: "", + options: {}, + onDidReceiveMessage: vi.fn(), + asWebviewUri: vi.fn(), + }, + visible: true, + onDidDispose: vi.fn().mockImplementation((callback) => { + callback() + return { dispose: vi.fn() } + }), + onDidChangeVisibility: vi.fn().mockImplementation(() => ({ dispose: vi.fn() })), + } as unknown as vscode.WebviewView + + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) + + // Mock providerSettingsManager + ;(provider as any).providerSettingsManager = { + saveConfig: vi.fn().mockResolvedValue("test-id"), + listConfig: vi + .fn() + .mockResolvedValue([ + { name: "test-config", id: "test-id", apiProvider: "openrouter", modelId: "openai/gpt-4" }, + ]), + setModeConfig: vi.fn(), + activateProfile: vi.fn().mockResolvedValue({ + name: "test-config", + id: "test-id", + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }), + getProfile: vi.fn().mockResolvedValue({ + name: "test-config", + id: "test-id", + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }), + } + + // Get the buildApiHandler mock + const { buildApiHandler } = await import("../../../api") + buildApiHandlerMock = vi.mocked(buildApiHandler) + + // Setup default mock implementation + buildApiHandlerMock.mockReturnValue({ + getModel: vi.fn().mockReturnValue({ + id: "openai/gpt-4", + info: { contextWindow: 128000 }, + }), + }) + + defaultTaskOptions = { + provider, + apiConfiguration: { + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }, + } + + await provider.resolveWebviewView(mockWebviewView) + }) + + describe("upsertProviderProfile", () => { + test("does NOT rebuild API handler when provider and model unchanged", async () => { + // Create a task with the current config + const mockTask = new Task({ + ...defaultTaskOptions, + apiConfiguration: { + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }, + }) + const originalApi = { + getModel: vi.fn().mockReturnValue({ + id: "openai/gpt-4", + info: { contextWindow: 128000 }, + }), + } + mockTask.api = originalApi as any + + await provider.addClineToStack(mockTask) + + // Clear the mock to track new calls + buildApiHandlerMock.mockClear() + + // Save settings with SAME provider and model (simulating Save button click) + await provider.upsertProviderProfile( + "test-config", + { + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + // Other settings that might change + rateLimitSeconds: 5, + modelTemperature: 0.7, + }, + true, + ) + + // Verify buildApiHandler was NOT called since provider/model unchanged + expect(buildApiHandlerMock).not.toHaveBeenCalled() + // Verify the task's api property was NOT reassigned (still same reference) + expect(mockTask.api).toBe(originalApi) + }) + + test("rebuilds API handler when provider changes", async () => { + const mockTask = new Task({ + ...defaultTaskOptions, + apiConfiguration: { + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }, + }) + mockTask.api = { + getModel: vi.fn().mockReturnValue({ + id: "openai/gpt-4", + info: { contextWindow: 128000 }, + }), + } as any + + await provider.addClineToStack(mockTask) + + buildApiHandlerMock.mockClear() + + // Change provider to anthropic + await provider.upsertProviderProfile( + "test-config", + { + apiProvider: "anthropic", + apiModelId: "claude-3-5-sonnet-20241022", + }, + true, + ) + + // Verify buildApiHandler WAS called since provider changed + expect(buildApiHandlerMock).toHaveBeenCalledWith( + expect.objectContaining({ + apiProvider: "anthropic", + apiModelId: "claude-3-5-sonnet-20241022", + }), + ) + }) + + test("rebuilds API handler when model changes", async () => { + const mockTask = new Task({ + ...defaultTaskOptions, + apiConfiguration: { + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }, + }) + mockTask.api = { + getModel: vi.fn().mockReturnValue({ + id: "openai/gpt-4", + info: { contextWindow: 128000 }, + }), + } as any + + await provider.addClineToStack(mockTask) + + buildApiHandlerMock.mockClear() + + // Change model to different model + await provider.upsertProviderProfile( + "test-config", + { + apiProvider: "openrouter", + openRouterModelId: "anthropic/claude-3-5-sonnet-20241022", + }, + true, + ) + + // Verify buildApiHandler WAS called since model changed + expect(buildApiHandlerMock).toHaveBeenCalledWith( + expect.objectContaining({ + apiProvider: "openrouter", + openRouterModelId: "anthropic/claude-3-5-sonnet-20241022", + }), + ) + }) + + test("does nothing when no task is running", async () => { + // Don't add any task to stack + buildApiHandlerMock.mockClear() + + await provider.upsertProviderProfile( + "test-config", + { + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }, + true, + ) + + // Should not call buildApiHandler when there's no task + expect(buildApiHandlerMock).not.toHaveBeenCalled() + }) + }) + + describe("activateProviderProfile", () => { + test("does NOT rebuild API handler when provider and model unchanged", async () => { + const mockTask = new Task({ + ...defaultTaskOptions, + apiConfiguration: { + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }, + }) + const originalApi = { + getModel: vi.fn().mockReturnValue({ + id: "openai/gpt-4", + info: { contextWindow: 128000 }, + }), + } + mockTask.api = originalApi as any + + await provider.addClineToStack(mockTask) + + buildApiHandlerMock.mockClear() + + // Mock activateProfile to return same provider/model + ;(provider as any).providerSettingsManager.activateProfile = vi.fn().mockResolvedValue({ + name: "test-config", + id: "test-id", + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }) + + await provider.activateProviderProfile({ name: "test-config" }) + + // Verify buildApiHandler was NOT called + expect(buildApiHandlerMock).not.toHaveBeenCalled() + // Verify the API reference wasn't changed + expect(mockTask.api).toBe(originalApi) + }) + + test("rebuilds API handler when provider changes", async () => { + const mockTask = new Task({ + ...defaultTaskOptions, + apiConfiguration: { + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }, + }) + mockTask.api = { + getModel: vi.fn().mockReturnValue({ + id: "openai/gpt-4", + info: { contextWindow: 128000 }, + }), + } as any + + await provider.addClineToStack(mockTask) + + buildApiHandlerMock.mockClear() + + // Mock activateProfile to return different provider + ;(provider as any).providerSettingsManager.activateProfile = vi.fn().mockResolvedValue({ + name: "anthropic-config", + id: "anthropic-id", + apiProvider: "anthropic", + apiModelId: "claude-3-5-sonnet-20241022", + }) + + await provider.activateProviderProfile({ name: "anthropic-config" }) + + // Verify buildApiHandler WAS called + expect(buildApiHandlerMock).toHaveBeenCalledWith( + expect.objectContaining({ + apiProvider: "anthropic", + apiModelId: "claude-3-5-sonnet-20241022", + }), + ) + }) + + test("rebuilds API handler when model changes", async () => { + const mockTask = new Task({ + ...defaultTaskOptions, + apiConfiguration: { + apiProvider: "openrouter", + openRouterModelId: "openai/gpt-4", + }, + }) + mockTask.api = { + getModel: vi.fn().mockReturnValue({ + id: "openai/gpt-4", + info: { contextWindow: 128000 }, + }), + } as any + + await provider.addClineToStack(mockTask) + + buildApiHandlerMock.mockClear() + + // Mock activateProfile to return different model + ;(provider as any).providerSettingsManager.activateProfile = vi.fn().mockResolvedValue({ + name: "test-config", + id: "test-id", + apiProvider: "openrouter", + openRouterModelId: "anthropic/claude-3-5-sonnet-20241022", + }) + + await provider.activateProviderProfile({ name: "test-config" }) + + // Verify buildApiHandler WAS called + expect(buildApiHandlerMock).toHaveBeenCalledWith( + expect.objectContaining({ + apiProvider: "openrouter", + openRouterModelId: "anthropic/claude-3-5-sonnet-20241022", + }), + ) + }) + }) + + describe("getModelId helper", () => { + test("correctly extracts model ID from different provider configurations", () => { + expect(getModelId({ apiProvider: "openrouter", openRouterModelId: "openai/gpt-4" })).toBe("openai/gpt-4") + expect(getModelId({ apiProvider: "anthropic", apiModelId: "claude-3-5-sonnet-20241022" })).toBe( + "claude-3-5-sonnet-20241022", + ) + expect(getModelId({ apiProvider: "openai", openAiModelId: "gpt-4-turbo" })).toBe("gpt-4-turbo") + expect(getModelId({ apiProvider: "glama", glamaModelId: "some-model" })).toBe("some-model") + expect(getModelId({ apiProvider: "bedrock", apiModelId: "anthropic.claude-v2" })).toBe( + "anthropic.claude-v2", + ) + }) + + test("returns undefined when no model ID is present", () => { + expect(getModelId({ apiProvider: "anthropic" })).toBeUndefined() + expect(getModelId({})).toBeUndefined() + }) + }) +})