diff --git a/dcp.schema.json b/dcp.schema.json index e25f09b1..a0356ee6 100644 --- a/dcp.schema.json +++ b/dcp.schema.json @@ -121,6 +121,21 @@ "pattern": "^\\d+(?:\\.\\d+)?%$" } ] + }, + "modelLimits": { + "description": "Model-specific context limits with optional wildcard patterns (exact match first, then most specific wildcard). Examples: \"openai/gpt-5\", \"*/zen-1\", \"ollama/*\", \"*sonnet*\"", + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "number" + }, + { + "type": "string", + "pattern": "^\\d+(?:\\.\\d+)?%$" + } + ] + } } } }, diff --git a/lib/config.ts b/lib/config.ts index 1a60307f..082bc163 100644 --- a/lib/config.ts +++ b/lib/config.ts @@ -28,6 +28,7 @@ export interface ToolSettings { nudgeFrequency: number protectedTools: string[] contextLimit: number | `${number}%` + modelLimits?: Record } export interface Tools { @@ -107,6 +108,7 @@ export const VALID_CONFIG_KEYS = new Set([ "tools.settings.nudgeFrequency", "tools.settings.protectedTools", "tools.settings.contextLimit", + "tools.settings.modelLimits", "tools.distill", "tools.distill.permission", "tools.distill.showDistillation", @@ -136,6 +138,12 @@ function getConfigKeyPaths(obj: Record, prefix = ""): string[] { for (const key of Object.keys(obj)) { const fullKey = prefix ? `${prefix}.${key}` : key keys.push(fullKey) + + // modelLimits is a dynamic map keyed by model ID; do not recurse into arbitrary IDs. + if (fullKey === "tools.settings.modelLimits") { + continue + } + if (obj[key] && typeof obj[key] === "object" && !Array.isArray(obj[key])) { keys.push(...getConfigKeyPaths(obj[key], fullKey)) } @@ -156,7 +164,7 @@ interface ValidationError { actual: string } -function validateConfigTypes(config: Record): ValidationError[] { +export function validateConfigTypes(config: Record): ValidationError[] { const errors: ValidationError[] = [] // Top-level validators @@ -303,9 +311,32 @@ function validateConfigTypes(config: Record): ValidationError[] { }) } } - } - if (tools.distill) { - if (tools.distill.permission !== undefined) { + if (tools.settings.modelLimits !== undefined) { + if ( + typeof tools.settings.modelLimits !== "object" || + Array.isArray(tools.settings.modelLimits) + ) { + errors.push({ + key: "tools.settings.modelLimits", + expected: "Record", + actual: typeof tools.settings.modelLimits, + }) + } else { + for (const [modelId, limit] of Object.entries(tools.settings.modelLimits)) { + const isValidNumber = typeof limit === "number" + const isPercentString = + typeof limit === "string" && /^\d+(?:\.\d+)?%$/.test(limit) + if (!isValidNumber && !isPercentString) { + errors.push({ + key: `tools.settings.modelLimits.${modelId}`, + expected: 'number | "${number}%"', + actual: JSON.stringify(limit), + }) + } + } + } + } + if (tools.distill?.permission !== undefined) { const validValues = ["ask", "allow", "deny"] if (!validValues.includes(tools.distill.permission)) { errors.push({ @@ -316,7 +347,7 @@ function validateConfigTypes(config: Record): ValidationError[] { } } if ( - tools.distill.showDistillation !== undefined && + tools.distill?.showDistillation !== undefined && typeof tools.distill.showDistillation !== "boolean" ) { errors.push({ @@ -684,6 +715,7 @@ function mergeTools( ]), ], contextLimit: override.settings?.contextLimit ?? base.settings.contextLimit, + modelLimits: override.settings?.modelLimits ?? base.settings.modelLimits, }, distill: { permission: override.distill?.permission ?? base.distill.permission, @@ -724,6 +756,7 @@ function deepCloneConfig(config: PluginConfig): PluginConfig { settings: { ...config.tools.settings, protectedTools: [...config.tools.settings.protectedTools], + modelLimits: { ...config.tools.settings.modelLimits }, }, distill: { ...config.tools.distill }, compress: { ...config.tools.compress }, diff --git a/lib/messages/inject.ts b/lib/messages/inject.ts index 3f0c60b1..a7228a85 100644 --- a/lib/messages/inject.ts +++ b/lib/messages/inject.ts @@ -27,6 +27,48 @@ function parsePercentageString(value: string, total: number): number | undefined return Math.round((clampedPercent / 100) * total) } +const escapeRegex = (value: string): string => { + return value.replace(/[.+?^${}()|[\]\\]/g, "\\$&") +} + +const wildcardPatternToRegex = (pattern: string): RegExp => { + const escapedPattern = escapeRegex(pattern) + const regexPattern = escapedPattern.replace(/\*/g, ".*") + return new RegExp(`^${regexPattern}$`) +} + +const wildcardSpecificity = (pattern: string): number => { + return pattern.replace(/\*/g, "").length +} + +export const findModelLimit = ( + modelId: string, + modelLimits: Record, +): number | `${number}%` | undefined => { + const exactMatch = modelLimits[modelId] + if (exactMatch !== undefined) { + return exactMatch + } + + const wildcardMatches = Object.entries(modelLimits) + .filter(([pattern]) => pattern.includes("*")) + .filter(([pattern]) => wildcardPatternToRegex(pattern).test(modelId)) + + if (wildcardMatches.length === 0) { + return undefined + } + + wildcardMatches.sort(([leftPattern], [rightPattern]) => { + const specificityDiff = wildcardSpecificity(rightPattern) - wildcardSpecificity(leftPattern) + if (specificityDiff !== 0) { + return specificityDiff + } + return leftPattern.localeCompare(rightPattern) + }) + + return wildcardMatches[0][1] +} + // XML wrappers export const wrapPrunableTools = (content: string): string => { return ` @@ -66,21 +108,41 @@ Context management was just performed. Do NOT use the ${toolName} again. A fresh ` } -const resolveContextLimit = (config: PluginConfig, state: SessionState): number | undefined => { - const configLimit = config.tools.settings.contextLimit +const resolveContextLimit = ( + config: PluginConfig, + state: SessionState, + messages: WithParts[], +): number | undefined => { + const { settings } = config.tools + const { modelLimits, contextLimit } = settings + + if (modelLimits) { + const userMsg = getLastUserMessage(messages) + const modelId = userMsg ? (userMsg.info as UserMessage).model.modelID : undefined + const limit = modelId !== undefined ? findModelLimit(modelId, modelLimits) : undefined + + if (limit !== undefined) { + if (typeof limit === "string" && limit.endsWith("%")) { + if (state.modelContextLimit === undefined) { + return undefined + } + return parsePercentageString(limit, state.modelContextLimit) + } + return typeof limit === "number" ? limit : undefined + } + } - if (typeof configLimit === "string") { - if (configLimit.endsWith("%")) { + if (typeof contextLimit === "string") { + if (contextLimit.endsWith("%")) { if (state.modelContextLimit === undefined) { return undefined } - return parsePercentageString(configLimit, state.modelContextLimit) + return parsePercentageString(contextLimit, state.modelContextLimit) } - return undefined } - return configLimit + return contextLimit } const shouldInjectCompressNudge = ( @@ -92,7 +154,7 @@ const shouldInjectCompressNudge = ( return false } - const contextLimit = resolveContextLimit(config, state) + const contextLimit = resolveContextLimit(config, state, messages) if (contextLimit === undefined) { return false } diff --git a/tests/config-model-limits.test.ts b/tests/config-model-limits.test.ts new file mode 100644 index 00000000..82a6857e --- /dev/null +++ b/tests/config-model-limits.test.ts @@ -0,0 +1,158 @@ +import assert from "node:assert" +import { describe, it } from "node:test" +import { getInvalidConfigKeys, validateConfigTypes } from "../lib/config" + +function createConfig(modelLimits?: Record) { + return { + enabled: true, + debug: false, + pruneNotification: "minimal", + pruneNotificationType: "chat", + commands: { + enabled: true, + protectedTools: [], + }, + turnProtection: { + enabled: false, + turns: 0, + }, + protectedFilePatterns: [], + tools: { + settings: { + nudgeEnabled: true, + nudgeFrequency: 5, + protectedTools: [], + contextLimit: "60%", + ...(modelLimits !== undefined ? { modelLimits } : {}), + }, + distill: { + permission: "allow", + showDistillation: false, + }, + compress: { + permission: "deny", + showCompression: false, + }, + prune: { + permission: "allow", + }, + }, + strategies: { + deduplication: { + enabled: true, + protectedTools: [], + }, + supersedeWrites: { + enabled: true, + }, + purgeErrors: { + enabled: true, + turns: 4, + protectedTools: [], + }, + }, + } +} + +describe("Config Validation - modelLimits", () => { + it("accepts valid modelLimits configuration", () => { + const config = createConfig({ + "anthropic/claude-3.5-sonnet": "70%", + "anthropic/claude-3-opus": 150000, + "gpt-4": "80%", + }) + + const errors = validateConfigTypes(config) + assert.strictEqual(errors.length, 0) + }) + + it("rejects invalid modelLimits string value", () => { + const config = createConfig({ + "anthropic/claude-3.5-sonnet": "invalid", + }) + + const errors = validateConfigTypes(config) + assert.ok( + errors.some( + (error) => error.key === "tools.settings.modelLimits.anthropic/claude-3.5-sonnet", + ), + ) + }) + + it("rejects modelLimits when not an object", () => { + const config = createConfig() + ;(config.tools.settings as any).modelLimits = "not-an-object" + + const errors = validateConfigTypes(config) + assert.ok(errors.some((error) => error.key === "tools.settings.modelLimits")) + }) + + it("works without modelLimits", () => { + const config = createConfig() + + const errors = validateConfigTypes(config) + assert.strictEqual(errors.length, 0) + }) + + it("rejects malformed percentage strings", () => { + const config = createConfig({ + model1: "abc%", + model2: "50 %", + model3: "%50", + model4: "50.5.5%", + }) + + const errors = validateConfigTypes(config) + assert.ok(errors.some((error) => error.key === "tools.settings.modelLimits.model1")) + assert.ok(errors.some((error) => error.key === "tools.settings.modelLimits.model2")) + assert.ok(errors.some((error) => error.key === "tools.settings.modelLimits.model3")) + assert.ok(errors.some((error) => error.key === "tools.settings.modelLimits.model4")) + }) + + it("rejects strings without percent suffix", () => { + const config = createConfig({ model: "50" }) + + const errors = validateConfigTypes(config) + assert.ok(errors.some((error) => error.key === "tools.settings.modelLimits.model")) + }) + + it("rejects empty strings", () => { + const config = createConfig({ model: "" }) + + const errors = validateConfigTypes(config) + assert.ok(errors.some((error) => error.key === "tools.settings.modelLimits.model")) + }) + + it("accepts boundary percentages and numbers", () => { + const config = createConfig({ + p0: "0%", + p100: "100%", + n0: 0, + negative: -50000, + above100: "150%", + decimal: "50.5%", + huge: 1000000000000, + }) + + const errors = validateConfigTypes(config) + assert.strictEqual(errors.length, 0) + }) + + it("rejects modelLimits arrays", () => { + const config = createConfig() + ;(config.tools.settings as any).modelLimits = ["not-an-object"] + + const errors = validateConfigTypes(config) + assert.ok(errors.some((error) => error.key === "tools.settings.modelLimits")) + }) + + it("does not flag model-specific keys as unknown config keys", () => { + const config = createConfig({ + "anthropic/claude-3.5-sonnet": "70%", + "openai/gpt-4o": 120000, + }) + + const invalidKeys = getInvalidConfigKeys(config) + assert.strictEqual(invalidKeys.length, 0) + }) +}) diff --git a/tests/inject-model-limits-wildcard.test.ts b/tests/inject-model-limits-wildcard.test.ts new file mode 100644 index 00000000..502e3435 --- /dev/null +++ b/tests/inject-model-limits-wildcard.test.ts @@ -0,0 +1,71 @@ +import assert from "node:assert" +import { describe, it } from "node:test" +import { findModelLimit } from "../lib/messages/inject" + +describe("findModelLimit", () => { + it("prefers exact matches over wildcard matches", () => { + const modelLimits = { + "ollama/zen-1": "35%", + "*/zen-1": "40%", + } + + const limit = findModelLimit("ollama/zen-1", modelLimits) + assert.strictEqual(limit, "35%") + }) + + it("matches provider wildcard patterns", () => { + const modelLimits = { + "*/zen-1": "40%", + } + + const limit = findModelLimit("opencode/zen-1", modelLimits) + assert.strictEqual(limit, "40%") + }) + + it("matches model wildcard patterns", () => { + const modelLimits = { + "ollama/*": "25%", + } + + const limit = findModelLimit("ollama/zen-3", modelLimits) + assert.strictEqual(limit, "25%") + }) + + it("matches substring wildcard patterns", () => { + const modelLimits = { + "*sonnet*": 120000, + } + + const limit = findModelLimit("anthropic/claude-3.5-sonnet", modelLimits) + assert.strictEqual(limit, 120000) + }) + + it("prefers the most specific wildcard pattern", () => { + const modelLimits = { + "*sonnet*": "45%", + "ollama/*": "25%", + } + + const limit = findModelLimit("ollama/sonnet", modelLimits) + assert.strictEqual(limit, "25%") + }) + + it("uses lexical order as deterministic tiebreaker", () => { + const modelLimits = { + "a*": 100, + "*a": 200, + } + + const limit = findModelLimit("a", modelLimits) + assert.strictEqual(limit, 200) + }) + + it("returns undefined when no pattern matches", () => { + const modelLimits = { + "ollama/*": "25%", + } + + const limit = findModelLimit("openai/gpt-5", modelLimits) + assert.strictEqual(limit, undefined) + }) +}) diff --git a/tests/schema-model-limits.test.ts b/tests/schema-model-limits.test.ts new file mode 100644 index 00000000..e31c8fea --- /dev/null +++ b/tests/schema-model-limits.test.ts @@ -0,0 +1,84 @@ +import { describe, it } from "node:test" +import assert from "node:assert" +import { readFile } from "fs/promises" +import { fileURLToPath } from "url" +import { dirname, join } from "path" + +const __filename = fileURLToPath(import.meta.url) +const __dirname = dirname(__filename) + +describe("Schema Validation - modelLimits", () => { + it("should accept valid modelLimits configuration", async () => { + const schema = JSON.parse(await readFile(join(__dirname, "../dcp.schema.json"), "utf-8")) + const modelLimitsSchema = + schema.properties?.tools?.properties?.settings?.properties?.modelLimits + + assert.ok(modelLimitsSchema, "modelLimits field should exist") + assert.strictEqual(modelLimitsSchema.type, "object") + assert.ok(modelLimitsSchema.additionalProperties) + assert.ok(modelLimitsSchema.additionalProperties.oneOf) + assert.strictEqual(modelLimitsSchema.additionalProperties.oneOf.length, 2) + }) + + it("should support number values in modelLimits", async () => { + const schema = JSON.parse(await readFile(join(__dirname, "../dcp.schema.json"), "utf-8")) + const numberSchema = + schema.properties?.tools?.properties?.settings?.properties?.modelLimits + ?.additionalProperties?.oneOf?.[0] + + assert.ok(numberSchema, "number schema should exist") + assert.strictEqual(numberSchema.type, "number") + }) + + it("should support percentage strings in modelLimits", async () => { + const schema = JSON.parse(await readFile(join(__dirname, "../dcp.schema.json"), "utf-8")) + const percentSchema = + schema.properties?.tools?.properties?.settings?.properties?.modelLimits + ?.additionalProperties?.oneOf?.[1] + + assert.ok(percentSchema, "percentage schema should exist") + assert.strictEqual(percentSchema.type, "string") + assert.ok(percentSchema.pattern) + assert.strictEqual(percentSchema.pattern, "^\\d+(?:\\.\\d+)?%$") + }) + + // Test valid percentage patterns + it("should accept valid percentage patterns", async () => { + const schema = JSON.parse(await readFile(join(__dirname, "../dcp.schema.json"), "utf-8")) + const pattern = + schema.properties?.tools?.properties?.settings?.properties?.modelLimits + ?.additionalProperties?.oneOf?.[1]?.pattern + + const validPatterns = ["0%", "50%", "100%", "50.5%", "0.1%", "99.99%", "1000%"] + const regex = new RegExp(pattern) + + for (const test of validPatterns) { + assert.ok(regex.test(test), `Should accept: ${test}`) + } + }) + + it("should reject invalid percentage patterns", async () => { + const schema = JSON.parse(await readFile(join(__dirname, "../dcp.schema.json"), "utf-8")) + const pattern = + schema.properties?.tools?.properties?.settings?.properties?.modelLimits + ?.additionalProperties?.oneOf?.[1]?.pattern + + const invalidPatterns = [ + "abc%", // non-numeric + "50 %", // space before % + "%50", // % before number + "50.5.5%", // multiple decimals + "%%", // no number + "", // empty string + "50", // no % + "-50%", // negative (regex doesn't support -) + ".5%", // starts with decimal + "50.%", // decimal without fraction + ] + const regex = new RegExp(pattern) + + for (const test of invalidPatterns) { + assert.ok(!regex.test(test), `Should reject: ${test}`) + } + }) +})