Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/types/src/codebase-index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ export const codebaseIndexConfigSchema = z.object({
// Bedrock specific fields
codebaseIndexBedrockRegion: z.string().optional(),
codebaseIndexBedrockProfile: z.string().optional(),
// OpenRouter specific fields
codebaseIndexOpenRouterSpecificProvider: z.string().optional(),
})

export type CodebaseIndexConfig = z.infer<typeof codebaseIndexConfigSchema>
Expand Down
3 changes: 3 additions & 0 deletions src/core/webview/ClineProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2079,6 +2079,7 @@ export class ClineProvider
codebaseIndexSearchMinScore: codebaseIndexConfig?.codebaseIndexSearchMinScore,
codebaseIndexBedrockRegion: codebaseIndexConfig?.codebaseIndexBedrockRegion,
codebaseIndexBedrockProfile: codebaseIndexConfig?.codebaseIndexBedrockProfile,
codebaseIndexOpenRouterSpecificProvider: codebaseIndexConfig?.codebaseIndexOpenRouterSpecificProvider,
},
// Only set mdmCompliant if there's an actual MDM policy
// undefined means no MDM policy, true means compliant, false means non-compliant
Expand Down Expand Up @@ -2310,6 +2311,8 @@ export class ClineProvider
codebaseIndexSearchMinScore: stateValues.codebaseIndexConfig?.codebaseIndexSearchMinScore,
codebaseIndexBedrockRegion: stateValues.codebaseIndexConfig?.codebaseIndexBedrockRegion,
codebaseIndexBedrockProfile: stateValues.codebaseIndexConfig?.codebaseIndexBedrockProfile,
codebaseIndexOpenRouterSpecificProvider:
stateValues.codebaseIndexConfig?.codebaseIndexOpenRouterSpecificProvider,
},
profileThresholds: stateValues.profileThresholds ?? {},
includeDiagnosticMessages: stateValues.includeDiagnosticMessages ?? true,
Expand Down
1 change: 1 addition & 0 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2373,6 +2373,7 @@ export const webviewMessageHandler = async (
codebaseIndexBedrockProfile: settings.codebaseIndexBedrockProfile,
codebaseIndexSearchMaxResults: settings.codebaseIndexSearchMaxResults,
codebaseIndexSearchMinScore: settings.codebaseIndexSearchMinScore,
codebaseIndexOpenRouterSpecificProvider: settings.codebaseIndexOpenRouterSpecificProvider,
}

