Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions packages/core/src/config/models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import {
resolveModel,
resolveClassifierModel,
isGemini3Model,
isGemini2Model,
isAutoModel,
getDisplayString,
Expand All @@ -25,6 +26,29 @@
DEFAULT_GEMINI_MODEL_AUTO,
} from './models.js';

describe('isGemini3Model', () => {
it('should return true for gemini-3 models', () => {
expect(isGemini3Model('gemini-3-pro-preview')).toBe(true);
expect(isGemini3Model('gemini-3-flash-preview')).toBe(true);
});

it('should return true for aliases that resolve to Gemini 3', () => {
expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO, true)).toBe(true);
expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO, true)).toBe(true);
expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true);
});

it('should return false for Gemini 2 models', () => {
expect(isGemini3Model('gemini-2.5-pro')).toBe(false);
expect(isGemini3Model('gemini-2.5-flash')).toBe(false);
expect(isGemini3Model(DEFAULT_GEMINI_MODEL_AUTO)).toBe(false);
});

it('should return false for arbitrary strings', () => {
expect(isGemini3Model('gpt-4')).toBe(false);
});
});

describe('getDisplayString', () => {
it('should return Auto (Gemini 3) for preview auto model', () => {
expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)');
Expand Down Expand Up @@ -67,7 +91,7 @@
expect(supportsMultimodalFunctionResponse('gemini-3-pro')).toBe(true);
});

it('should return false for gemini-2 models', () => {

Check warning on line 94 in packages/core/src/config/models.test.ts

View workflow job for this annotation

GitHub Actions / Lint

Found sensitive keyword "gemini-2". Please make sure this change is appropriate to submit.
expect(supportsMultimodalFunctionResponse('gemini-2.5-pro')).toBe(false);
expect(supportsMultimodalFunctionResponse('gemini-2.5-flash')).toBe(false);
});
Expand Down
15 changes: 15 additions & 0 deletions packages/core/src/config/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@
);
}

/**
* Checks if the model is a Gemini 3 model.
*
* @param model The model name to check.
* @param previewFeaturesEnabled A boolean indicating if preview features are enabled.
* @returns True if the model is a Gemini 3 model.
*/
export function isGemini3Model(
model: string,
previewFeaturesEnabled: boolean = false,
): boolean {
const resolved = resolveModel(model, previewFeaturesEnabled);
return /^gemini-3(\.|-|$)/.test(resolved);
}

/**
* Checks if the model is a Gemini 2.x model.
*
Expand All @@ -144,7 +159,7 @@
* @returns True if the model is a Gemini-2.x model.
*/
export function isGemini2Model(model: string): boolean {
return /^gemini-2(\.|$)/.test(model);

Check warning on line 162 in packages/core/src/config/models.ts

View workflow job for this annotation

GitHub Actions / Lint

Found sensitive keyword "gemini-2". Please make sure this change is appropriate to submit.
}

