Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 116 additions & 5 deletions packages/cli/src/ui/commands/modelCommand.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { modelCommand } from './modelCommand.js';
import { type CommandContext } from './types.js';
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
import type { Config } from '@google/gemini-cli-core';
import { MessageType } from '../types.js';

describe('modelCommand', () => {
let mockContext: CommandContext;
Expand All @@ -17,7 +18,7 @@ describe('modelCommand', () => {
mockContext = createMockCommandContext();
});

it('should return a dialog action to open the model dialog', async () => {
it('should return a dialog action to open the model dialog when no args', async () => {
if (!modelCommand.action) {
throw new Error('The model command must have an action.');
}
Expand All @@ -30,7 +31,7 @@ describe('modelCommand', () => {
});
});

it('should call refreshUserQuota if config is available', async () => {
it('should call refreshUserQuota if config is available when opening dialog', async () => {
if (!modelCommand.action) {
throw new Error('The model command must have an action.');
}
Expand All @@ -45,10 +46,120 @@ describe('modelCommand', () => {
expect(mockRefreshUserQuota).toHaveBeenCalled();
});

describe('manage subcommand', () => {
it('should return a dialog action to open the model dialog', async () => {
const manageCommand = modelCommand.subCommands?.find(
(c) => c.name === 'manage',
);
expect(manageCommand).toBeDefined();

const result = await manageCommand!.action!(mockContext, '');

expect(result).toEqual({
type: 'dialog',
dialog: 'model',
});
});

it('should call refreshUserQuota if config is available', async () => {
const manageCommand = modelCommand.subCommands?.find(
(c) => c.name === 'manage',
);
const mockRefreshUserQuota = vi.fn();
mockContext.services.config = {
refreshUserQuota: mockRefreshUserQuota,
} as unknown as Config;

await manageCommand!.action!(mockContext, '');

expect(mockRefreshUserQuota).toHaveBeenCalled();
});
});

describe('set subcommand', () => {
it('should set the model and log the command', async () => {
const setCommand = modelCommand.subCommands?.find(
(c) => c.name === 'set',
);
expect(setCommand).toBeDefined();

const mockSetModel = vi.fn();
mockContext.services.config = {
setModel: mockSetModel,
getHasAccessToPreviewModel: vi.fn().mockReturnValue(true),
getUserId: vi.fn().mockReturnValue('test-user'),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
getSessionId: vi.fn().mockReturnValue('test-session'),
getContentGeneratorConfig: vi
.fn()
.mockReturnValue({ authType: 'test-auth' }),
isInteractive: vi.fn().mockReturnValue(true),
getExperiments: vi.fn().mockReturnValue({ experimentIds: [] }),
getPolicyEngine: vi.fn().mockReturnValue({
getApprovalMode: vi.fn().mockReturnValue('auto'),
}),
} as unknown as Config;

await setCommand!.action!(mockContext, 'gemini-pro');

expect(mockSetModel).toHaveBeenCalledWith('gemini-pro', true);
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageType.INFO,
text: expect.stringContaining('Model set to gemini-pro'),
}),
);
});

it('should set the model with persistence when --persist is used', async () => {
const setCommand = modelCommand.subCommands?.find(
(c) => c.name === 'set',
);
const mockSetModel = vi.fn();
mockContext.services.config = {
setModel: mockSetModel,
getHasAccessToPreviewModel: vi.fn().mockReturnValue(true),
getUserId: vi.fn().mockReturnValue('test-user'),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
getSessionId: vi.fn().mockReturnValue('test-session'),
getContentGeneratorConfig: vi
.fn()
.mockReturnValue({ authType: 'test-auth' }),
isInteractive: vi.fn().mockReturnValue(true),
getExperiments: vi.fn().mockReturnValue({ experimentIds: [] }),
getPolicyEngine: vi.fn().mockReturnValue({
getApprovalMode: vi.fn().mockReturnValue('auto'),
}),
} as unknown as Config;

await setCommand!.action!(mockContext, 'gemini-pro --persist');

expect(mockSetModel).toHaveBeenCalledWith('gemini-pro', false);
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageType.INFO,
text: expect.stringContaining('Model set to gemini-pro (persisted)'),
}),
);
});

it('should show error if no model name is provided', async () => {
const setCommand = modelCommand.subCommands?.find(
(c) => c.name === 'set',
);
await setCommand!.action!(mockContext, '');

expect(mockContext.ui.addItem).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageType.ERROR,
text: expect.stringContaining('Usage: /model set <model-name>'),
}),
);
});
});

it('should have the correct name and description', () => {
expect(modelCommand.name).toBe('model');
expect(modelCommand.description).toBe(
'Opens a dialog to configure the model',
);
expect(modelCommand.description).toBe('Manage model configuration');
});
});
51 changes: 49 additions & 2 deletions packages/cli/src/ui/commands/modelCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,51 @@
* SPDX-License-Identifier: Apache-2.0
*/

import {
ModelSlashCommandEvent,
logModelSlashCommand,
} from '@google/gemini-cli-core';
import {
type CommandContext,
CommandKind,
type SlashCommand,
} from './types.js';
import { MessageType } from '../types.js';

export const modelCommand: SlashCommand = {
name: 'model',
const setModelCommand: SlashCommand = {
name: 'set',
description:
'Set the model to use. Usage: /model set <model-name> [--persist]',
kind: CommandKind.BUILT_IN,
autoExecute: false,
action: async (context: CommandContext, args: string) => {
const parts = args.trim().split(/\s+/).filter(Boolean);
if (parts.length === 0) {
context.ui.addItem({
type: MessageType.ERROR,
text: 'Usage: /model set <model-name> [--persist]',
});
return;
}

const modelName = parts[0];
const persist = parts.includes('--persist');

if (context.services.config) {
context.services.config.setModel(modelName, !persist);
const event = new ModelSlashCommandEvent(modelName);
logModelSlashCommand(context.services.config, event);

context.ui.addItem({
type: MessageType.INFO,
text: `Model set to ${modelName}${persist ? ' (persisted)' : ''}`,
});
}
},
};

const manageModelCommand: SlashCommand = {
name: 'manage',
description: 'Opens a dialog to configure the model',
kind: CommandKind.BUILT_IN,
autoExecute: true,
Expand All @@ -25,3 +62,13 @@ export const modelCommand: SlashCommand = {
};
},
};

export const modelCommand: SlashCommand = {
name: 'model',
description: 'Manage model configuration',
kind: CommandKind.BUILT_IN,
autoExecute: false,
subCommands: [manageModelCommand, setModelCommand],
action: async (context: CommandContext, args: string) =>
manageModelCommand.action!(context, args),
};
Loading