// Save global state first
Expand Down
15 changes: 13 additions & 2 deletions src/services/code-index/config-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export class CodeIndexConfigManager {
private mistralOptions?: { apiKey: string }
private vercelAiGatewayOptions?: { apiKey: string }
private bedrockOptions?: { region: string; profile?: string }
private openRouterOptions?: { apiKey: string }
private openRouterOptions?: { apiKey: string; specificProvider?: string }
private qdrantUrl?: string = "http://localhost:6333"
private qdrantApiKey?: string
private searchMinScore?: number
Expand Down Expand Up @@ -78,6 +78,7 @@ export class CodeIndexConfigManager {
const bedrockRegion = codebaseIndexConfig.codebaseIndexBedrockRegion ?? "us-east-1"
const bedrockProfile = codebaseIndexConfig.codebaseIndexBedrockProfile ?? ""
const openRouterApiKey = this.contextProxy?.getSecret("codebaseIndexOpenRouterApiKey") ?? ""
const openRouterSpecificProvider = codebaseIndexConfig.codebaseIndexOpenRouterSpecificProvider ?? ""

// Update instance variables with configuration
this.codebaseIndexEnabled = codebaseIndexEnabled ?? false
Expand Down Expand Up @@ -140,7 +141,9 @@ export class CodeIndexConfigManager {
this.geminiOptions = geminiApiKey ? { apiKey: geminiApiKey } : undefined
this.mistralOptions = mistralApiKey ? { apiKey: mistralApiKey } : undefined
this.vercelAiGatewayOptions = vercelAiGatewayApiKey ? { apiKey: vercelAiGatewayApiKey } : undefined
this.openRouterOptions = openRouterApiKey ? { apiKey: openRouterApiKey } : undefined
this.openRouterOptions = openRouterApiKey
? { apiKey: openRouterApiKey, specificProvider: openRouterSpecificProvider || undefined }
: undefined
// Set bedrockOptions if region is provided (profile is optional)
this.bedrockOptions = bedrockRegion
? { region: bedrockRegion, profile: bedrockProfile || undefined }
Expand Down Expand Up @@ -188,6 +191,7 @@ export class CodeIndexConfigManager {
bedrockRegion: this.bedrockOptions?.region ?? "",
bedrockProfile: this.bedrockOptions?.profile ?? "",
openRouterApiKey: this.openRouterOptions?.apiKey ?? "",
openRouterSpecificProvider: this.openRouterOptions?.specificProvider ?? "",
qdrantUrl: this.qdrantUrl ?? "",
qdrantApiKey: this.qdrantApiKey ?? "",
}
Expand Down Expand Up @@ -306,6 +310,7 @@ export class CodeIndexConfigManager {
const prevBedrockRegion = prev?.bedrockRegion ?? ""
const prevBedrockProfile = prev?.bedrockProfile ?? ""
const prevOpenRouterApiKey = prev?.openRouterApiKey ?? ""
const prevOpenRouterSpecificProvider = prev?.openRouterSpecificProvider ?? ""
const prevQdrantUrl = prev?.qdrantUrl ?? ""
const prevQdrantApiKey = prev?.qdrantApiKey ?? ""

Expand Down Expand Up @@ -347,6 +352,7 @@ export class CodeIndexConfigManager {
const currentBedrockRegion = this.bedrockOptions?.region ?? ""
const currentBedrockProfile = this.bedrockOptions?.profile ?? ""
const currentOpenRouterApiKey = this.openRouterOptions?.apiKey ?? ""
const currentOpenRouterSpecificProvider = this.openRouterOptions?.specificProvider ?? ""
const currentQdrantUrl = this.qdrantUrl ?? ""
const currentQdrantApiKey = this.qdrantApiKey ?? ""

Expand Down Expand Up @@ -385,6 +391,11 @@ export class CodeIndexConfigManager {
return true
}

// OpenRouter specific provider change
if (prevOpenRouterSpecificProvider !== currentOpenRouterSpecificProvider) {
return true
}

// Check for model dimension changes (generic for all providers)
if (prevModelDimension !== currentModelDimension) {
return true
Expand Down
120 changes: 119 additions & 1 deletion src/services/code-index/embedders/__tests__/openrouter.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { MockedClass, MockedFunction } from "vitest"
import { describe, it, expect, beforeEach, vi } from "vitest"
import { OpenAI } from "openai"
import { OpenRouterEmbedder } from "../openrouter"
import { OpenRouterEmbedder, OPENROUTER_DEFAULT_PROVIDER_NAME } from "../openrouter"
import { getModelDimension, getDefaultModelId } from "../../../../shared/embeddingModels"

// Mock the OpenAI SDK
Expand Down Expand Up @@ -95,6 +95,16 @@ describe("OpenRouterEmbedder", () => {
},
})
})

it("should accept specificProvider parameter", () => {
const embedder = new OpenRouterEmbedder(mockApiKey, undefined, undefined, "together")
expect(embedder).toBeInstanceOf(OpenRouterEmbedder)
})

it("should ignore default provider name as specificProvider", () => {
const embedder = new OpenRouterEmbedder(mockApiKey, undefined, undefined, OPENROUTER_DEFAULT_PROVIDER_NAME)
expect(embedder).toBeInstanceOf(OpenRouterEmbedder)
})
})

describe("embedderInfo", () => {
Expand Down Expand Up @@ -205,6 +215,77 @@ describe("OpenRouterEmbedder", () => {
encoding_format: "base64",
})
})

it("should include provider routing when specificProvider is set", async () => {
const specificProvider = "together"
const embedderWithProvider = new OpenRouterEmbedder(mockApiKey, undefined, undefined, specificProvider)

const testEmbedding = new Float32Array([0.25, 0.5])
const base64String = Buffer.from(testEmbedding.buffer).toString("base64")

const mockResponse = {
data: [
{
embedding: base64String,
},
],
usage: {
prompt_tokens: 5,
total_tokens: 5,
},
}

mockEmbeddingsCreate.mockResolvedValue(mockResponse)

await embedderWithProvider.createEmbeddings(["test"])

// Verify the embeddings.create was called with provider routing
expect(mockEmbeddingsCreate).toHaveBeenCalledWith({
input: ["test"],
model: "openai/text-embedding-3-large",
encoding_format: "base64",
provider: {
order: [specificProvider],
only: [specificProvider],
allow_fallbacks: false,
},
})
})

it("should not include provider routing when specificProvider is default", async () => {
const embedderWithDefaultProvider = new OpenRouterEmbedder(
mockApiKey,
undefined,
undefined,
OPENROUTER_DEFAULT_PROVIDER_NAME,
)

const testEmbedding = new Float32Array([0.25, 0.5])
const base64String = Buffer.from(testEmbedding.buffer).toString("base64")

const mockResponse = {
data: [
{
embedding: base64String,
},
],
usage: {
prompt_tokens: 5,
total_tokens: 5,
},
}

mockEmbeddingsCreate.mockResolvedValue(mockResponse)

await embedderWithDefaultProvider.createEmbeddings(["test"])

// Verify the embeddings.create was called without provider routing
expect(mockEmbeddingsCreate).toHaveBeenCalledWith({
input: ["test"],
model: "openai/text-embedding-3-large",
encoding_format: "base64",
})
})
})

