Skip to content
93 changes: 93 additions & 0 deletions packages/core/src/config/config.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ vi.mock('../core/client.js', () => ({
initialize: vi.fn().mockResolvedValue(undefined),
stripThoughtsFromHistory: vi.fn(),
isInitialized: vi.fn().mockReturnValue(false),
setTools: vi.fn().mockResolvedValue(undefined),
updateSystemInstruction: vi.fn(),
})),
}));

Expand Down Expand Up @@ -199,6 +201,8 @@ import { getExperiments } from '../code_assist/experiments/experiments.js';
import type { CodeAssistServer } from '../code_assist/server.js';
import { ContextManager } from '../services/contextManager.js';
import { UserTierId } from 'src/code_assist/types.js';
import { ExitPlanModeTool } from '../tools/exit-plan-mode.js';
import { EnterPlanModeTool } from '../tools/enter-plan-mode.js';

vi.mock('../core/baseLlmClient.js');
vi.mock('../core/tokenLimits.js', () => ({
Expand Down Expand Up @@ -1324,6 +1328,11 @@ describe('setApprovalMode with folder trust', () => {
it('should update system instruction when entering Plan mode', () => {
const config = new Config(baseParams);
vi.spyOn(config, 'isTrustedFolder').mockReturnValue(true);
vi.spyOn(config, 'getToolRegistry').mockReturnValue({
getTool: vi.fn().mockReturnValue(undefined),
unregisterTool: vi.fn(),
registerTool: vi.fn(),
} as unknown as ReturnType<Config['getToolRegistry']>);
const updateSpy = vi.spyOn(config, 'updateSystemInstructionIfInitialized');

config.setApprovalMode(ApprovalMode.PLAN);
Expand All @@ -1337,6 +1346,11 @@ describe('setApprovalMode with folder trust', () => {
approvalMode: ApprovalMode.PLAN,
});
vi.spyOn(config, 'isTrustedFolder').mockReturnValue(true);
vi.spyOn(config, 'getToolRegistry').mockReturnValue({
getTool: vi.fn().mockReturnValue(undefined),
unregisterTool: vi.fn(),
registerTool: vi.fn(),
} as unknown as ReturnType<Config['getToolRegistry']>);
const updateSpy = vi.spyOn(config, 'updateSystemInstructionIfInitialized');

config.setApprovalMode(ApprovalMode.DEFAULT);
Expand Down Expand Up @@ -2398,3 +2412,82 @@ describe('Plans Directory Initialization', () => {
expect(context.getDirectories()).not.toContain(plansDir);
});
});

describe('syncPlanModeTools', () => {
const baseParams: ConfigParameters = {
sessionId: 'test-session',
targetDir: '.',
debugMode: false,
model: 'test-model',
cwd: '.',
};

it('should register ExitPlanModeTool and unregister EnterPlanModeTool when in PLAN mode', async () => {
const config = new Config({
...baseParams,
approvalMode: ApprovalMode.PLAN,
});
const registry = new ToolRegistry(config, config.getMessageBus());
vi.spyOn(config, 'getToolRegistry').mockReturnValue(registry);

const registerSpy = vi.spyOn(registry, 'registerTool');
const unregisterSpy = vi.spyOn(registry, 'unregisterTool');
const getToolSpy = vi.spyOn(registry, 'getTool');

getToolSpy.mockImplementation((name) => {
if (name === 'enter_plan_mode')
return new EnterPlanModeTool(config, config.getMessageBus());
return undefined;
});

config.syncPlanModeTools();

expect(unregisterSpy).toHaveBeenCalledWith('enter_plan_mode');
expect(registerSpy).toHaveBeenCalledWith(expect.anything());
const registeredTool = registerSpy.mock.calls[0][0];
const { ExitPlanModeTool } = await import('../tools/exit-plan-mode.js');
expect(registeredTool).toBeInstanceOf(ExitPlanModeTool);
});

it('should register EnterPlanModeTool and unregister ExitPlanModeTool when NOT in PLAN mode', async () => {
const config = new Config({
...baseParams,
approvalMode: ApprovalMode.DEFAULT,
});
const registry = new ToolRegistry(config, config.getMessageBus());
vi.spyOn(config, 'getToolRegistry').mockReturnValue(registry);

const registerSpy = vi.spyOn(registry, 'registerTool');
const unregisterSpy = vi.spyOn(registry, 'unregisterTool');
const getToolSpy = vi.spyOn(registry, 'getTool');

getToolSpy.mockImplementation((name) => {
if (name === 'exit_plan_mode')
return new ExitPlanModeTool(config, config.getMessageBus());
return undefined;
});

config.syncPlanModeTools();

expect(unregisterSpy).toHaveBeenCalledWith('exit_plan_mode');
expect(registerSpy).toHaveBeenCalledWith(expect.anything());
const registeredTool = registerSpy.mock.calls[0][0];
const { EnterPlanModeTool } = await import('../tools/enter-plan-mode.js');
expect(registeredTool).toBeInstanceOf(EnterPlanModeTool);
});

it('should call geminiClient.setTools if initialized', async () => {
const config = new Config(baseParams);
const registry = new ToolRegistry(config, config.getMessageBus());
vi.spyOn(config, 'getToolRegistry').mockReturnValue(registry);
const client = config.getGeminiClient();
vi.spyOn(client, 'isInitialized').mockReturnValue(true);
const setToolsSpy = vi
.spyOn(client, 'setTools')
.mockResolvedValue(undefined);

config.syncPlanModeTools();

expect(setToolsSpy).toHaveBeenCalled();
});
});
36 changes: 36 additions & 0 deletions packages/core/src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ import {
import { McpClientManager } from '../tools/mcp-client-manager.js';
import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js';
import { getErrorMessage } from '../utils/errors.js';
import {
ENTER_PLAN_MODE_TOOL_NAME,
EXIT_PLAN_MODE_TOOL_NAME,
} from '../tools/tool-names.js';

export type { FileFilteringOptions };
export {
Expand Down Expand Up @@ -948,6 +952,7 @@ export class Config {
}

await this.geminiClient.initialize();
this.syncPlanModeTools();
}

getContentGenerator(): ContentGenerator {
Expand Down Expand Up @@ -1489,10 +1494,41 @@ export class Config {
currentMode !== mode &&
(currentMode === ApprovalMode.PLAN || mode === ApprovalMode.PLAN);
if (isPlanModeTransition) {
this.syncPlanModeTools();
this.updateSystemInstructionIfInitialized();
}
}

/**
* Synchronizes enter/exit plan mode tools based on current mode.
*/
syncPlanModeTools(): void {
const isPlanMode = this.getApprovalMode() === ApprovalMode.PLAN;
const registry = this.getToolRegistry();

if (isPlanMode) {
if (registry.getTool(ENTER_PLAN_MODE_TOOL_NAME)) {
registry.unregisterTool(ENTER_PLAN_MODE_TOOL_NAME);
}
if (!registry.getTool(EXIT_PLAN_MODE_TOOL_NAME)) {
registry.registerTool(new ExitPlanModeTool(this, this.messageBus));
}
} else {
if (registry.getTool(EXIT_PLAN_MODE_TOOL_NAME)) {
registry.unregisterTool(EXIT_PLAN_MODE_TOOL_NAME);
}
if (!registry.getTool(ENTER_PLAN_MODE_TOOL_NAME)) {
registry.registerTool(new EnterPlanModeTool(this, this.messageBus));
}
}

if (this.geminiClient?.isInitialized()) {
this.geminiClient.setTools().catch((err) => {
debugLogger.error('Failed to update tools', err);
});
}
}

/**
* Logs the duration of the current approval mode.
*/
Expand Down
Loading