/**
Expand Down
101 changes: 55 additions & 46 deletions packages/core/src/routing/strategies/classifierStrategy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,16 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { describe, it, expect, vi, beforeEach } from 'vitest';
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { ClassifierStrategy } from './classifierStrategy.js';
import type { RoutingContext } from '../routingStrategy.js';
import type { Config } from '../../config/config.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import {
isFunctionCall,
isFunctionResponse,
} from '../../utils/messageInspectors.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
} from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai';
Expand All @@ -31,6 +28,9 @@ describe('ClassifierStrategy', () => {
let mockConfig: Config;
let mockBaseLlmClient: BaseLlmClient;
let mockResolvedConfig: ResolvedModelConfig;
let mockGetModel: Mock;
let mockGetNumericalRoutingEnabled: Mock;
let mockGenerateJson: Mock;

beforeEach(() => {
vi.clearAllMocks();
Expand All @@ -46,23 +46,30 @@ describe('ClassifierStrategy', () => {
model: 'classifier',
generateContentConfig: {},
} as unknown as ResolvedModelConfig;

mockGetModel = vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
mockGetNumericalRoutingEnabled = vi.fn().mockResolvedValue(false);
mockGenerateJson = vi.fn();

mockConfig = {
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
},
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
getModel: mockGetModel,
getPreviewFeatures: () => false,
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
getNumericalRoutingEnabled: mockGetNumericalRoutingEnabled,
} as unknown as Config;

mockBaseLlmClient = {
generateJson: vi.fn(),
generateJson: mockGenerateJson,
} as unknown as BaseLlmClient;

vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
});

it('should return null if numerical routing is enabled', async () => {
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
it('should return null if numerical routing is enabled and model is Gemini 3', async () => {
mockGetNumericalRoutingEnabled.mockResolvedValue(true);
mockGetModel.mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);

const decision = await strategy.route(
mockContext,
Expand All @@ -71,21 +78,37 @@ describe('ClassifierStrategy', () => {
);

expect(decision).toBeNull();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
expect(mockGenerateJson).not.toHaveBeenCalled();
});

it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => {
mockGetNumericalRoutingEnabled.mockResolvedValue(true);
mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
mockGenerateJson.mockResolvedValue({
reasoning: 'test',
model_choice: 'flash',
});

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(decision).not.toBeNull();
expect(mockGenerateJson).toHaveBeenCalled();
});

it('should call generateJson with the correct parameters', async () => {
const mockApiResponse = {
reasoning: 'Simple task',
model_choice: 'flash',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
mockGenerateJson.mockResolvedValue(mockApiResponse);

await strategy.route(mockContext, mockConfig, mockBaseLlmClient);

expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect(mockGenerateJson).toHaveBeenCalledWith(
expect.objectContaining({
modelConfigKey: { model: mockResolvedConfig.model },
promptId: 'test-prompt-id',
Expand All @@ -98,17 +121,15 @@ describe('ClassifierStrategy', () => {
reasoning: 'This is a simple task.',
model_choice: 'flash',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
mockGenerateJson.mockResolvedValue(mockApiResponse);

const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);

expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
expect(mockGenerateJson).toHaveBeenCalledOnce();
expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL,
metadata: {
Expand All @@ -124,9 +145,7 @@ describe('ClassifierStrategy', () => {
reasoning: 'This is a complex task.',
model_choice: 'pro',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
mockGenerateJson.mockResolvedValue(mockApiResponse);
mockContext.request = [{ text: 'how do I build a spaceship?' }];

const decision = await strategy.route(
Expand All @@ -135,7 +154,7 @@ describe('ClassifierStrategy', () => {
mockBaseLlmClient,
);

expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
expect(mockGenerateJson).toHaveBeenCalledOnce();
expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
metadata: {
Expand All @@ -151,7 +170,7 @@ describe('ClassifierStrategy', () => {
.spyOn(debugLogger, 'warn')
.mockImplementation(() => {});
const testError = new Error('API Failure');
vi.mocked(mockBaseLlmClient.generateJson).mockRejectedValue(testError);
mockGenerateJson.mockRejectedValue(testError);

const decision = await strategy.route(
mockContext,
Expand All @@ -172,9 +191,7 @@ describe('ClassifierStrategy', () => {
reasoning: 'This is a simple task.',
// model_choice is missing, which will cause a Zod parsing error.
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
malformedApiResponse,
);
mockGenerateJson.mockResolvedValue(malformedApiResponse);

const decision = await strategy.route(
mockContext,
Expand Down Expand Up @@ -203,14 +220,11 @@ describe('ClassifierStrategy', () => {
reasoning: 'Simple.',
model_choice: 'flash',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
mockGenerateJson.mockResolvedValue(mockApiResponse);

await strategy.route(mockContext, mockConfig, mockBaseLlmClient);

const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
.calls[0][0];
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
const contents = generateJsonCall.contents;

const expectedContents = [
Expand Down Expand Up @@ -239,22 +253,22 @@ describe('ClassifierStrategy', () => {
reasoning: 'Simple.',
model_choice: 'flash',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
mockGenerateJson.mockResolvedValue(mockApiResponse);

await strategy.route(mockContext, mockConfig, mockBaseLlmClient);

const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
.calls[0][0];
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
const contents = generateJsonCall.contents;

// Manually calculate what the history should be
const HISTORY_SEARCH_WINDOW = 20;
const HISTORY_TURNS_FOR_CONTEXT = 4;
const historySlice = longHistory.slice(-HISTORY_SEARCH_WINDOW);
const cleanHistory = historySlice.filter(
(content) => !isFunctionCall(content) && !isFunctionResponse(content),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(content: any) =>
!content.parts?.[0]?.functionCall &&
!content.parts?.[0]?.functionResponse,
);
const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);

Expand All @@ -275,14 +289,11 @@ describe('ClassifierStrategy', () => {
reasoning: 'Simple.',
model_choice: 'flash',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
mockGenerateJson.mockResolvedValue(mockApiResponse);

await strategy.route(mockContext, mockConfig, mockBaseLlmClient);

const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
.calls[0][0];
const generateJsonCall = mockGenerateJson.mock.calls[0][0];

expect(generateJsonCall.promptId).toMatch(
/^classifier-router-fallback-\d+-\w+$/,
Expand All @@ -301,9 +312,7 @@ describe('ClassifierStrategy', () => {
reasoning: 'Choice is flash',
model_choice: 'flash',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
mockGenerateJson.mockResolvedValue(mockApiResponse);

const contextWithRequestedModel = {
...mockContext,
Expand Down
11 changes: 8 additions & 3 deletions packages/core/src/routing/strategies/classifierStrategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import type {
RoutingDecision,
RoutingStrategy,
} from '../routingStrategy.js';
import { resolveClassifierModel } from '../../config/models.js';
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js';
import {
Expand Down Expand Up @@ -133,7 +133,12 @@ export class ClassifierStrategy implements RoutingStrategy {
): Promise<RoutingDecision | null> {
const startTime = Date.now();
try {
if (await config.getNumericalRoutingEnabled()) {
const model = context.requestedModel ?? config.getModel();
const previewFeaturesEnabled = config.getPreviewFeatures();
if (
(await config.getNumericalRoutingEnabled()) &&
isGemini3Model(model, previewFeaturesEnabled)
) {
return null;
}

Expand Down Expand Up @@ -164,7 +169,7 @@ export class ClassifierStrategy implements RoutingStrategy {
const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime;
const selectedModel = resolveClassifierModel(
context.requestedModel ?? config.getModel(),
model,
routerResponse.model_choice,
config.getPreviewFeatures(),
);
Expand Down
Loading
Loading