describe("validateConfiguration", () => {
Expand Down Expand Up @@ -254,6 +335,43 @@ describe("OpenRouterEmbedder", () => {
expect(result.valid).toBe(false)
expect(result.error).toBe("embeddings:validation.authenticationFailed")
})

it("should validate configuration with specificProvider", async () => {
const specificProvider = "openai"
const embedderWithProvider = new OpenRouterEmbedder(mockApiKey, undefined, undefined, specificProvider)

const testEmbedding = new Float32Array([0.25, 0.5])
const base64String = Buffer.from(testEmbedding.buffer).toString("base64")

const mockResponse = {
data: [
{
embedding: base64String,
},
],
usage: {
prompt_tokens: 1,
total_tokens: 1,
},
}

mockEmbeddingsCreate.mockResolvedValue(mockResponse)

const result = await embedderWithProvider.validateConfiguration()

expect(result.valid).toBe(true)
expect(result.error).toBeUndefined()
expect(mockEmbeddingsCreate).toHaveBeenCalledWith({
input: ["test"],
model: "openai/text-embedding-3-large",
encoding_format: "base64",
provider: {
order: [specificProvider],
only: [specificProvider],
allow_fallbacks: false,
},
})
})
})

describe("integration with shared models", () => {
Expand Down
46 changes: 41 additions & 5 deletions src/services/code-index/embedders/openrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import { TelemetryService } from "@roo-code/telemetry"
import { Mutex } from "async-mutex"
import { handleOpenAIError } from "../../../api/providers/utils/openai-error-handler"

// Default provider name when no specific provider is selected
export const OPENROUTER_DEFAULT_PROVIDER_NAME = "[default]"

interface EmbeddingItem {
embedding: string | number[]
[key: string]: any
Expand All @@ -38,6 +41,7 @@ export class OpenRouterEmbedder implements IEmbedder {
private readonly apiKey: string
private readonly maxItemTokens: number
private readonly baseUrl: string = "https://openrouter.ai/api/v1"
private readonly specificProvider?: string

// Global rate limiting state shared across all instances
private static globalRateLimitState = {
Expand All @@ -54,13 +58,17 @@ export class OpenRouterEmbedder implements IEmbedder {
* @param apiKey The API key for authentication
* @param modelId Optional model identifier (defaults to "openai/text-embedding-3-large")
* @param maxItemTokens Optional maximum tokens per item (defaults to MAX_ITEM_TOKENS)
* @param specificProvider Optional specific provider to route requests to
*/
constructor(apiKey: string, modelId?: string, maxItemTokens?: number) {
constructor(apiKey: string, modelId?: string, maxItemTokens?: number, specificProvider?: string) {
if (!apiKey) {
throw new Error(t("embeddings:validation.apiKeyRequired"))
}

this.apiKey = apiKey
// Only set specificProvider if it's not the default value
this.specificProvider =
specificProvider && specificProvider !== OPENROUTER_DEFAULT_PROVIDER_NAME ? specificProvider : undefined

// Wrap OpenAI client creation to handle invalid API key characters
try {
Expand Down Expand Up @@ -180,14 +188,28 @@ export class OpenRouterEmbedder implements IEmbedder {
await this.waitForGlobalRateLimit()

try {
const response = (await this.embeddingsClient.embeddings.create({
// Build the request parameters
const requestParams: any = {
input: batchTexts,
model: model,
// OpenAI package (as of v4.78.1) has a parsing issue that truncates embedding dimensions to 256
// when processing numeric arrays, which breaks compatibility with models using larger dimensions.
// By requesting base64 encoding, we bypass the package's parser and handle decoding ourselves.
encoding_format: "base64",
})) as OpenRouterEmbeddingResponse
}

// Add provider routing if a specific provider is set
if (this.specificProvider) {
requestParams.provider = {
order: [this.specificProvider],
only: [this.specificProvider],
allow_fallbacks: false,
}
}

const response = (await this.embeddingsClient.embeddings.create(
requestParams,
)) as OpenRouterEmbeddingResponse

// Convert base64 embeddings to float32 arrays
const processedEmbeddings = response.data.map((item: EmbeddingItem) => {
Expand Down Expand Up @@ -274,11 +296,25 @@ export class OpenRouterEmbedder implements IEmbedder {
const testTexts = ["test"]
const modelToUse = this.defaultModelId

const response = (await this.embeddingsClient.embeddings.create({
// Build the request parameters
const requestParams: any = {
input: testTexts,
model: modelToUse,
encoding_format: "base64",
})) as OpenRouterEmbeddingResponse
}

// Add provider routing if a specific provider is set
if (this.specificProvider) {
requestParams.provider = {
order: [this.specificProvider],
only: [this.specificProvider],
allow_fallbacks: false,
}
}

const response = (await this.embeddingsClient.embeddings.create(
requestParams,
)) as OpenRouterEmbeddingResponse

// Check if we got a valid response
if (!response?.data || response.data.length === 0) {
Expand Down
3 changes: 2 additions & 1 deletion src/services/code-index/interfaces/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export interface CodeIndexConfig {
mistralOptions?: { apiKey: string }
vercelAiGatewayOptions?: { apiKey: string }
bedrockOptions?: { region: string; profile?: string }
openRouterOptions?: { apiKey: string }
openRouterOptions?: { apiKey: string; specificProvider?: string }
qdrantUrl?: string
qdrantApiKey?: string
searchMinScore?: number
Expand All @@ -42,6 +42,7 @@ export type PreviousConfigSnapshot = {
bedrockRegion?: string
bedrockProfile?: string
openRouterApiKey?: string
openRouterSpecificProvider?: string
qdrantUrl?: string
qdrantApiKey?: string
}
7 changes: 6 additions & 1 deletion src/services/code-index/service-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ export class CodeIndexServiceFactory {
if (!config.openRouterOptions?.apiKey) {
throw new Error(t("embeddings:serviceFactory.openRouterConfigMissing"))
}
return new OpenRouterEmbedder(config.openRouterOptions.apiKey, config.modelId)
return new OpenRouterEmbedder(
config.openRouterOptions.apiKey,
config.modelId,
undefined, // maxItemTokens
config.openRouterOptions.specificProvider,
)
}

throw new Error(
Expand Down
1 change: 1 addition & 0 deletions src/shared/WebviewMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ export interface WebviewMessage {
codebaseIndexBedrockProfile?: string
codebaseIndexSearchMaxResults?: number
codebaseIndexSearchMinScore?: number
codebaseIndexOpenRouterSpecificProvider?: string // OpenRouter provider routing

// Secret settings
codeIndexOpenAiKey?: string
Expand Down
Loading
Loading