From 105ce3e518aac2e5768ef5e5837b525035244e34 Mon Sep 17 00:00:00 2001 From: ding113 Date: Tue, 17 Feb 2026 21:00:06 +0800 Subject: [PATCH 1/2] refactor(proxy): introduce EndpointPolicy to replace hardcoded count_tokens checks Replace scattered isCountTokensRequest() conditionals with a unified EndpointPolicy system resolved once at session construction time. This generalizes the "raw passthrough" behavior to cover both count_tokens and responses/compact endpoints via a single policy object. Key changes: - Add endpoint-paths.ts (path constants + normalization with case/slash/query handling) - Add endpoint-policy.ts (EndpointPolicy interface + resolution logic) - ProxySession holds immutable EndpointPolicy resolved at construction - GuardPipeline.fromSession() reads policy instead of RequestType enum - Forwarder, ResponseHandler, RequestFilter all gate on policy flags - proxy-handler uses trackConcurrentRequests from policy --- src/app/v1/_lib/proxy-handler.ts | 11 +- src/app/v1/_lib/proxy/endpoint-paths.ts | 64 ++++ src/app/v1/_lib/proxy/endpoint-policy.ts | 68 ++++ src/app/v1/_lib/proxy/forwarder.ts | 251 ++++++------- src/app/v1/_lib/proxy/guard-pipeline.ts | 32 +- .../v1/_lib/proxy/provider-request-filter.ts | 4 + src/app/v1/_lib/proxy/request-filter.ts | 4 + src/app/v1/_lib/proxy/response-handler.ts | 100 ++--- src/app/v1/_lib/proxy/session.ts | 22 +- .../proxy/endpoint-path-normalization.test.ts | 56 +++ .../unit/proxy/endpoint-policy-parity.test.ts | 341 ++++++++++++++++++ tests/unit/proxy/endpoint-policy.test.ts | 61 ++++ .../unit/proxy/guard-pipeline-warmup.test.ts | 71 +++- .../proxy-forwarder-fake-200-html.test.ts | 2 + ...y-forwarder-large-chunked-response.test.ts | 2 + .../proxy-forwarder-nonok-body-hang.test.ts | 2 + .../proxy/proxy-forwarder-retry-limit.test.ts | 69 ++++ .../proxy-handler-session-id-error.test.ts | 44 +++ ...handler-endpoint-circuit-isolation.test.ts | 2 + ...gemini-stream-passthrough-timeouts.test.ts | 2 + .../response-handler-lease-decrement.test.ts | 2 + tests/unit/proxy/session.test.ts | 49 +++ 22 files changed, 1062 insertions(+), 197 deletions(-) create mode 100644 src/app/v1/_lib/proxy/endpoint-paths.ts create mode 100644 src/app/v1/_lib/proxy/endpoint-policy.ts create mode 100644 tests/unit/proxy/endpoint-path-normalization.test.ts create mode 100644 tests/unit/proxy/endpoint-policy-parity.test.ts create mode 100644 tests/unit/proxy/endpoint-policy.test.ts diff --git a/src/app/v1/_lib/proxy-handler.ts b/src/app/v1/_lib/proxy-handler.ts index 5f2b90b4e..744791aa9 100644 --- a/src/app/v1/_lib/proxy-handler.ts +++ b/src/app/v1/_lib/proxy-handler.ts @@ -7,7 +7,7 @@ import { attachSessionIdToErrorResponse } from "./proxy/error-session-id"; import { ProxyError } from "./proxy/errors"; import { detectClientFormat, detectFormatByEndpoint } from "./proxy/format-mapper"; import { ProxyForwarder } from "./proxy/forwarder"; -import { GuardPipelineBuilder, RequestType } from "./proxy/guard-pipeline"; +import { GuardPipelineBuilder } from "./proxy/guard-pipeline"; import { ProxyResponseHandler } from "./proxy/response-handler"; import { ProxyResponses } from "./proxy/responses"; import { ProxySession } from "./proxy/session"; @@ -49,9 +49,8 @@ export async function handleProxyRequest(c: Context): Promise { } } - // Decide request type and build configured guard pipeline - const type = session.isCountTokensRequest() ? RequestType.COUNT_TOKENS : RequestType.CHAT; - const pipeline = GuardPipelineBuilder.fromRequestType(type); + // Build guard pipeline from session endpoint policy + const pipeline = GuardPipelineBuilder.fromSession(session); // Run guard chain; may return early Response const early = await pipeline.run(session); @@ -60,7 +59,7 @@ export async function handleProxyRequest(c: Context): Promise { } // 9. 增加并发计数(在所有检查通过后,请求开始前)- 跳过 count_tokens - if (session.sessionId && !session.isCountTokensRequest()) { + if (session.sessionId && session.getEndpointPolicy().trackConcurrentRequests) { await SessionTracker.incrementConcurrentCount(session.sessionId); } @@ -97,7 +96,7 @@ export async function handleProxyRequest(c: Context): Promise { return ProxyResponses.buildError(500, "代理请求发生未知错误"); } finally { // 11. 减少并发计数(确保无论成功失败都执行)- 跳过 count_tokens - if (session?.sessionId && !session.isCountTokensRequest()) { + if (session?.sessionId && session.getEndpointPolicy().trackConcurrentRequests) { await SessionTracker.decrementConcurrentCount(session.sessionId); } } diff --git a/src/app/v1/_lib/proxy/endpoint-paths.ts b/src/app/v1/_lib/proxy/endpoint-paths.ts new file mode 100644 index 000000000..2c32d4411 --- /dev/null +++ b/src/app/v1/_lib/proxy/endpoint-paths.ts @@ -0,0 +1,64 @@ +const V1_PREFIX = "/v1"; + +export const V1_ENDPOINT_PATHS = { + MESSAGES: "/v1/messages", + MESSAGES_COUNT_TOKENS: "/v1/messages/count_tokens", + RESPONSES: "/v1/responses", + RESPONSES_COMPACT: "/v1/responses/compact", + CHAT_COMPLETIONS: "/v1/chat/completions", + MODELS: "/v1/models", +} as const; + +export const STANDARD_ENDPOINT_PATHS = [ + V1_ENDPOINT_PATHS.MESSAGES, + V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS, + V1_ENDPOINT_PATHS.RESPONSES, + V1_ENDPOINT_PATHS.RESPONSES_COMPACT, + V1_ENDPOINT_PATHS.CHAT_COMPLETIONS, + V1_ENDPOINT_PATHS.MODELS, +] as const; + +export const STRICT_STANDARD_ENDPOINT_PATHS = [ + V1_ENDPOINT_PATHS.MESSAGES, + V1_ENDPOINT_PATHS.RESPONSES, + V1_ENDPOINT_PATHS.RESPONSES_COMPACT, + V1_ENDPOINT_PATHS.CHAT_COMPLETIONS, +] as const; + +const standardEndpointPathSet = new Set(STANDARD_ENDPOINT_PATHS); +const strictStandardEndpointPathSet = new Set(STRICT_STANDARD_ENDPOINT_PATHS); + +export function normalizeEndpointPath(pathname: string): string { + const pathWithoutQuery = pathname.split("?")[0]; + const trimmedPath = + pathWithoutQuery.length > 1 && pathWithoutQuery.endsWith("/") + ? pathWithoutQuery.slice(0, -1) + : pathWithoutQuery; + + return trimmedPath.toLowerCase(); +} + +export function isStandardEndpointPath(pathname: string): boolean { + return standardEndpointPathSet.has(normalizeEndpointPath(pathname)); +} + +export function isStrictStandardEndpointPath(pathname: string): boolean { + return strictStandardEndpointPathSet.has(normalizeEndpointPath(pathname)); +} + +export function isCountTokensEndpointPath(pathname: string): boolean { + return normalizeEndpointPath(pathname) === V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS; +} + +export function isResponseCompactEndpointPath(pathname: string): boolean { + return normalizeEndpointPath(pathname) === V1_ENDPOINT_PATHS.RESPONSES_COMPACT; +} + +export function toV1RoutePath(pathname: string): string { + if (!pathname.startsWith(V1_PREFIX)) { + return pathname; + } + + const routePath = pathname.slice(V1_PREFIX.length); + return routePath.length > 0 ? routePath : "/"; +} diff --git a/src/app/v1/_lib/proxy/endpoint-policy.ts b/src/app/v1/_lib/proxy/endpoint-policy.ts new file mode 100644 index 000000000..1bd77d89f --- /dev/null +++ b/src/app/v1/_lib/proxy/endpoint-policy.ts @@ -0,0 +1,68 @@ +import { normalizeEndpointPath, V1_ENDPOINT_PATHS } from "./endpoint-paths"; + +export type EndpointGuardPreset = "chat" | "raw_passthrough"; + +export type EndpointPoolStrictness = "inherit" | "strict"; + +export interface EndpointPolicy { + readonly kind: "default" | "raw_passthrough"; + readonly guardPreset: EndpointGuardPreset; + readonly allowRetry: boolean; + readonly allowProviderSwitch: boolean; + readonly allowCircuitBreakerAccounting: boolean; + readonly trackConcurrentRequests: boolean; + readonly bypassRequestFilters: boolean; + readonly bypassForwarderPreprocessing: boolean; + readonly bypassSpecialSettings: boolean; + readonly bypassResponseRectifier: boolean; + readonly endpointPoolStrictness: EndpointPoolStrictness; +} + +const DEFAULT_ENDPOINT_POLICY: EndpointPolicy = Object.freeze({ + kind: "default", + guardPreset: "chat", + allowRetry: true, + allowProviderSwitch: true, + allowCircuitBreakerAccounting: true, + trackConcurrentRequests: true, + bypassRequestFilters: false, + bypassForwarderPreprocessing: false, + bypassSpecialSettings: false, + bypassResponseRectifier: false, + endpointPoolStrictness: "inherit", +}); + +const RAW_PASSTHROUGH_ENDPOINT_POLICY: EndpointPolicy = Object.freeze({ + kind: "raw_passthrough", + guardPreset: "raw_passthrough", + allowRetry: false, + allowProviderSwitch: false, + allowCircuitBreakerAccounting: false, + trackConcurrentRequests: false, + bypassRequestFilters: true, + bypassForwarderPreprocessing: true, + bypassSpecialSettings: true, + bypassResponseRectifier: true, + endpointPoolStrictness: "strict", +}); + +const rawPassthroughEndpointPathSet = new Set([ + V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS, + V1_ENDPOINT_PATHS.RESPONSES_COMPACT, +]); + +export function isRawPassthroughEndpointPath(pathname: string): boolean { + return rawPassthroughEndpointPathSet.has(normalizeEndpointPath(pathname)); +} + +export function isRawPassthroughEndpointPolicy(policy: EndpointPolicy): boolean { + return policy.kind === "raw_passthrough"; +} + +export function resolveEndpointPolicy(pathname: string): EndpointPolicy { + if (isRawPassthroughEndpointPath(pathname)) { + return RAW_PASSTHROUGH_ENDPOINT_POLICY; + } + + return DEFAULT_ENDPOINT_POLICY; +} diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index 5ef27aa31..b3d6e4193 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -1621,18 +1621,19 @@ export class ProxyForwarder { break; } - // 🆕 count_tokens 请求特殊处理:不计入熔断,不触发供应商切换 - if (session.isCountTokensRequest()) { + // Raw passthrough endpoints: no circuit breaker, no provider switch, no retry + if (!session.getEndpointPolicy().allowRetry) { logger.debug( - "ProxyForwarder: count_tokens request error, skipping circuit breaker and provider switch", + "ProxyForwarder: raw passthrough endpoint error, skipping circuit breaker and provider switch", { providerId: currentProvider.id, providerName: currentProvider.name, statusCode, error: proxyError.message, + policyKind: session.getEndpointPolicy().kind, } ); - // 直接抛出错误,不重试,不切换供应商 + // Throw immediately: no retry, no provider switch throw lastError; } @@ -1780,12 +1781,14 @@ export class ProxyForwarder { session.setContext1mApplied(context1mApplied); } - // 应用模型重定向(如果配置了) - const wasRedirected = ModelRedirector.apply(session, provider); - if (wasRedirected) { - logger.debug("ProxyForwarder: Model redirected", { - providerId: provider.id, - }); + // Apply model redirect (if configured) - skip for raw passthrough endpoints + if (!session.getEndpointPolicy().bypassForwarderPreprocessing) { + const wasRedirected = ModelRedirector.apply(session, provider); + if (wasRedirected) { + logger.debug("ProxyForwarder: Model redirected", { + providerId: provider.id, + }); + } } let processedHeaders: Headers; @@ -1898,138 +1901,140 @@ export class ProxyForwarder { }); } else { // --- STANDARD HANDLING --- - if ( - resolvedCacheTtl && - (provider.providerType === "claude" || provider.providerType === "claude-auth") - ) { - const applied = applyCacheTtlOverrideToMessage(session.request.message, resolvedCacheTtl); - if (applied) { - logger.info("ProxyForwarder: Applied cache TTL override to request", { - providerId: provider.id, - providerName: provider.name, - cacheTtl: resolvedCacheTtl, - }); + if (!session.getEndpointPolicy().bypassForwarderPreprocessing) { + if ( + resolvedCacheTtl && + (provider.providerType === "claude" || provider.providerType === "claude-auth") + ) { + const applied = applyCacheTtlOverrideToMessage(session.request.message, resolvedCacheTtl); + if (applied) { + logger.info("ProxyForwarder: Applied cache TTL override to request", { + providerId: provider.id, + providerName: provider.name, + cacheTtl: resolvedCacheTtl, + }); + } } - } - - // Codex 供应商级参数覆写(默认 inherit=遵循客户端) - if (provider.providerType === "codex") { - const { request: overridden, audit } = applyCodexProviderOverridesWithAudit( - provider, - session.request.message as Record - ); - session.request.message = overridden; - - if (audit) { - session.addSpecialSetting(audit); - const specialSettings = session.getSpecialSettings(); - if (session.sessionId) { - // 这里用 await:避免后续响应侧写入(ResponseFixer 等)先完成后,被本次旧快照覆写 - await SessionManager.storeSessionSpecialSettings( - session.sessionId, - specialSettings, - session.requestSequence - ).catch((err) => { - logger.error("[ProxyForwarder] Failed to store special settings", { - error: err, - sessionId: session.sessionId, + // Codex 供应商级参数覆写(默认 inherit=遵循客户端) + if (provider.providerType === "codex") { + const { request: overridden, audit } = applyCodexProviderOverridesWithAudit( + provider, + session.request.message as Record + ); + session.request.message = overridden; + + if (audit) { + session.addSpecialSetting(audit); + const specialSettings = session.getSpecialSettings(); + + if (session.sessionId) { + // 这里用 await:避免后续响应侧写入(ResponseFixer 等)先完成后,被本次旧快照覆写 + await SessionManager.storeSessionSpecialSettings( + session.sessionId, + specialSettings, + session.requestSequence + ).catch((err) => { + logger.error("[ProxyForwarder] Failed to store special settings", { + error: err, + sessionId: session.sessionId, + }); }); - }); - } + } - if (session.messageContext?.id) { - // 同上:确保 special_settings 的"旧值"不会在并发下覆盖"新值" - await updateMessageRequestDetails(session.messageContext.id, { - specialSettings, - }).catch((err) => { - logger.error("[ProxyForwarder] Failed to persist special settings", { - error: err, - messageRequestId: session.messageContext?.id, + if (session.messageContext?.id) { + // 同上:确保 special_settings 的"旧值"不会在并发下覆盖"新值" + await updateMessageRequestDetails(session.messageContext.id, { + specialSettings, + }).catch((err) => { + logger.error("[ProxyForwarder] Failed to persist special settings", { + error: err, + messageRequestId: session.messageContext?.id, + }); }); - }); + } } } - } - // Anthropic 供应商级参数覆写(默认 inherit=遵循客户端) - // 说明:允许管理员在供应商层面强制覆写 max_tokens 和 thinking.budget_tokens - if (provider.providerType === "claude" || provider.providerType === "claude-auth") { - // Billing header rectifier: proactively strip x-anthropic-billing-header from system prompt - { - const settings = await getCachedSystemSettings(); - const billingRectifierEnabled = settings.enableBillingHeaderRectifier ?? true; - if (billingRectifierEnabled) { - const billingResult = rectifyBillingHeader( + // Anthropic 供应商级参数覆写(默认 inherit=遵循客户端) + // 说明:允许管理员在供应商层面强制覆写 max_tokens 和 thinking.budget_tokens + if (provider.providerType === "claude" || provider.providerType === "claude-auth") { + // Billing header rectifier: proactively strip x-anthropic-billing-header from system prompt + { + const settings = await getCachedSystemSettings(); + const billingRectifierEnabled = settings.enableBillingHeaderRectifier ?? true; + if (billingRectifierEnabled) { + const billingResult = rectifyBillingHeader( + session.request.message as Record + ); + if (billingResult.applied) { + session.addSpecialSetting({ + type: "billing_header_rectifier", + scope: "request", + hit: true, + removedCount: billingResult.removedCount, + extractedValues: billingResult.extractedValues, + }); + logger.info("ProxyForwarder: Billing header rectifier applied", { + providerId: provider.id, + providerName: provider.name, + removedCount: billingResult.removedCount, + }); + await persistSpecialSettings(session); + } + } + } + + const { request: anthropicOverridden, audit: anthropicAudit } = + applyAnthropicProviderOverridesWithAudit( + provider, session.request.message as Record ); - if (billingResult.applied) { - session.addSpecialSetting({ - type: "billing_header_rectifier", - scope: "request", - hit: true, - removedCount: billingResult.removedCount, - extractedValues: billingResult.extractedValues, - }); - logger.info("ProxyForwarder: Billing header rectifier applied", { - providerId: provider.id, - providerName: provider.name, - removedCount: billingResult.removedCount, + session.request.message = anthropicOverridden; + + if (anthropicAudit) { + session.addSpecialSetting(anthropicAudit); + const specialSettings = session.getSpecialSettings(); + + if (session.sessionId) { + await SessionManager.storeSessionSpecialSettings( + session.sessionId, + specialSettings, + session.requestSequence + ).catch((err) => { + logger.error("[ProxyForwarder] Failed to store Anthropic special settings", { + error: err, + sessionId: session.sessionId, + }); }); - await persistSpecialSettings(session); } - } - } - - const { request: anthropicOverridden, audit: anthropicAudit } = - applyAnthropicProviderOverridesWithAudit( - provider, - session.request.message as Record - ); - session.request.message = anthropicOverridden; - if (anthropicAudit) { - session.addSpecialSetting(anthropicAudit); - const specialSettings = session.getSpecialSettings(); - - if (session.sessionId) { - await SessionManager.storeSessionSpecialSettings( - session.sessionId, - specialSettings, - session.requestSequence - ).catch((err) => { - logger.error("[ProxyForwarder] Failed to store Anthropic special settings", { - error: err, - sessionId: session.sessionId, + if (session.messageContext?.id) { + await updateMessageRequestDetails(session.messageContext.id, { + specialSettings, + }).catch((err) => { + logger.error("[ProxyForwarder] Failed to persist Anthropic special settings", { + error: err, + messageRequestId: session.messageContext?.id, + }); }); - }); + } } + } - if (session.messageContext?.id) { - await updateMessageRequestDetails(session.messageContext.id, { - specialSettings, - }).catch((err) => { - logger.error("[ProxyForwarder] Failed to persist Anthropic special settings", { - error: err, - messageRequestId: session.messageContext?.id, - }); + if ( + resolvedCacheTtl && + (provider.providerType === "claude" || provider.providerType === "claude-auth") + ) { + const applied = applyCacheTtlOverrideToMessage(session.request.message, resolvedCacheTtl); + if (applied) { + logger.debug("ProxyForwarder: Applied cache TTL override to request", { + providerId: provider.id, + ttl: resolvedCacheTtl, }); } } - } - - if ( - resolvedCacheTtl && - (provider.providerType === "claude" || provider.providerType === "claude-auth") - ) { - const applied = applyCacheTtlOverrideToMessage(session.request.message, resolvedCacheTtl); - if (applied) { - logger.debug("ProxyForwarder: Applied cache TTL override to request", { - providerId: provider.id, - ttl: resolvedCacheTtl, - }); - } - } + } // end bypassForwarderPreprocessing gate processedHeaders = ProxyForwarder.buildHeaders(session, provider); diff --git a/src/app/v1/_lib/proxy/guard-pipeline.ts b/src/app/v1/_lib/proxy/guard-pipeline.ts index 7070477d8..8f881ca25 100644 --- a/src/app/v1/_lib/proxy/guard-pipeline.ts +++ b/src/app/v1/_lib/proxy/guard-pipeline.ts @@ -1,5 +1,6 @@ import { ProxyAuthenticator } from "./auth-guard"; import { ProxyClientGuard } from "./client-guard"; +import type { EndpointPolicy } from "./endpoint-policy"; import { ProxyMessageService } from "./message-service"; import { ProxyModelGuard } from "./model-guard"; import { ProxyProviderRequestFilter } from "./provider-request-filter"; @@ -157,11 +158,24 @@ export class GuardPipelineBuilder { }; } + static fromSession(session: Pick): GuardPipeline { + return GuardPipelineBuilder.fromEndpointPolicy(session.getEndpointPolicy()); + } + + static fromEndpointPolicy(policy: Pick): GuardPipeline { + switch (policy.guardPreset) { + case "raw_passthrough": + return GuardPipelineBuilder.build(RAW_PASSTHROUGH_PIPELINE); + default: + return GuardPipelineBuilder.build(CHAT_PIPELINE); + } + } + // Convenience: build a pipeline from preset request type static fromRequestType(type: RequestType): GuardPipeline { switch (type) { case RequestType.COUNT_TOKENS: - return GuardPipelineBuilder.build(COUNT_TOKENS_PIPELINE); + return GuardPipelineBuilder.build(RAW_PASSTHROUGH_PIPELINE); default: return GuardPipelineBuilder.build(CHAT_PIPELINE); } @@ -188,16 +202,8 @@ export const CHAT_PIPELINE: GuardConfig = { ], }; -export const COUNT_TOKENS_PIPELINE: GuardConfig = { - // Minimal chain for count_tokens: no session, no sensitive, no rate limit, no message logging - steps: [ - "auth", - "client", - "model", - "version", - "probe", - "requestFilter", - "provider", - "providerRequestFilter", - ], +export const RAW_PASSTHROUGH_PIPELINE: GuardConfig = { + steps: ["auth", "client", "model", "version", "probe", "provider"], }; + +export const COUNT_TOKENS_PIPELINE: GuardConfig = RAW_PASSTHROUGH_PIPELINE; diff --git a/src/app/v1/_lib/proxy/provider-request-filter.ts b/src/app/v1/_lib/proxy/provider-request-filter.ts index 719bc1307..68242b029 100644 --- a/src/app/v1/_lib/proxy/provider-request-filter.ts +++ b/src/app/v1/_lib/proxy/provider-request-filter.ts @@ -9,6 +9,10 @@ import type { ProxySession } from "./session"; */ export class ProxyProviderRequestFilter { static async ensure(session: ProxySession): Promise { + if (session.getEndpointPolicy().bypassRequestFilters) { + return; + } + if (!session.provider) { logger.warn( "[ProxyProviderRequestFilter] No provider selected, skipping provider-specific filters" diff --git a/src/app/v1/_lib/proxy/request-filter.ts b/src/app/v1/_lib/proxy/request-filter.ts index 468863724..68e45aae7 100644 --- a/src/app/v1/_lib/proxy/request-filter.ts +++ b/src/app/v1/_lib/proxy/request-filter.ts @@ -12,6 +12,10 @@ import type { ProxySession } from "./session"; */ export class ProxyRequestFilter { static async ensure(session: ProxySession): Promise { + if (session.getEndpointPolicy().bypassRequestFilters) { + return; + } + try { await requestFilterEngine.applyGlobal(session); } catch (error) { diff --git a/src/app/v1/_lib/proxy/response-handler.ts b/src/app/v1/_lib/proxy/response-handler.ts index 13abdc3bf..28798a4fd 100644 --- a/src/app/v1/_lib/proxy/response-handler.ts +++ b/src/app/v1/_lib/proxy/response-handler.ts @@ -253,7 +253,7 @@ async function finalizeDeferredStreamingFinalizationIfNeeded( // - 客户端主动中断:不计入熔断器(这通常不是供应商问题) // - 非客户端中断:计入 provider/endpoint 熔断失败(与 timeout 路径保持一致) if (!streamEndedNormally) { - if (!clientAborted) { + if (!clientAborted && session.getEndpointPolicy().allowCircuitBreakerAccounting) { try { // 动态导入:避免 proxy 模块与熔断器模块之间潜在的循环依赖。 const { recordFailure } = await import("@/lib/circuit-breaker"); @@ -301,7 +301,7 @@ async function finalizeDeferredStreamingFinalizationIfNeeded( // 计入熔断器:让后续请求能正确触发故障转移/熔断。 // // 注意:404 语义在 forwarder 中属于 RESOURCE_NOT_FOUND,不计入熔断器(避免把“资源/模型不存在”当作供应商故障)。 - if (effectiveStatusCode !== 404) { + if (effectiveStatusCode !== 404 && session.getEndpointPolicy().allowCircuitBreakerAccounting) { try { // 动态导入:避免 proxy 模块与熔断器模块之间潜在的循环依赖。 const { recordFailure } = await import("@/lib/circuit-breaker"); @@ -350,7 +350,7 @@ async function finalizeDeferredStreamingFinalizationIfNeeded( // 计入熔断器:让后续请求能正确触发故障转移/熔断。 // 注意:与 forwarder 口径保持一致:404 不计入熔断器(资源不存在不是供应商故障)。 - if (effectiveStatusCode !== 404) { + if (effectiveStatusCode !== 404 && session.getEndpointPolicy().allowCircuitBreakerAccounting) { try { const { recordFailure } = await import("@/lib/circuit-breaker"); await recordFailure(meta.providerId, new Error(errorMessage)); @@ -469,19 +469,21 @@ async function finalizeDeferredStreamingFinalizationIfNeeded( export class ProxyResponseHandler { static async dispatch(session: ProxySession, response: Response): Promise { let fixedResponse = response; - try { - fixedResponse = await ResponseFixer.process(session, response); - } catch (error) { - logger.error( - "[ResponseHandler] ResponseFixer failed (getCachedSystemSettings/processNonStream)", - { - error: error instanceof Error ? error.message : String(error), - sessionId: session.sessionId ?? null, - messageRequestId: session.messageContext?.id ?? null, - requestSequence: session.requestSequence ?? null, - } - ); - fixedResponse = response; + if (!session.getEndpointPolicy().bypassResponseRectifier) { + try { + fixedResponse = await ResponseFixer.process(session, response); + } catch (error) { + logger.error( + "[ResponseHandler] ResponseFixer failed (getCachedSystemSettings/processNonStream)", + { + error: error instanceof Error ? error.message : String(error), + sessionId: session.sessionId ?? null, + messageRequestId: session.messageContext?.id ?? null, + requestSequence: session.requestSequence ?? null, + } + ); + fixedResponse = response; + } } const contentType = fixedResponse.headers.get("content-type") || ""; @@ -561,17 +563,19 @@ export class ProxyResponseHandler { errorMessageForFinalize = detected.isError ? detected.code : `HTTP ${statusCode}`; // 计入熔断器 - try { - const { recordFailure } = await import("@/lib/circuit-breaker"); - await recordFailure(provider.id, new Error(errorMessageForFinalize)); - } catch (cbError) { - logger.warn( - "ResponseHandler: Failed to record non-200 error in circuit breaker (passthrough)", - { - providerId: provider.id, - error: cbError, - } - ); + if (session.getEndpointPolicy().allowCircuitBreakerAccounting) { + try { + const { recordFailure } = await import("@/lib/circuit-breaker"); + await recordFailure(provider.id, new Error(errorMessageForFinalize)); + } catch (cbError) { + logger.warn( + "ResponseHandler: Failed to record non-200 error in circuit breaker (passthrough)", + { + providerId: provider.id, + error: cbError, + } + ); + } } // 记录到决策链 @@ -843,14 +847,16 @@ export class ProxyResponseHandler { const errorMessageForDb = detected.isError ? detected.code : `HTTP ${statusCode}`; // 计入熔断器 - try { - const { recordFailure } = await import("@/lib/circuit-breaker"); - await recordFailure(provider.id, new Error(errorMessageForDb)); - } catch (cbError) { - logger.warn("ResponseHandler: Failed to record non-200 error in circuit breaker", { - providerId: provider.id, - error: cbError, - }); + if (session.getEndpointPolicy().allowCircuitBreakerAccounting) { + try { + const { recordFailure } = await import("@/lib/circuit-breaker"); + await recordFailure(provider.id, new Error(errorMessageForDb)); + } catch (cbError) { + logger.warn("ResponseHandler: Failed to record non-200 error in circuit breaker", { + providerId: provider.id, + error: cbError, + }); + } } // 记录到决策链 @@ -929,17 +935,19 @@ export class ProxyResponseHandler { }); // 计入熔断器(动态导入避免循环依赖) - try { - const { recordFailure } = await import("@/lib/circuit-breaker"); - await recordFailure(provider.id, err); - logger.debug("ResponseHandler: Response timeout recorded in circuit breaker", { - providerId: provider.id, - }); - } catch (cbError) { - logger.warn("ResponseHandler: Failed to record timeout in circuit breaker", { - providerId: provider.id, - error: cbError, - }); + if (session.getEndpointPolicy().allowCircuitBreakerAccounting) { + try { + const { recordFailure } = await import("@/lib/circuit-breaker"); + await recordFailure(provider.id, err); + logger.debug("ResponseHandler: Response timeout recorded in circuit breaker", { + providerId: provider.id, + }); + } catch (cbError) { + logger.warn("ResponseHandler: Failed to record timeout in circuit breaker", { + providerId: provider.id, + error: cbError, + }); + } } // 注意:无法重试,因为客户端已收到 HTTP 200 diff --git a/src/app/v1/_lib/proxy/session.ts b/src/app/v1/_lib/proxy/session.ts index dda2c4a46..bebf1ba9e 100644 --- a/src/app/v1/_lib/proxy/session.ts +++ b/src/app/v1/_lib/proxy/session.ts @@ -11,6 +11,8 @@ import type { ModelPriceData } from "@/types/model-price"; import type { Provider, ProviderType } from "@/types/provider"; import type { SpecialSetting } from "@/types/special-settings"; import type { User } from "@/types/user"; +import { isCountTokensEndpointPath } from "./endpoint-paths"; +import { type EndpointPolicy, resolveEndpointPolicy } from "./endpoint-policy"; import { ProxyError } from "./errors"; import type { ClientFormat } from "./format-mapper"; @@ -83,6 +85,8 @@ export class ProxySession { originalFormat: ClientFormat = "claude"; providerType: ProviderType | null = null; + private readonly endpointPolicy: EndpointPolicy; + // 模型重定向追踪:保存原始模型名(重定向前) private originalModelName: string | null = null; @@ -154,6 +158,7 @@ export class ProxySession { this.messageContext = null; this.sessionId = null; this.providerChain = []; + this.endpointPolicy = resolveSessionEndpointPolicy(init.requestUrl); } static async fromContext(c: Context): Promise { @@ -528,6 +533,10 @@ export class ProxySession { return this.request.model; } + getEndpointPolicy(): EndpointPolicy { + return this.endpointPolicy; + } + /** * 获取请求的 API endpoint(来自 URL.pathname) * 处理边界:若 URL 不存在则返回 null @@ -548,7 +557,7 @@ export class ProxySession { */ isCountTokensRequest(): boolean { const endpoint = this.getEndpoint(); - return endpoint === "/v1/messages/count_tokens"; + return endpoint !== null && isCountTokensEndpointPath(endpoint); } /** @@ -793,6 +802,17 @@ function optimizeRequestMessage(message: Record): Record 0) { + return resolveEndpointPolicy(pathname); + } + } catch {} + + return resolveEndpointPolicy("/"); +} + export function extractModelFromPath(pathname: string): string | null { // 匹配 Vertex AI 路径:/v1/publishers/google/models/{model}: const publishersMatch = pathname.match(/\/publishers\/google\/models\/([^/:]+)(?::[^/]+)?/); diff --git a/tests/unit/proxy/endpoint-path-normalization.test.ts b/tests/unit/proxy/endpoint-path-normalization.test.ts new file mode 100644 index 000000000..8b4662e04 --- /dev/null +++ b/tests/unit/proxy/endpoint-path-normalization.test.ts @@ -0,0 +1,56 @@ +import { describe, expect, test } from "vitest"; +import { isRawPassthroughEndpointPath } from "@/app/v1/_lib/proxy/endpoint-policy"; +import { + isCountTokensEndpointPath, + isResponseCompactEndpointPath, +} from "@/app/v1/_lib/proxy/endpoint-paths"; +import { ProxySession } from "@/app/v1/_lib/proxy/session"; + +const countTokensVariants = [ + "/v1/messages/count_tokens", + "/v1/messages/count_tokens/", + "/V1/MESSAGES/COUNT_TOKENS", +]; + +const compactVariants = [ + "/v1/responses/compact", + "/v1/responses/compact/", + "/V1/RESPONSES/COMPACT", +]; + +function isCountTokensRequestWithEndpoint(pathname: string | null): boolean { + const sessionLike = { + getEndpoint: () => pathname, + } as Pick; + + return ProxySession.prototype.isCountTokensRequest.call(sessionLike as ProxySession); +} + +describe("endpoint path normalization", () => { + test.each(countTokensVariants)("count_tokens stays classified for variant %s", (pathname) => { + expect(isCountTokensEndpointPath(pathname)).toBe(true); + expect(isRawPassthroughEndpointPath(pathname)).toBe(true); + expect(isCountTokensRequestWithEndpoint(pathname)).toBe(true); + }); + + test.each(compactVariants)("responses/compact stays classified for variant %s", (pathname) => { + expect(isResponseCompactEndpointPath(pathname)).toBe(true); + expect(isRawPassthroughEndpointPath(pathname)).toBe(true); + }); + + test.each([ + "/v1/messages", + "/v1/responses", + "/v1/messages/count", + "/v1/responses/mini", + ])("non-target path is not misclassified for %s", (pathname) => { + expect(isCountTokensEndpointPath(pathname)).toBe(false); + expect(isResponseCompactEndpointPath(pathname)).toBe(false); + expect(isRawPassthroughEndpointPath(pathname)).toBe(false); + expect(isCountTokensRequestWithEndpoint(pathname)).toBe(false); + }); + + test("session count_tokens detection handles null endpoint", () => { + expect(isCountTokensRequestWithEndpoint(null)).toBe(false); + }); +}); diff --git a/tests/unit/proxy/endpoint-policy-parity.test.ts b/tests/unit/proxy/endpoint-policy-parity.test.ts new file mode 100644 index 000000000..7145728fd --- /dev/null +++ b/tests/unit/proxy/endpoint-policy-parity.test.ts @@ -0,0 +1,341 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { + type EndpointPolicy, + isRawPassthroughEndpointPath, + isRawPassthroughEndpointPolicy, + resolveEndpointPolicy, +} from "@/app/v1/_lib/proxy/endpoint-policy"; +import { V1_ENDPOINT_PATHS } from "@/app/v1/_lib/proxy/endpoint-paths"; + +// --------------------------------------------------------------------------- +// Shared constants +// --------------------------------------------------------------------------- + +const RAW_PASSTHROUGH_ENDPOINTS = [ + V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS, + V1_ENDPOINT_PATHS.RESPONSES_COMPACT, +] as const; + +const DEFAULT_ENDPOINTS = [ + V1_ENDPOINT_PATHS.MESSAGES, + V1_ENDPOINT_PATHS.RESPONSES, + V1_ENDPOINT_PATHS.CHAT_COMPLETIONS, +] as const; + +// --------------------------------------------------------------------------- +// T11: Endpoint parity -- count_tokens and responses/compact produce +// identical EndpointPolicy objects and exhibit identical behaviour +// under provider errors. +// --------------------------------------------------------------------------- + +describe("T11: raw passthrough endpoint parity", () => { + test("count_tokens and responses/compact resolve to the exact same EndpointPolicy object", () => { + const countTokensPolicy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS); + const compactPolicy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.RESPONSES_COMPACT); + + // Reference equality: same frozen singleton + expect(countTokensPolicy).toBe(compactPolicy); + + // Both recognized as raw_passthrough + expect(isRawPassthroughEndpointPolicy(countTokensPolicy)).toBe(true); + expect(isRawPassthroughEndpointPolicy(compactPolicy)).toBe(true); + }); + + test("both raw passthrough endpoints have identical strict policy fields", () => { + const countTokensPolicy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS); + const compactPolicy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.RESPONSES_COMPACT); + + const expectedPolicy: EndpointPolicy = { + kind: "raw_passthrough", + guardPreset: "raw_passthrough", + allowRetry: false, + allowProviderSwitch: false, + allowCircuitBreakerAccounting: false, + trackConcurrentRequests: false, + bypassRequestFilters: true, + bypassForwarderPreprocessing: true, + bypassSpecialSettings: true, + bypassResponseRectifier: true, + endpointPoolStrictness: "strict", + }; + + expect(countTokensPolicy).toEqual(expectedPolicy); + expect(compactPolicy).toEqual(expectedPolicy); + }); + + test("under provider error, both endpoints result in no retry, no provider switch, no circuit breaker accounting", () => { + for (const pathname of RAW_PASSTHROUGH_ENDPOINTS) { + const policy = resolveEndpointPolicy(pathname); + + expect(policy.allowRetry).toBe(false); + expect(policy.allowProviderSwitch).toBe(false); + expect(policy.allowCircuitBreakerAccounting).toBe(false); + } + }); + + test("isRawPassthroughEndpointPath returns true for both raw passthrough canonical paths", () => { + for (const pathname of RAW_PASSTHROUGH_ENDPOINTS) { + expect(isRawPassthroughEndpointPath(pathname)).toBe(true); + } + }); +}); + +// --------------------------------------------------------------------------- +// T12: Bypass completeness -- spy-based zero-call assertions to verify that +// request filter guards early-return without invoking the engine. +// --------------------------------------------------------------------------- + +const applyGlobalMock = vi.fn(async () => {}); +const applyForProviderMock = vi.fn(async () => {}); + +vi.mock("@/lib/request-filter-engine", () => ({ + requestFilterEngine: { + applyGlobal: applyGlobalMock, + applyForProvider: applyForProviderMock, + }, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + trace: vi.fn(), + fatal: vi.fn(), + }, +})); + +describe("T12: bypass completeness (spy-based zero-call assertions)", () => { + beforeEach(() => { + applyGlobalMock.mockClear(); + applyForProviderMock.mockClear(); + }); + + test("ProxyRequestFilter.ensure early-returns without calling applyGlobal for raw passthrough", async () => { + const { ProxyRequestFilter } = await import("@/app/v1/_lib/proxy/request-filter"); + + for (const pathname of RAW_PASSTHROUGH_ENDPOINTS) { + applyGlobalMock.mockClear(); + + const session = { + getEndpointPolicy: () => resolveEndpointPolicy(pathname), + } as any; + + await ProxyRequestFilter.ensure(session); + expect(applyGlobalMock).not.toHaveBeenCalled(); + } + }); + + test("ProxyProviderRequestFilter.ensure early-returns without calling applyForProvider for raw passthrough", async () => { + const { ProxyProviderRequestFilter } = await import( + "@/app/v1/_lib/proxy/provider-request-filter" + ); + + for (const pathname of RAW_PASSTHROUGH_ENDPOINTS) { + applyForProviderMock.mockClear(); + + const session = { + getEndpointPolicy: () => resolveEndpointPolicy(pathname), + provider: { id: 1 }, + } as any; + + await ProxyProviderRequestFilter.ensure(session); + expect(applyForProviderMock).not.toHaveBeenCalled(); + } + }); + + test("ProxyRequestFilter.ensure calls applyGlobal for default policy endpoints", async () => { + const { ProxyRequestFilter } = await import("@/app/v1/_lib/proxy/request-filter"); + + for (const pathname of DEFAULT_ENDPOINTS) { + applyGlobalMock.mockClear(); + + const session = { + getEndpointPolicy: () => resolveEndpointPolicy(pathname), + } as any; + + await ProxyRequestFilter.ensure(session); + expect(applyGlobalMock).toHaveBeenCalledTimes(1); + } + }); + + test("ProxyProviderRequestFilter.ensure calls applyForProvider for default policy endpoints", async () => { + const { ProxyProviderRequestFilter } = await import( + "@/app/v1/_lib/proxy/provider-request-filter" + ); + + for (const pathname of DEFAULT_ENDPOINTS) { + applyForProviderMock.mockClear(); + + const session = { + getEndpointPolicy: () => resolveEndpointPolicy(pathname), + provider: { id: 1 }, + } as any; + + await ProxyProviderRequestFilter.ensure(session); + expect(applyForProviderMock).toHaveBeenCalledTimes(1); + } + }); +}); + +// --------------------------------------------------------------------------- +// T13: Non-target regression -- default endpoints retain full default policy. +// --------------------------------------------------------------------------- + +describe("T13: non-target regression (default policy preserved)", () => { + const expectedDefaultPolicy: EndpointPolicy = { + kind: "default", + guardPreset: "chat", + allowRetry: true, + allowProviderSwitch: true, + allowCircuitBreakerAccounting: true, + trackConcurrentRequests: true, + bypassRequestFilters: false, + bypassForwarderPreprocessing: false, + bypassSpecialSettings: false, + bypassResponseRectifier: false, + endpointPoolStrictness: "inherit", + }; + + test("/v1/messages retains full default policy", () => { + const policy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.MESSAGES); + expect(policy).toEqual(expectedDefaultPolicy); + expect(isRawPassthroughEndpointPolicy(policy)).toBe(false); + }); + + test("/v1/responses retains full default policy", () => { + const policy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.RESPONSES); + expect(policy).toEqual(expectedDefaultPolicy); + expect(isRawPassthroughEndpointPolicy(policy)).toBe(false); + }); + + test("/v1/chat/completions retains full default policy", () => { + const policy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.CHAT_COMPLETIONS); + expect(policy).toEqual(expectedDefaultPolicy); + expect(isRawPassthroughEndpointPolicy(policy)).toBe(false); + }); + + test("all default endpoints resolve to the same singleton object", () => { + const policies = DEFAULT_ENDPOINTS.map((p) => resolveEndpointPolicy(p)); + // All should be the same reference + for (let i = 1; i < policies.length; i++) { + expect(policies[i]).toBe(policies[0]); + } + }); + + test("default policy has all bypass flags set to false", () => { + for (const pathname of DEFAULT_ENDPOINTS) { + const policy = resolveEndpointPolicy(pathname); + expect(policy.bypassRequestFilters).toBe(false); + expect(policy.bypassForwarderPreprocessing).toBe(false); + expect(policy.bypassSpecialSettings).toBe(false); + expect(policy.bypassResponseRectifier).toBe(false); + } + }); + + test("default policy has all allow flags set to true", () => { + for (const pathname of DEFAULT_ENDPOINTS) { + const policy = resolveEndpointPolicy(pathname); + expect(policy.allowRetry).toBe(true); + expect(policy.allowProviderSwitch).toBe(true); + expect(policy.allowCircuitBreakerAccounting).toBe(true); + expect(policy.trackConcurrentRequests).toBe(true); + } + }); +}); + +// --------------------------------------------------------------------------- +// T14: Path edge-case tests -- normalization handles trailing slashes, case +// variants, query strings, and non-matching paths correctly. +// --------------------------------------------------------------------------- + +describe("T14: path edge-case normalization", () => { + test("trailing slash: /v1/messages/count_tokens/ -> raw_passthrough", () => { + expect(isRawPassthroughEndpointPath("/v1/messages/count_tokens/")).toBe(true); + const policy = resolveEndpointPolicy("/v1/messages/count_tokens/"); + expect(policy.kind).toBe("raw_passthrough"); + }); + + test("trailing slash: /v1/responses/compact/ -> raw_passthrough", () => { + expect(isRawPassthroughEndpointPath("/v1/responses/compact/")).toBe(true); + const policy = resolveEndpointPolicy("/v1/responses/compact/"); + expect(policy.kind).toBe("raw_passthrough"); + }); + + test("uppercase: /V1/MESSAGES/COUNT_TOKENS -> raw_passthrough", () => { + expect(isRawPassthroughEndpointPath("/V1/MESSAGES/COUNT_TOKENS")).toBe(true); + const policy = resolveEndpointPolicy("/V1/MESSAGES/COUNT_TOKENS"); + expect(policy.kind).toBe("raw_passthrough"); + }); + + test("uppercase: /V1/RESPONSES/COMPACT -> raw_passthrough", () => { + expect(isRawPassthroughEndpointPath("/V1/RESPONSES/COMPACT")).toBe(true); + const policy = resolveEndpointPolicy("/V1/RESPONSES/COMPACT"); + expect(policy.kind).toBe("raw_passthrough"); + }); + + test("query string: /v1/messages/count_tokens?foo=bar -> raw_passthrough", () => { + expect(isRawPassthroughEndpointPath("/v1/messages/count_tokens?foo=bar")).toBe(true); + const policy = resolveEndpointPolicy("/v1/messages/count_tokens?foo=bar"); + expect(policy.kind).toBe("raw_passthrough"); + }); + + test("query string: /v1/responses/compact?foo=bar -> raw_passthrough", () => { + expect(isRawPassthroughEndpointPath("/v1/responses/compact?foo=bar")).toBe(true); + const policy = resolveEndpointPolicy("/v1/responses/compact?foo=bar"); + expect(policy.kind).toBe("raw_passthrough"); + }); + + test("combined edge case: uppercase + trailing slash + query string", () => { + expect(isRawPassthroughEndpointPath("/V1/MESSAGES/COUNT_TOKENS/?x=1")).toBe(true); + expect(isRawPassthroughEndpointPath("/V1/RESPONSES/COMPACT/?x=1")).toBe(true); + + const policy1 = resolveEndpointPolicy("/V1/MESSAGES/COUNT_TOKENS/?x=1"); + const policy2 = resolveEndpointPolicy("/V1/RESPONSES/COMPACT/?x=1"); + expect(policy1.kind).toBe("raw_passthrough"); + expect(policy2.kind).toBe("raw_passthrough"); + }); + + test("/v1/messages/ (with trailing slash) -> default, NOT raw_passthrough", () => { + expect(isRawPassthroughEndpointPath("/v1/messages/")).toBe(false); + const policy = resolveEndpointPolicy("/v1/messages/"); + expect(policy.kind).toBe("default"); + }); + + test("/v1/messages (no trailing slash) -> default", () => { + expect(isRawPassthroughEndpointPath("/v1/messages")).toBe(false); + const policy = resolveEndpointPolicy("/v1/messages"); + expect(policy.kind).toBe("default"); + }); + + test("/v1/responses (no sub-path) -> default", () => { + expect(isRawPassthroughEndpointPath("/v1/responses")).toBe(false); + const policy = resolveEndpointPolicy("/v1/responses"); + expect(policy.kind).toBe("default"); + }); + + test("/v1/chat/completions -> default", () => { + expect(isRawPassthroughEndpointPath("/v1/chat/completions")).toBe(false); + const policy = resolveEndpointPolicy("/v1/chat/completions"); + expect(policy.kind).toBe("default"); + }); + + test.each([ + "/v1/messages/count", + "/v1/messages/count_token", + "/v1/responses/mini", + "/v1/responses/compacted", + "/v2/messages/count_tokens", + "/v1/messages/count_tokens/extra", + ])("non-matching path %s -> default", (pathname) => { + expect(isRawPassthroughEndpointPath(pathname)).toBe(false); + const policy = resolveEndpointPolicy(pathname); + expect(policy.kind).toBe("default"); + }); + + test("empty and root paths -> default", () => { + expect(resolveEndpointPolicy("/").kind).toBe("default"); + expect(resolveEndpointPolicy("").kind).toBe("default"); + }); +}); diff --git a/tests/unit/proxy/endpoint-policy.test.ts b/tests/unit/proxy/endpoint-policy.test.ts new file mode 100644 index 000000000..af540acc6 --- /dev/null +++ b/tests/unit/proxy/endpoint-policy.test.ts @@ -0,0 +1,61 @@ +import { describe, expect, test } from "vitest"; +import { + isRawPassthroughEndpointPath, + isRawPassthroughEndpointPolicy, + resolveEndpointPolicy, +} from "@/app/v1/_lib/proxy/endpoint-policy"; +import { V1_ENDPOINT_PATHS } from "@/app/v1/_lib/proxy/endpoint-paths"; + +describe("endpoint-policy", () => { + test("raw passthrough endpoints resolve to identical strict policy", () => { + const countTokensPolicy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS); + const compactPolicy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.RESPONSES_COMPACT); + + expect(countTokensPolicy).toBe(compactPolicy); + expect(isRawPassthroughEndpointPolicy(countTokensPolicy)).toBe(true); + expect(countTokensPolicy).toEqual({ + kind: "raw_passthrough", + guardPreset: "raw_passthrough", + allowRetry: false, + allowProviderSwitch: false, + allowCircuitBreakerAccounting: false, + trackConcurrentRequests: false, + bypassRequestFilters: true, + bypassForwarderPreprocessing: true, + bypassSpecialSettings: true, + bypassResponseRectifier: true, + endpointPoolStrictness: "strict", + }); + }); + + test.each([ + "/v1/messages/count_tokens/", + "/V1/MESSAGES/COUNT_TOKENS", + "/v1/responses/compact/", + "/V1/RESPONSES/COMPACT", + ])("raw passthrough endpoints path helper matches variant %s", (pathname) => { + expect(isRawPassthroughEndpointPath(pathname)).toBe(true); + expect(isRawPassthroughEndpointPolicy(resolveEndpointPolicy(pathname))).toBe(true); + }); + + test("default policy stays on non-target endpoints", () => { + const messagesPolicy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.MESSAGES); + const responsesPolicy = resolveEndpointPolicy(V1_ENDPOINT_PATHS.RESPONSES); + + expect(messagesPolicy).toBe(responsesPolicy); + expect(isRawPassthroughEndpointPolicy(messagesPolicy)).toBe(false); + expect(messagesPolicy).toEqual({ + kind: "default", + guardPreset: "chat", + allowRetry: true, + allowProviderSwitch: true, + allowCircuitBreakerAccounting: true, + trackConcurrentRequests: true, + bypassRequestFilters: false, + bypassForwarderPreprocessing: false, + bypassSpecialSettings: false, + bypassResponseRectifier: false, + endpointPoolStrictness: "inherit", + }); + }); +}); diff --git a/tests/unit/proxy/guard-pipeline-warmup.test.ts b/tests/unit/proxy/guard-pipeline-warmup.test.ts index 401565913..aa953f336 100644 --- a/tests/unit/proxy/guard-pipeline-warmup.test.ts +++ b/tests/unit/proxy/guard-pipeline-warmup.test.ts @@ -1,4 +1,6 @@ import { describe, expect, test, vi } from "vitest"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; +import { V1_ENDPOINT_PATHS } from "@/app/v1/_lib/proxy/endpoint-paths"; const callOrder: string[] = []; @@ -204,20 +206,73 @@ describe("GuardPipeline:warmup 拦截点", () => { const res = await pipeline.run(session); expect(res).toBeNull(); + expect(callOrder).toEqual(["auth", "client", "model", "version", "probe", "provider"]); + expect(callOrder).not.toContain("session"); + expect(callOrder).not.toContain("warmup"); + expect(callOrder).not.toContain("sensitive"); + expect(callOrder).not.toContain("rateLimit"); + expect(callOrder).not.toContain("requestFilter"); + expect(callOrder).not.toContain("providerRequestFilter"); + expect(callOrder).not.toContain("messageContext"); + }); + + test("count_tokens 和 responses/compact 应通过 endpoint policy 选择同一 raw preset", async () => { + const { GuardPipelineBuilder } = await import("@/app/v1/_lib/proxy/guard-pipeline"); + + const endpoints = [ + V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS, + V1_ENDPOINT_PATHS.RESPONSES_COMPACT, + ]; + const orders: string[][] = []; + + for (const endpoint of endpoints) { + callOrder.length = 0; + const session = { + getEndpointPolicy: () => resolveEndpointPolicy(endpoint), + isProbeRequest: () => { + callOrder.push("probe"); + return false; + }, + } as any; + + const pipeline = GuardPipelineBuilder.fromSession(session); + const res = await pipeline.run(session); + + expect(res).toBeNull(); + orders.push([...callOrder]); + } + + expect(orders[0]).toEqual(orders[1]); + expect(orders[0]).toEqual(["auth", "client", "model", "version", "probe", "provider"]); + }); + + test("/v1/messages 仍应通过 endpoint policy 选择现有 chat preset", async () => { + callOrder.length = 0; + + const { GuardPipelineBuilder } = await import("@/app/v1/_lib/proxy/guard-pipeline"); + + const session = { + getEndpointPolicy: () => resolveEndpointPolicy(V1_ENDPOINT_PATHS.MESSAGES), + isProbeRequest: () => { + callOrder.push("probe"); + return false; + }, + } as any; + + const pipeline = GuardPipelineBuilder.fromSession(session); + const res = await pipeline.run(session); + + expect(res).not.toBeNull(); + expect(res?.status).toBe(200); expect(callOrder).toEqual([ "auth", + "sensitive", "client", "model", "version", "probe", - "requestFilter", - "provider", - "providerRequestFilter", + "session", + "warmup", ]); - expect(callOrder).not.toContain("session"); - expect(callOrder).not.toContain("warmup"); - expect(callOrder).not.toContain("sensitive"); - expect(callOrder).not.toContain("rateLimit"); - expect(callOrder).not.toContain("messageContext"); }); }); diff --git a/tests/unit/proxy/proxy-forwarder-fake-200-html.test.ts b/tests/unit/proxy/proxy-forwarder-fake-200-html.test.ts index 2aa054a10..8b57d3915 100644 --- a/tests/unit/proxy/proxy-forwarder-fake-200-html.test.ts +++ b/tests/unit/proxy/proxy-forwarder-fake-200-html.test.ts @@ -1,4 +1,5 @@ import { beforeEach, describe, expect, test, vi } from "vitest"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; const mocks = vi.hoisted(() => { return { @@ -185,6 +186,7 @@ function createSession(): ProxySession { specialSettings: [], cachedPriceData: undefined, cachedBillingModelSource: undefined, + endpointPolicy: resolveEndpointPolicy("/v1/messages"), isHeaderModified: () => false, }); diff --git a/tests/unit/proxy/proxy-forwarder-large-chunked-response.test.ts b/tests/unit/proxy/proxy-forwarder-large-chunked-response.test.ts index 2bdfc284b..4e9cfd9c7 100644 --- a/tests/unit/proxy/proxy-forwarder-large-chunked-response.test.ts +++ b/tests/unit/proxy/proxy-forwarder-large-chunked-response.test.ts @@ -2,6 +2,7 @@ import { createServer } from "node:http"; import type { Socket } from "node:net"; import { describe, expect, test, vi } from "vitest"; import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; import { ProxySession } from "@/app/v1/_lib/proxy/session"; import type { Provider } from "@/types/provider"; @@ -128,6 +129,7 @@ function createSession(params?: { clientAbortSignal?: AbortSignal | null }): Pro specialSettings: [], cachedPriceData: undefined, cachedBillingModelSource: undefined, + endpointPolicy: resolveEndpointPolicy("/v1/chat/completions"), isHeaderModified: () => false, }); diff --git a/tests/unit/proxy/proxy-forwarder-nonok-body-hang.test.ts b/tests/unit/proxy/proxy-forwarder-nonok-body-hang.test.ts index d1a1e3cf4..95eee1076 100644 --- a/tests/unit/proxy/proxy-forwarder-nonok-body-hang.test.ts +++ b/tests/unit/proxy/proxy-forwarder-nonok-body-hang.test.ts @@ -2,6 +2,7 @@ import { createServer } from "node:http"; import type { Socket } from "node:net"; import { describe, expect, test, vi } from "vitest"; import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; import { ProxyError } from "@/app/v1/_lib/proxy/errors"; import { ProxySession } from "@/app/v1/_lib/proxy/session"; import type { Provider } from "@/types/provider"; @@ -128,6 +129,7 @@ function createSession(params?: { clientAbortSignal?: AbortSignal | null }): Pro specialSettings: [], cachedPriceData: undefined, cachedBillingModelSource: undefined, + endpointPolicy: resolveEndpointPolicy("/v1/chat/completions"), isHeaderModified: () => false, }); diff --git a/tests/unit/proxy/proxy-forwarder-retry-limit.test.ts b/tests/unit/proxy/proxy-forwarder-retry-limit.test.ts index 9e3f11c16..266fbb2bb 100644 --- a/tests/unit/proxy/proxy-forwarder-retry-limit.test.ts +++ b/tests/unit/proxy/proxy-forwarder-retry-limit.test.ts @@ -1,4 +1,6 @@ import { beforeEach, describe, expect, test, vi } from "vitest"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; +import { V1_ENDPOINT_PATHS } from "@/app/v1/_lib/proxy/endpoint-paths"; const mocks = vi.hoisted(() => { return { @@ -188,6 +190,7 @@ function createSession(requestUrl: URL = new URL("https://example.com/v1/message originalModelName: null, originalUrlPathname: null, providerChain: [], + endpointPolicy: resolveEndpointPolicy(requestUrl.pathname), cacheTtlResolved: null, context1mApplied: false, specialSettings: [], @@ -200,6 +203,72 @@ function createSession(requestUrl: URL = new URL("https://example.com/v1/message return session as ProxySession; } +describe("ProxyForwarder - raw passthrough policy parity (T5 RED)", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(categorizeErrorAsync).mockResolvedValue(ErrorCategory.PROVIDER_ERROR); + }); + + test.each([ + V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS, + V1_ENDPOINT_PATHS.RESPONSES_COMPACT, + ])("RED: %s 失败时都应统一为 no-retry/no-switch/no-circuit(Wave2 未实现前应失败)", async (pathname) => { + vi.useFakeTimers(); + + try { + const session = createSession(new URL(`https://example.com${pathname}`)); + const provider = createProvider({ + providerType: "claude", + providerVendorId: 123, + maxRetryAttempts: 3, + }); + session.setProvider(provider); + + mocks.getPreferredProviderEndpoints.mockResolvedValue([ + makeEndpoint({ + id: 1, + vendorId: 123, + providerType: "claude", + url: "https://ep1.example.com", + }), + makeEndpoint({ + id: 2, + vendorId: 123, + providerType: "claude", + url: "https://ep2.example.com", + }), + ]); + + const doForward = vi.spyOn( + ProxyForwarder as unknown as { doForward: (...args: unknown[]) => unknown }, + "doForward" + ); + const selectAlternative = vi.spyOn( + ProxyForwarder as unknown as { selectAlternative: (...args: unknown[]) => unknown }, + "selectAlternative" + ); + + doForward.mockImplementation(async () => { + throw new ProxyError("upstream failed", 500); + }); + + const sendPromise = ProxyForwarder.send(session); + let caughtError: Error | null = null; + sendPromise.catch((error) => { + caughtError = error as Error; + }); + await vi.runAllTimersAsync(); + + expect(caughtError).toBeInstanceOf(ProxyError); + expect(doForward).toHaveBeenCalledTimes(1); + expect(selectAlternative).not.toHaveBeenCalled(); + expect(mocks.recordFailure).not.toHaveBeenCalled(); + } finally { + vi.useRealTimers(); + } + }); +}); + describe("ProxyForwarder - retry limit enforcement", () => { beforeEach(() => { vi.clearAllMocks(); diff --git a/tests/unit/proxy/proxy-handler-session-id-error.test.ts b/tests/unit/proxy/proxy-handler-session-id-error.test.ts index 18062cc79..a7132a7a7 100644 --- a/tests/unit/proxy/proxy-handler-session-id-error.test.ts +++ b/tests/unit/proxy/proxy-handler-session-id-error.test.ts @@ -1,4 +1,6 @@ import { describe, expect, test, vi } from "vitest"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; +import { V1_ENDPOINT_PATHS } from "@/app/v1/_lib/proxy/endpoint-paths"; import { ProxyResponses } from "@/app/v1/_lib/proxy/responses"; import { ProxyError } from "@/app/v1/_lib/proxy/errors"; @@ -11,6 +13,7 @@ const h = vi.hoisted(() => ({ model: "gpt", message: {}, }, + getEndpointPolicy: () => resolveEndpointPolicy(h.session.requestUrl.pathname), isCountTokensRequest: () => false, setOriginalFormat: () => {}, recordForwardStart: () => {}, @@ -40,6 +43,12 @@ vi.mock("@/app/v1/_lib/proxy/session", () => ({ vi.mock("@/app/v1/_lib/proxy/guard-pipeline", () => ({ RequestType: { CHAT: "CHAT", COUNT_TOKENS: "COUNT_TOKENS" }, GuardPipelineBuilder: { + fromSession: () => ({ + run: async () => { + if (h.pipelineError) throw h.pipelineError; + return h.earlyResponse; + }, + }), fromRequestType: () => ({ run: async () => { if (h.pipelineError) throw h.pipelineError; @@ -167,6 +176,41 @@ describe("handleProxyRequest - session id on errors", async () => { expect(h.trackerCalls).toEqual(["inc", "startRequest", "dec"]); }); + test.each([ + { + pathname: V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS, + isCountTokensRequest: true, + }, + { + pathname: V1_ENDPOINT_PATHS.RESPONSES_COMPACT, + isCountTokensRequest: false, + }, + ])("RED: raw endpoint $pathname 应统一跳过并发计数(Wave2 未实现前会失败)", async ({ + pathname, + isCountTokensRequest, + }) => { + h.fromContextError = null; + h.session.originalFormat = "claude"; + h.endpointFormat = "openai"; + h.trackerCalls.length = 0; + h.pipelineError = null; + h.earlyResponse = null; + h.forwardResponse = new Response("ok", { status: 200 }); + h.dispatchedResponse = null; + + h.session.requestUrl = new URL(`http://localhost${pathname}`); + h.session.getEndpointPolicy = () => resolveEndpointPolicy(h.session.requestUrl.pathname); + h.session.sessionId = "s_123"; + h.session.messageContext = { id: 1, user: { id: 1, name: "u" }, key: { name: "k" } }; + h.session.provider = { id: 1, name: "p" }; + h.session.isCountTokensRequest = () => isCountTokensRequest; + + const res = await handleProxyRequest({} as any); + + expect(res.status).toBe(200); + expect(h.trackerCalls).toEqual(["startRequest"]); + }); + test("session not created and ProxyError thrown: returns buildError without session header", async () => { h.fromContextError = new ProxyError("upstream", 401); h.endpointFormat = null; diff --git a/tests/unit/proxy/response-handler-endpoint-circuit-isolation.test.ts b/tests/unit/proxy/response-handler-endpoint-circuit-isolation.test.ts index 533f8247f..e3e83fbd7 100644 --- a/tests/unit/proxy/response-handler-endpoint-circuit-isolation.test.ts +++ b/tests/unit/proxy/response-handler-endpoint-circuit-isolation.test.ts @@ -9,6 +9,7 @@ */ import { beforeEach, describe, expect, it, vi } from "vitest"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; import type { ModelPriceData } from "@/types/model-price"; // Track async tasks for draining @@ -173,6 +174,7 @@ function createSession(opts?: { sessionId?: string | null }): ProxySession { specialSettings: [], cachedPriceData: undefined, cachedBillingModelSource: undefined, + endpointPolicy: resolveEndpointPolicy("/v1/messages"), isHeaderModified: () => false, getContext1mApplied: () => false, getOriginalModel: () => "test-model", diff --git a/tests/unit/proxy/response-handler-gemini-stream-passthrough-timeouts.test.ts b/tests/unit/proxy/response-handler-gemini-stream-passthrough-timeouts.test.ts index bc26ea8af..afc42b326 100644 --- a/tests/unit/proxy/response-handler-gemini-stream-passthrough-timeouts.test.ts +++ b/tests/unit/proxy/response-handler-gemini-stream-passthrough-timeouts.test.ts @@ -2,6 +2,7 @@ import { createServer } from "node:http"; import type { Socket } from "node:net"; import { beforeEach, describe, expect, test, vi } from "vitest"; import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler"; import { ProxySession } from "@/app/v1/_lib/proxy/session"; import type { Provider } from "@/types/provider"; @@ -191,6 +192,7 @@ function createSession(params: { specialSettings: [], cachedPriceData: undefined, cachedBillingModelSource: undefined, + endpointPolicy: resolveEndpointPolicy("/v1/chat/completions"), isHeaderModified: () => false, }); diff --git a/tests/unit/proxy/response-handler-lease-decrement.test.ts b/tests/unit/proxy/response-handler-lease-decrement.test.ts index b6a0a4773..1100256ef 100644 --- a/tests/unit/proxy/response-handler-lease-decrement.test.ts +++ b/tests/unit/proxy/response-handler-lease-decrement.test.ts @@ -9,6 +9,7 @@ */ import { beforeEach, describe, expect, it, vi } from "vitest"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; import type { ModelPriceData } from "@/types/model-price"; // Track async tasks for draining @@ -157,6 +158,7 @@ function createSession(opts: { specialSettings: [], cachedPriceData: undefined, cachedBillingModelSource: undefined, + endpointPolicy: resolveEndpointPolicy("/v1/messages"), isHeaderModified: () => false, getContext1mApplied: () => false, getOriginalModel: () => originalModel, diff --git a/tests/unit/proxy/session.test.ts b/tests/unit/proxy/session.test.ts index 5afe30248..9771ea4df 100644 --- a/tests/unit/proxy/session.test.ts +++ b/tests/unit/proxy/session.test.ts @@ -1,4 +1,6 @@ import { describe, expect, it, vi } from "vitest"; +import { isRawPassthroughEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; +import { V1_ENDPOINT_PATHS } from "@/app/v1/_lib/proxy/endpoint-paths"; import type { ModelPrice, ModelPriceData } from "@/types/model-price"; import type { SystemSettings } from "@/types/system-config"; import type { Provider } from "@/types/provider"; @@ -101,6 +103,53 @@ function createSession({ return session; } +describe("ProxySession endpoint policy", () => { + it.each([ + V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS, + "/V1/RESPONSES/COMPACT/", + ])("应在创建时解析 raw passthrough policy: %s", (pathname) => { + const session = createSession({ + redirectedModel: null, + requestUrl: new URL(`http://localhost${pathname}`), + }); + + const policy = session.getEndpointPolicy(); + expect(isRawPassthroughEndpointPolicy(policy)).toBe(true); + expect(policy.trackConcurrentRequests).toBe(false); + }); + + it("应在请求路径后续变更后保持创建时 policy 不变", () => { + const session = createSession({ + redirectedModel: null, + requestUrl: new URL(`http://localhost${V1_ENDPOINT_PATHS.MESSAGES_COUNT_TOKENS}`), + }); + + const policyAtCreation = session.getEndpointPolicy(); + session.requestUrl = new URL(`http://localhost${V1_ENDPOINT_PATHS.MESSAGES}`); + + expect(session.getEndpointPolicy()).toBe(policyAtCreation); + expect(isRawPassthroughEndpointPolicy(session.getEndpointPolicy())).toBe(true); + }); + + it("应在 pathname 无法读取时回退到 default policy", () => { + const malformedUrl = { + get pathname() { + throw new Error("broken pathname"); + }, + } as unknown as URL; + + const session = createSession({ + redirectedModel: null, + requestUrl: malformedUrl, + }); + + const policy = session.getEndpointPolicy(); + expect(isRawPassthroughEndpointPolicy(policy)).toBe(false); + expect(policy.kind).toBe("default"); + expect(policy.trackConcurrentRequests).toBe(true); + }); +}); + describe("ProxySession.getCachedPriceDataByBillingSource", () => { it("配置 = original 时应优先使用原始模型", async () => { const originalPriceData: ModelPriceData = { input_cost_per_token: 1, output_cost_per_token: 2 }; From 40b030df615962f35499f2a7e58e95b15e1584c4 Mon Sep 17 00:00:00 2001 From: ding113 Date: Tue, 17 Feb 2026 21:17:53 +0800 Subject: [PATCH 2/2] fix(proxy): remove duplicate cache TTL call and cache endpoint policy in local var - Remove redundant first applyCacheTtlOverrideToMessage call (lines 1905-1917) that duplicated the post-Anthropic-overrides call (lines 2025-2036) - Cache session.getEndpointPolicy() in local variable in error handling path to avoid repeated accessor calls Addresses: gemini-code-assist and coderabbitai review comments --- src/app/v1/_lib/proxy/forwarder.ts | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index b3d6e4193..bc161dd0e 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -1622,7 +1622,8 @@ export class ProxyForwarder { } // Raw passthrough endpoints: no circuit breaker, no provider switch, no retry - if (!session.getEndpointPolicy().allowRetry) { + const endpointPolicy = session.getEndpointPolicy(); + if (!endpointPolicy.allowRetry) { logger.debug( "ProxyForwarder: raw passthrough endpoint error, skipping circuit breaker and provider switch", { @@ -1630,7 +1631,7 @@ export class ProxyForwarder { providerName: currentProvider.name, statusCode, error: proxyError.message, - policyKind: session.getEndpointPolicy().kind, + policyKind: endpointPolicy.kind, } ); // Throw immediately: no retry, no provider switch @@ -1902,20 +1903,6 @@ export class ProxyForwarder { } else { // --- STANDARD HANDLING --- if (!session.getEndpointPolicy().bypassForwarderPreprocessing) { - if ( - resolvedCacheTtl && - (provider.providerType === "claude" || provider.providerType === "claude-auth") - ) { - const applied = applyCacheTtlOverrideToMessage(session.request.message, resolvedCacheTtl); - if (applied) { - logger.info("ProxyForwarder: Applied cache TTL override to request", { - providerId: provider.id, - providerName: provider.name, - cacheTtl: resolvedCacheTtl, - }); - } - } - // Codex 供应商级参数覆写(默认 inherit=遵循客户端) if (provider.providerType === "codex") { const { request: overridden, audit } = applyCodexProviderOverridesWithAudit(