diff --git a/src/actions/provider-endpoints.ts b/src/actions/provider-endpoints.ts index 47c6e7e70..fccc7e7c1 100644 --- a/src/actions/provider-endpoints.ts +++ b/src/actions/provider-endpoints.ts @@ -333,7 +333,15 @@ export async function removeProviderEndpoint(input: unknown): Promise // Auto cleanup: delete vendor if it has no active providers/endpoints. if (provider?.providerVendorId) { - await tryDeleteProviderVendorIfEmpty(provider.providerVendorId); + try { + await tryDeleteProviderVendorIfEmpty(provider.providerVendorId); + } catch (error) { + logger.warn("removeProvider:vendor_cleanup_failed", { + providerId, + vendorId: provider.providerVendorId, + error: error instanceof Error ? error.message : String(error), + }); + } } // 广播缓存更新(跨实例即时生效) @@ -3713,10 +3721,13 @@ export async function reclusterProviderVendors(args: { if (!provider) continue; // Get or create new vendor - const newVendorId = await getOrCreateProviderVendorIdFromUrls({ - providerUrl: provider.url, - websiteUrl: provider.websiteUrl ?? null, - }); + const newVendorId = await getOrCreateProviderVendorIdFromUrls( + { + providerUrl: provider.url, + websiteUrl: provider.websiteUrl ?? null, + }, + { tx } + ); // Update provider's vendorId await tx @@ -3731,7 +3742,14 @@ export async function reclusterProviderVendors(args: { // Cleanup empty vendors for (const oldVendorId of oldVendorIds) { - await tryDeleteProviderVendorIfEmpty(oldVendorId); + try { + await tryDeleteProviderVendorIfEmpty(oldVendorId); + } catch (error) { + logger.warn("reclusterProviderVendors:vendor_cleanup_failed", { + vendorId: oldVendorId, + error: error instanceof Error ? error.message : String(error), + }); + } } // Publish cache invalidation diff --git a/src/repository/provider-endpoints.ts b/src/repository/provider-endpoints.ts index b395c068c..0fabd3a6d 100644 --- a/src/repository/provider-endpoints.ts +++ b/src/repository/provider-endpoints.ts @@ -1,6 +1,6 @@ "use server"; -import { and, asc, desc, eq, gt, isNotNull, isNull, or, sql } from "drizzle-orm"; +import { and, asc, desc, eq, gt, isNotNull, isNull, ne, or, sql } from "drizzle-orm"; import { db } from "@/drizzle/db"; import { providerEndpointProbeLogs, @@ -8,6 +8,7 @@ import { providers, providerVendors, } from "@/drizzle/schema"; +import { resetEndpointCircuit } from "@/lib/endpoint-circuit-breaker"; import { logger } from "@/lib/logger"; import type { ProviderEndpoint, @@ -17,6 +18,41 @@ import type { ProviderVendor, } from "@/types/provider"; +type TransactionExecutor = Parameters[0]>[0]; +type QueryExecutor = Pick< + TransactionExecutor, + "select" | "insert" | "update" | "delete" | "execute" +>; + +function isUniqueViolationError(error: unknown): boolean { + if (!error || typeof error !== "object") { + return false; + } + + const candidate = error as { + code?: string; + message?: string; + cause?: { code?: string; message?: string }; + }; + + if (candidate.code === "23505") { + return true; + } + + if (typeof candidate.message === "string" && candidate.message.includes("duplicate key value")) { + return true; + } + + if (candidate.cause?.code === "23505") { + return true; + } + + return ( + typeof candidate.cause?.message === "string" && + candidate.cause.message.includes("duplicate key value") + ); +} + function toDate(value: unknown): Date { if (value instanceof Date) return value; if (typeof value === "string" || typeof value === "number") return new Date(value); @@ -252,12 +288,17 @@ export async function deleteProviderEndpointProbeLogsBeforeDateBatch(input: { return typeof rowCount === "number" ? rowCount : 0; } -export async function getOrCreateProviderVendorIdFromUrls(input: { - providerUrl: string; - websiteUrl?: string | null; - faviconUrl?: string | null; - displayName?: string | null; -}): Promise { +export async function getOrCreateProviderVendorIdFromUrls( + input: { + providerUrl: string; + websiteUrl?: string | null; + faviconUrl?: string | null; + displayName?: string | null; + }, + options?: { tx?: QueryExecutor } +): Promise { + const executor = options?.tx ?? db; + // Use new computeVendorKey for consistent vendor key calculation const websiteDomain = await computeVendorKey({ providerUrl: input.providerUrl, @@ -267,7 +308,7 @@ export async function getOrCreateProviderVendorIdFromUrls(input: { throw new Error("Failed to resolve provider vendor domain"); } - const existing = await db + const existing = await executor .select({ id: providerVendors.id }) .from(providerVendors) .where(eq(providerVendors.websiteDomain, websiteDomain)) @@ -277,7 +318,7 @@ export async function getOrCreateProviderVendorIdFromUrls(input: { } const now = new Date(); - const inserted = await db + const inserted = await executor .insert(providerVendors) .values({ websiteDomain, @@ -293,7 +334,7 @@ export async function getOrCreateProviderVendorIdFromUrls(input: { return inserted[0].id; } - const fallback = await db + const fallback = await executor .select({ id: providerVendors.id }) .from(providerVendors) .where(eq(providerVendors.websiteDomain, websiteDomain)) @@ -556,57 +597,58 @@ export async function deleteProviderVendor(vendorId: number): Promise { return deleted; } -export async function tryDeleteProviderVendorIfEmpty(vendorId: number): Promise { - try { - return await db.transaction(async (tx) => { - // 1) Must have no active providers (soft-deleted rows still exist but should not block). - const [activeProvider] = await tx - .select({ id: providers.id }) - .from(providers) - .where(and(eq(providers.providerVendorId, vendorId), isNull(providers.deletedAt))) - .limit(1); +export async function tryDeleteProviderVendorIfEmpty( + vendorId: number, + options?: { tx?: QueryExecutor } +): Promise { + const runInTx = async (tx: QueryExecutor): Promise => { + // 1) Must have no active providers (soft-deleted rows still exist but should not block). + const [activeProvider] = await tx + .select({ id: providers.id }) + .from(providers) + .where(and(eq(providers.providerVendorId, vendorId), isNull(providers.deletedAt))) + .limit(1); - if (activeProvider) { - return false; - } + if (activeProvider) { + return false; + } - // 2) Must have no active endpoints. - const [activeEndpoint] = await tx - .select({ id: providerEndpoints.id }) - .from(providerEndpoints) - .where(and(eq(providerEndpoints.vendorId, vendorId), isNull(providerEndpoints.deletedAt))) - .limit(1); + // 2) Must have no active endpoints. + const [activeEndpoint] = await tx + .select({ id: providerEndpoints.id }) + .from(providerEndpoints) + .where(and(eq(providerEndpoints.vendorId, vendorId), isNull(providerEndpoints.deletedAt))) + .limit(1); - if (activeEndpoint) { - return false; - } + if (activeEndpoint) { + return false; + } - // 3) Hard delete soft-deleted providers to satisfy FK `onDelete: restrict`. - await tx - .delete(providers) - .where(and(eq(providers.providerVendorId, vendorId), isNotNull(providers.deletedAt))); + // 3) Hard delete soft-deleted providers to satisfy FK `onDelete: restrict`. + await tx + .delete(providers) + .where(and(eq(providers.providerVendorId, vendorId), isNotNull(providers.deletedAt))); - // 4) Delete vendor. Endpoints will be physically removed by FK cascade. - const deleted = await tx - .delete(providerVendors) - .where( - and( - eq(providerVendors.id, vendorId), - sql`NOT EXISTS (SELECT 1 FROM providers p WHERE p.provider_vendor_id = ${vendorId} AND p.deleted_at IS NULL)`, - sql`NOT EXISTS (SELECT 1 FROM provider_endpoints e WHERE e.vendor_id = ${vendorId} AND e.deleted_at IS NULL)` - ) + // 4) Delete vendor. Endpoints will be physically removed by FK cascade. + const deleted = await tx + .delete(providerVendors) + .where( + and( + eq(providerVendors.id, vendorId), + sql`NOT EXISTS (SELECT 1 FROM providers p WHERE p.provider_vendor_id = ${vendorId} AND p.deleted_at IS NULL)`, + sql`NOT EXISTS (SELECT 1 FROM provider_endpoints e WHERE e.vendor_id = ${vendorId} AND e.deleted_at IS NULL)` ) - .returning({ id: providerVendors.id }); + ) + .returning({ id: providerVendors.id }); - return deleted.length > 0; - }); - } catch (error) { - logger.warn("[ProviderVendor] Auto delete failed", { - vendorId, - error: error instanceof Error ? error.message : String(error), - }); - return false; + return deleted.length > 0; + }; + + if (options?.tx) { + return await runInTx(options.tx); } + + return await db.transaction(async (tx) => runInTx(tx)); } export async function findProviderEndpointsByVendorAndType( @@ -714,26 +756,31 @@ export async function createProviderEndpoint(payload: { return toProviderEndpoint(row); } -export async function ensureProviderEndpointExistsForUrl(input: { - vendorId: number; - providerType: ProviderType; - url: string; - label?: string | null; -}): Promise { +export async function ensureProviderEndpointExistsForUrl( + input: { + vendorId: number; + providerType: ProviderType; + url: string; + label?: string | null; + }, + options?: { tx?: QueryExecutor } +): Promise { + const executor = options?.tx ?? db; + const trimmedUrl = input.url.trim(); if (!trimmedUrl) { - return false; + throw new Error("[ProviderEndpointEnsure] url is required"); } try { // eslint-disable-next-line no-new new URL(trimmedUrl); } catch { - return false; + throw new Error("[ProviderEndpointEnsure] url must be a valid URL"); } const now = new Date(); - const inserted = await db + const inserted = await executor .insert(providerEndpoints) .values({ vendorId: input.vendorId, @@ -750,6 +797,373 @@ export async function ensureProviderEndpointExistsForUrl(input: { return inserted.length > 0; } +export interface SyncProviderEndpointOnProviderEditInput { + providerId: number; + vendorId: number; + providerType: ProviderType; + previousVendorId?: number | null; + previousProviderType?: ProviderType | null; + previousUrl: string; + nextUrl: string; + keepPreviousWhenReferenced?: boolean; +} + +type ProviderEndpointSyncAction = + | "noop" + | "created-next" + | "revived-next" + | "updated-previous-in-place" + | "kept-previous-and-created-next" + | "kept-previous-and-revived-next" + | "kept-previous-and-kept-next" + | "soft-deleted-previous-and-kept-next" + | "soft-deleted-previous-and-revived-next"; + +export interface SyncProviderEndpointOnProviderEditResult { + action: ProviderEndpointSyncAction; + resetCircuitEndpointId?: number; +} + +export async function syncProviderEndpointOnProviderEdit( + input: SyncProviderEndpointOnProviderEditInput, + options?: { tx?: QueryExecutor } +): Promise { + const previousUrl = input.previousUrl.trim(); + const nextUrl = input.nextUrl.trim(); + + if (!nextUrl) { + throw new Error("[ProviderEndpointSync] nextUrl is required"); + } + + try { + // eslint-disable-next-line no-new + new URL(nextUrl); + } catch { + throw new Error("[ProviderEndpointSync] nextUrl must be a valid URL"); + } + + const previousVendorId = input.previousVendorId ?? input.vendorId; + const previousProviderType = input.previousProviderType ?? input.providerType; + const keepPreviousWhenReferenced = input.keepPreviousWhenReferenced !== false; + + const runInTx = async (tx: QueryExecutor): Promise => { + const now = new Date(); + + const loadEndpoint = async (args: { + vendorId: number; + providerType: ProviderType; + url: string; + }): Promise<{ id: number; deletedAt: Date | null; isEnabled: boolean } | null> => { + const [row] = await tx + .select({ + id: providerEndpoints.id, + deletedAt: providerEndpoints.deletedAt, + isEnabled: providerEndpoints.isEnabled, + }) + .from(providerEndpoints) + .where( + and( + eq(providerEndpoints.vendorId, args.vendorId), + eq(providerEndpoints.providerType, args.providerType), + eq(providerEndpoints.url, args.url) + ) + ) + .limit(1); + + return row + ? { + id: row.id, + deletedAt: row.deletedAt, + isEnabled: row.isEnabled, + } + : null; + }; + + const hasActiveReferencesOnPreviousUrl = async (): Promise => { + const [activeReference] = await tx + .select({ id: providers.id }) + .from(providers) + .where( + and( + eq(providers.providerVendorId, previousVendorId), + eq(providers.providerType, previousProviderType), + eq(providers.url, previousUrl), + isNull(providers.deletedAt), + ne(providers.id, input.providerId) + ) + ) + .limit(1); + + return Boolean(activeReference); + }; + + const ensureNextEndpointActive = async (options?: { + reactivateDisabled?: boolean; + }): Promise<"created-next" | "revived-next" | "noop"> => { + const reactivateDisabled = options?.reactivateDisabled ?? true; + const nextEndpoint = await loadEndpoint({ + vendorId: input.vendorId, + providerType: input.providerType, + url: nextUrl, + }); + + if (!nextEndpoint) { + const inserted = await tx + .insert(providerEndpoints) + .values({ + vendorId: input.vendorId, + providerType: input.providerType, + url: nextUrl, + label: null, + updatedAt: now, + }) + .onConflictDoNothing({ + target: [ + providerEndpoints.vendorId, + providerEndpoints.providerType, + providerEndpoints.url, + ], + }) + .returning({ id: providerEndpoints.id }); + + if (inserted[0]) { + return "created-next"; + } + + const concurrentEndpoint = await loadEndpoint({ + vendorId: input.vendorId, + providerType: input.providerType, + url: nextUrl, + }); + + if (!concurrentEndpoint) { + throw new Error("[ProviderEndpointSync] failed to load next endpoint after conflict"); + } + + if (concurrentEndpoint.deletedAt !== null) { + await tx + .update(providerEndpoints) + .set({ + deletedAt: null, + isEnabled: true, + updatedAt: now, + }) + .where(eq(providerEndpoints.id, concurrentEndpoint.id)); + + return "revived-next"; + } + + if (reactivateDisabled && !concurrentEndpoint.isEnabled) { + await tx + .update(providerEndpoints) + .set({ + isEnabled: true, + updatedAt: now, + }) + .where(eq(providerEndpoints.id, concurrentEndpoint.id)); + + return "revived-next"; + } + + return "noop"; + } + + if (nextEndpoint.deletedAt !== null) { + await tx + .update(providerEndpoints) + .set({ + deletedAt: null, + isEnabled: true, + updatedAt: now, + }) + .where(eq(providerEndpoints.id, nextEndpoint.id)); + + return "revived-next"; + } + + if (reactivateDisabled && !nextEndpoint.isEnabled) { + await tx + .update(providerEndpoints) + .set({ + isEnabled: true, + updatedAt: now, + }) + .where(eq(providerEndpoints.id, nextEndpoint.id)); + + return "revived-next"; + } + + return "noop"; + }; + + const previousKeyEqualsNextKey = + previousVendorId === input.vendorId && + previousProviderType === input.providerType && + previousUrl === nextUrl; + + if (previousKeyEqualsNextKey) { + const ensureResult = await ensureNextEndpointActive({ + reactivateDisabled: false, + }); + return { action: ensureResult === "noop" ? "noop" : ensureResult }; + } + + const previousEndpoint = await loadEndpoint({ + vendorId: previousVendorId, + providerType: previousProviderType, + url: previousUrl, + }); + + const nextEndpoint = await loadEndpoint({ + vendorId: input.vendorId, + providerType: input.providerType, + url: nextUrl, + }); + + if (previousEndpoint && !nextEndpoint) { + const previousIsReferenced = + keepPreviousWhenReferenced && (await hasActiveReferencesOnPreviousUrl()); + + if (!previousIsReferenced) { + const updatePreviousEndpointInPlace = async (executor: QueryExecutor): Promise => { + await executor + .update(providerEndpoints) + .set({ + vendorId: input.vendorId, + providerType: input.providerType, + url: nextUrl, + deletedAt: null, + isEnabled: true, + lastProbedAt: null, + lastProbeOk: null, + lastProbeStatusCode: null, + lastProbeLatencyMs: null, + lastProbeErrorType: null, + lastProbeErrorMessage: null, + updatedAt: now, + }) + .where(eq(providerEndpoints.id, previousEndpoint.id)); + }; + + let movedInPlace = false; + const executorWithSavepoint = tx as QueryExecutor & { + transaction?: (runInTx: (nestedTx: TransactionExecutor) => Promise) => Promise; + }; + + if (typeof executorWithSavepoint.transaction === "function") { + try { + await executorWithSavepoint.transaction(async (nestedTx) => { + await updatePreviousEndpointInPlace(nestedTx); + }); + movedInPlace = true; + } catch (error) { + if (!isUniqueViolationError(error)) { + throw error; + } + } + } else { + // No savepoint support means we cannot safely continue after unique violations. + await updatePreviousEndpointInPlace(tx); + movedInPlace = true; + } + + if (movedInPlace) { + return { + action: "updated-previous-in-place", + // Reset is an external side-effect and must run only after transaction commit. + resetCircuitEndpointId: previousEndpoint.id, + }; + } + + const ensureResult = await ensureNextEndpointActive(); + + await tx + .update(providerEndpoints) + .set({ + deletedAt: now, + isEnabled: false, + updatedAt: now, + }) + .where( + and(eq(providerEndpoints.id, previousEndpoint.id), isNull(providerEndpoints.deletedAt)) + ); + + return { + action: + ensureResult === "revived-next" + ? "soft-deleted-previous-and-revived-next" + : "soft-deleted-previous-and-kept-next", + }; + } + + const ensureResult = await ensureNextEndpointActive(); + return { + action: + ensureResult === "created-next" + ? "kept-previous-and-created-next" + : ensureResult === "revived-next" + ? "kept-previous-and-revived-next" + : "kept-previous-and-kept-next", + }; + } + + const ensureResult = await ensureNextEndpointActive(); + + if ( + previousEndpoint && + nextEndpoint && + previousEndpoint.id !== nextEndpoint.id && + previousEndpoint.deletedAt === null + ) { + const previousIsReferenced = + keepPreviousWhenReferenced && (await hasActiveReferencesOnPreviousUrl()); + + if (!previousIsReferenced) { + await tx + .update(providerEndpoints) + .set({ + deletedAt: now, + isEnabled: false, + updatedAt: now, + }) + .where( + and(eq(providerEndpoints.id, previousEndpoint.id), isNull(providerEndpoints.deletedAt)) + ); + + return { + action: + ensureResult === "revived-next" + ? "soft-deleted-previous-and-revived-next" + : "soft-deleted-previous-and-kept-next", + }; + } + } + + return { action: ensureResult === "noop" ? "noop" : ensureResult }; + }; + + if (options?.tx) { + return await runInTx(options.tx); + } + + const result = await db.transaction(async (tx) => runInTx(tx)); + + if (result.resetCircuitEndpointId != null) { + try { + await resetEndpointCircuit(result.resetCircuitEndpointId); + } catch (error) { + logger.warn("syncProviderEndpointOnProviderEdit:reset_endpoint_circuit_failed", { + endpointId: result.resetCircuitEndpointId, + error: error instanceof Error ? error.message : String(error), + }); + } + + return { action: result.action }; + } + + return result; +} + export async function backfillProviderEndpointsFromProviders(): Promise<{ inserted: number; uniqueCandidates: number; diff --git a/src/repository/provider.ts b/src/repository/provider.ts index 4eaa9cf23..697d22686 100644 --- a/src/repository/provider.ts +++ b/src/repository/provider.ts @@ -4,6 +4,7 @@ import { and, desc, eq, isNotNull, isNull, ne, sql } from "drizzle-orm"; import { db } from "@/drizzle/db"; import { providers } from "@/drizzle/schema"; import { getCachedProviders } from "@/lib/cache/provider-cache"; +import { resetEndpointCircuit } from "@/lib/endpoint-circuit-breaker"; import { logger } from "@/lib/logger"; import { resolveSystemTimezone } from "@/lib/utils/timezone"; import type { CreateProviderData, Provider, UpdateProviderData } from "@/types/provider"; @@ -11,22 +12,15 @@ import { toProvider } from "./_shared/transformers"; import { ensureProviderEndpointExistsForUrl, getOrCreateProviderVendorIdFromUrls, + syncProviderEndpointOnProviderEdit, tryDeleteProviderVendorIfEmpty, } from "./provider-endpoints"; export async function createProvider(providerData: CreateProviderData): Promise { - const providerVendorId = await getOrCreateProviderVendorIdFromUrls({ - providerUrl: providerData.url, - websiteUrl: providerData.website_url ?? null, - faviconUrl: providerData.favicon_url ?? null, - displayName: providerData.name, - }); - const dbData = { name: providerData.name, url: providerData.url, key: providerData.key, - providerVendorId, isEnabled: providerData.is_enabled, weight: providerData.weight, priority: providerData.priority, @@ -78,80 +72,93 @@ export async function createProvider(providerData: CreateProviderData): Promise< cc: providerData.cc, }; - const [provider] = await db.insert(providers).values(dbData).returning({ - id: providers.id, - name: providers.name, - url: providers.url, - key: providers.key, - providerVendorId: providers.providerVendorId, - isEnabled: providers.isEnabled, - weight: providers.weight, - priority: providers.priority, - costMultiplier: providers.costMultiplier, - groupTag: providers.groupTag, - providerType: providers.providerType, - preserveClientIp: providers.preserveClientIp, - modelRedirects: providers.modelRedirects, - allowedModels: providers.allowedModels, - mcpPassthroughType: providers.mcpPassthroughType, - mcpPassthroughUrl: providers.mcpPassthroughUrl, - limit5hUsd: providers.limit5hUsd, - limitDailyUsd: providers.limitDailyUsd, - dailyResetMode: providers.dailyResetMode, - dailyResetTime: providers.dailyResetTime, - limitWeeklyUsd: providers.limitWeeklyUsd, - limitMonthlyUsd: providers.limitMonthlyUsd, - limitTotalUsd: providers.limitTotalUsd, - totalCostResetAt: providers.totalCostResetAt, - limitConcurrentSessions: providers.limitConcurrentSessions, - maxRetryAttempts: providers.maxRetryAttempts, - circuitBreakerFailureThreshold: providers.circuitBreakerFailureThreshold, - circuitBreakerOpenDuration: providers.circuitBreakerOpenDuration, - circuitBreakerHalfOpenSuccessThreshold: providers.circuitBreakerHalfOpenSuccessThreshold, - proxyUrl: providers.proxyUrl, - proxyFallbackToDirect: providers.proxyFallbackToDirect, - firstByteTimeoutStreamingMs: providers.firstByteTimeoutStreamingMs, - streamingIdleTimeoutMs: providers.streamingIdleTimeoutMs, - requestTimeoutNonStreamingMs: providers.requestTimeoutNonStreamingMs, - websiteUrl: providers.websiteUrl, - faviconUrl: providers.faviconUrl, - cacheTtlPreference: providers.cacheTtlPreference, - context1mPreference: providers.context1mPreference, - codexReasoningEffortPreference: providers.codexReasoningEffortPreference, - codexReasoningSummaryPreference: providers.codexReasoningSummaryPreference, - codexTextVerbosityPreference: providers.codexTextVerbosityPreference, - codexParallelToolCallsPreference: providers.codexParallelToolCallsPreference, - anthropicMaxTokensPreference: providers.anthropicMaxTokensPreference, - anthropicThinkingBudgetPreference: providers.anthropicThinkingBudgetPreference, - geminiGoogleSearchPreference: providers.geminiGoogleSearchPreference, - tpm: providers.tpm, - rpm: providers.rpm, - rpd: providers.rpd, - cc: providers.cc, - createdAt: providers.createdAt, - updatedAt: providers.updatedAt, - deletedAt: providers.deletedAt, - }); - - const created = toProvider(provider); + return db.transaction(async (tx) => { + const providerVendorId = await getOrCreateProviderVendorIdFromUrls( + { + providerUrl: providerData.url, + websiteUrl: providerData.website_url ?? null, + faviconUrl: providerData.favicon_url ?? null, + displayName: providerData.name, + }, + { tx } + ); - if (created.providerVendorId) { - try { - await ensureProviderEndpointExistsForUrl({ - vendorId: created.providerVendorId, - providerType: created.providerType, - url: created.url, - }); - } catch (error) { - logger.warn("[Provider] Failed to seed provider endpoint from provider.url", { + const [provider] = await tx + .insert(providers) + .values({ + ...dbData, providerVendorId, - providerType: created.providerType, - error: error instanceof Error ? error.message : String(error), + }) + .returning({ + id: providers.id, + name: providers.name, + url: providers.url, + key: providers.key, + providerVendorId: providers.providerVendorId, + isEnabled: providers.isEnabled, + weight: providers.weight, + priority: providers.priority, + costMultiplier: providers.costMultiplier, + groupTag: providers.groupTag, + providerType: providers.providerType, + preserveClientIp: providers.preserveClientIp, + modelRedirects: providers.modelRedirects, + allowedModels: providers.allowedModels, + mcpPassthroughType: providers.mcpPassthroughType, + mcpPassthroughUrl: providers.mcpPassthroughUrl, + limit5hUsd: providers.limit5hUsd, + limitDailyUsd: providers.limitDailyUsd, + dailyResetMode: providers.dailyResetMode, + dailyResetTime: providers.dailyResetTime, + limitWeeklyUsd: providers.limitWeeklyUsd, + limitMonthlyUsd: providers.limitMonthlyUsd, + limitTotalUsd: providers.limitTotalUsd, + totalCostResetAt: providers.totalCostResetAt, + limitConcurrentSessions: providers.limitConcurrentSessions, + maxRetryAttempts: providers.maxRetryAttempts, + circuitBreakerFailureThreshold: providers.circuitBreakerFailureThreshold, + circuitBreakerOpenDuration: providers.circuitBreakerOpenDuration, + circuitBreakerHalfOpenSuccessThreshold: providers.circuitBreakerHalfOpenSuccessThreshold, + proxyUrl: providers.proxyUrl, + proxyFallbackToDirect: providers.proxyFallbackToDirect, + firstByteTimeoutStreamingMs: providers.firstByteTimeoutStreamingMs, + streamingIdleTimeoutMs: providers.streamingIdleTimeoutMs, + requestTimeoutNonStreamingMs: providers.requestTimeoutNonStreamingMs, + websiteUrl: providers.websiteUrl, + faviconUrl: providers.faviconUrl, + cacheTtlPreference: providers.cacheTtlPreference, + context1mPreference: providers.context1mPreference, + codexReasoningEffortPreference: providers.codexReasoningEffortPreference, + codexReasoningSummaryPreference: providers.codexReasoningSummaryPreference, + codexTextVerbosityPreference: providers.codexTextVerbosityPreference, + codexParallelToolCallsPreference: providers.codexParallelToolCallsPreference, + anthropicMaxTokensPreference: providers.anthropicMaxTokensPreference, + anthropicThinkingBudgetPreference: providers.anthropicThinkingBudgetPreference, + geminiGoogleSearchPreference: providers.geminiGoogleSearchPreference, + tpm: providers.tpm, + rpm: providers.rpm, + rpd: providers.rpd, + cc: providers.cc, + createdAt: providers.createdAt, + updatedAt: providers.updatedAt, + deletedAt: providers.deletedAt, }); + + const created = toProvider(provider); + + if (created.providerVendorId) { + await ensureProviderEndpointExistsForUrl( + { + vendorId: created.providerVendorId, + providerType: created.providerType, + url: created.url, + }, + { tx } + ); } - } - return created; + return created; + }); } export async function findProviderList( @@ -386,8 +393,7 @@ export async function updateProvider( return findProviderById(id); } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const dbData: any = { + const dbData: Partial = { updatedAt: new Date(), }; @@ -478,127 +484,180 @@ export async function updateProvider( if (providerData.rpd !== undefined) dbData.rpd = providerData.rpd; if (providerData.cc !== undefined) dbData.cc = providerData.cc; - let previousVendorId: number | null = null; - if (providerData.url !== undefined || providerData.website_url !== undefined) { - const [current] = await db - .select({ + const shouldRefreshVendor = + providerData.url !== undefined || providerData.website_url !== undefined; + const shouldSyncEndpoint = shouldRefreshVendor || providerData.provider_type !== undefined; + + const updateResult = await db.transaction(async (tx) => { + let previousVendorId: number | null = null; + let previousUrl: string | null = null; + let previousProviderType: Provider["providerType"] | null = null; + let endpointCircuitResetId: number | null = null; + + if (shouldSyncEndpoint) { + const [current] = await tx + .select({ + url: providers.url, + websiteUrl: providers.websiteUrl, + faviconUrl: providers.faviconUrl, + name: providers.name, + providerVendorId: providers.providerVendorId, + providerType: providers.providerType, + }) + .from(providers) + .where(and(eq(providers.id, id), isNull(providers.deletedAt))) + .limit(1); + + if (current) { + previousVendorId = current.providerVendorId; + previousUrl = current.url; + previousProviderType = current.providerType; + + if (shouldRefreshVendor) { + const providerVendorId = await getOrCreateProviderVendorIdFromUrls( + { + providerUrl: providerData.url ?? current.url, + websiteUrl: providerData.website_url ?? current.websiteUrl, + faviconUrl: providerData.favicon_url ?? current.faviconUrl, + displayName: providerData.name ?? current.name, + }, + { tx } + ); + dbData.providerVendorId = providerVendorId; + } + } + } + + const [provider] = await tx + .update(providers) + .set(dbData) + .where(and(eq(providers.id, id), isNull(providers.deletedAt))) + .returning({ + id: providers.id, + name: providers.name, url: providers.url, + key: providers.key, + providerVendorId: providers.providerVendorId, + isEnabled: providers.isEnabled, + weight: providers.weight, + priority: providers.priority, + costMultiplier: providers.costMultiplier, + groupTag: providers.groupTag, + providerType: providers.providerType, + preserveClientIp: providers.preserveClientIp, + modelRedirects: providers.modelRedirects, + allowedModels: providers.allowedModels, + mcpPassthroughType: providers.mcpPassthroughType, + mcpPassthroughUrl: providers.mcpPassthroughUrl, + limit5hUsd: providers.limit5hUsd, + limitDailyUsd: providers.limitDailyUsd, + dailyResetMode: providers.dailyResetMode, + dailyResetTime: providers.dailyResetTime, + limitWeeklyUsd: providers.limitWeeklyUsd, + limitMonthlyUsd: providers.limitMonthlyUsd, + limitTotalUsd: providers.limitTotalUsd, + totalCostResetAt: providers.totalCostResetAt, + limitConcurrentSessions: providers.limitConcurrentSessions, + maxRetryAttempts: providers.maxRetryAttempts, + circuitBreakerFailureThreshold: providers.circuitBreakerFailureThreshold, + circuitBreakerOpenDuration: providers.circuitBreakerOpenDuration, + circuitBreakerHalfOpenSuccessThreshold: providers.circuitBreakerHalfOpenSuccessThreshold, + proxyUrl: providers.proxyUrl, + proxyFallbackToDirect: providers.proxyFallbackToDirect, + firstByteTimeoutStreamingMs: providers.firstByteTimeoutStreamingMs, + streamingIdleTimeoutMs: providers.streamingIdleTimeoutMs, + requestTimeoutNonStreamingMs: providers.requestTimeoutNonStreamingMs, websiteUrl: providers.websiteUrl, faviconUrl: providers.faviconUrl, - name: providers.name, - providerVendorId: providers.providerVendorId, - }) - .from(providers) - .where(and(eq(providers.id, id), isNull(providers.deletedAt))) - .limit(1); - - if (current) { - previousVendorId = current.providerVendorId; - const providerVendorId = await getOrCreateProviderVendorIdFromUrls({ - providerUrl: providerData.url ?? current.url, - websiteUrl: providerData.website_url ?? current.websiteUrl, - faviconUrl: providerData.favicon_url ?? current.faviconUrl, - displayName: providerData.name ?? current.name, + cacheTtlPreference: providers.cacheTtlPreference, + context1mPreference: providers.context1mPreference, + codexReasoningEffortPreference: providers.codexReasoningEffortPreference, + codexReasoningSummaryPreference: providers.codexReasoningSummaryPreference, + codexTextVerbosityPreference: providers.codexTextVerbosityPreference, + codexParallelToolCallsPreference: providers.codexParallelToolCallsPreference, + anthropicMaxTokensPreference: providers.anthropicMaxTokensPreference, + anthropicThinkingBudgetPreference: providers.anthropicThinkingBudgetPreference, + geminiGoogleSearchPreference: providers.geminiGoogleSearchPreference, + tpm: providers.tpm, + rpm: providers.rpm, + rpd: providers.rpd, + cc: providers.cc, + createdAt: providers.createdAt, + updatedAt: providers.updatedAt, + deletedAt: providers.deletedAt, }); - dbData.providerVendorId = providerVendorId; + + if (!provider) return null; + const transformed = toProvider(provider); + + if (shouldSyncEndpoint && transformed.providerVendorId) { + if (previousUrl && previousProviderType) { + const syncResult = await syncProviderEndpointOnProviderEdit( + { + providerId: transformed.id, + vendorId: transformed.providerVendorId, + providerType: transformed.providerType, + previousVendorId, + previousProviderType, + previousUrl, + nextUrl: transformed.url, + keepPreviousWhenReferenced: true, + }, + { tx } + ); + + endpointCircuitResetId = syncResult.resetCircuitEndpointId ?? null; + } else { + await ensureProviderEndpointExistsForUrl( + { + vendorId: transformed.providerVendorId, + providerType: transformed.providerType, + url: transformed.url, + }, + { tx } + ); + } } - } - const [provider] = await db - .update(providers) - .set(dbData) - .where(and(eq(providers.id, id), isNull(providers.deletedAt))) - .returning({ - id: providers.id, - name: providers.name, - url: providers.url, - key: providers.key, - providerVendorId: providers.providerVendorId, - isEnabled: providers.isEnabled, - weight: providers.weight, - priority: providers.priority, - costMultiplier: providers.costMultiplier, - groupTag: providers.groupTag, - providerType: providers.providerType, - preserveClientIp: providers.preserveClientIp, - modelRedirects: providers.modelRedirects, - allowedModels: providers.allowedModels, - mcpPassthroughType: providers.mcpPassthroughType, - mcpPassthroughUrl: providers.mcpPassthroughUrl, - limit5hUsd: providers.limit5hUsd, - limitDailyUsd: providers.limitDailyUsd, - dailyResetMode: providers.dailyResetMode, - dailyResetTime: providers.dailyResetTime, - limitWeeklyUsd: providers.limitWeeklyUsd, - limitMonthlyUsd: providers.limitMonthlyUsd, - limitTotalUsd: providers.limitTotalUsd, - totalCostResetAt: providers.totalCostResetAt, - limitConcurrentSessions: providers.limitConcurrentSessions, - maxRetryAttempts: providers.maxRetryAttempts, - circuitBreakerFailureThreshold: providers.circuitBreakerFailureThreshold, - circuitBreakerOpenDuration: providers.circuitBreakerOpenDuration, - circuitBreakerHalfOpenSuccessThreshold: providers.circuitBreakerHalfOpenSuccessThreshold, - proxyUrl: providers.proxyUrl, - proxyFallbackToDirect: providers.proxyFallbackToDirect, - firstByteTimeoutStreamingMs: providers.firstByteTimeoutStreamingMs, - streamingIdleTimeoutMs: providers.streamingIdleTimeoutMs, - requestTimeoutNonStreamingMs: providers.requestTimeoutNonStreamingMs, - websiteUrl: providers.websiteUrl, - faviconUrl: providers.faviconUrl, - cacheTtlPreference: providers.cacheTtlPreference, - context1mPreference: providers.context1mPreference, - codexReasoningEffortPreference: providers.codexReasoningEffortPreference, - codexReasoningSummaryPreference: providers.codexReasoningSummaryPreference, - codexTextVerbosityPreference: providers.codexTextVerbosityPreference, - codexParallelToolCallsPreference: providers.codexParallelToolCallsPreference, - anthropicMaxTokensPreference: providers.anthropicMaxTokensPreference, - anthropicThinkingBudgetPreference: providers.anthropicThinkingBudgetPreference, - geminiGoogleSearchPreference: providers.geminiGoogleSearchPreference, - tpm: providers.tpm, - rpm: providers.rpm, - rpd: providers.rpd, - cc: providers.cc, - createdAt: providers.createdAt, - updatedAt: providers.updatedAt, - deletedAt: providers.deletedAt, - }); + return { + provider: transformed, + previousVendorIdToCleanup: + previousVendorId && transformed.providerVendorId !== previousVendorId + ? previousVendorId + : null, + endpointCircuitResetId, + }; + }); - if (!provider) return null; - const transformed = toProvider(provider); - - if ( - providerData.url !== undefined || - providerData.provider_type !== undefined || - providerData.website_url !== undefined - ) { - if ( - transformed.providerVendorId && - (providerData.url !== undefined || - transformed.providerVendorId !== previousVendorId || - previousVendorId === null) - ) { - try { - await ensureProviderEndpointExistsForUrl({ - vendorId: transformed.providerVendorId, - providerType: transformed.providerType, - url: transformed.url, - }); - } catch (error) { - logger.warn("[Provider] Failed to seed provider endpoint after provider update", { - providerId: transformed.id, - providerVendorId: transformed.providerVendorId, - providerType: transformed.providerType, - error: error instanceof Error ? error.message : String(error), - }); - } + if (!updateResult) { + return null; + } + + if (updateResult.endpointCircuitResetId != null) { + try { + await resetEndpointCircuit(updateResult.endpointCircuitResetId); + } catch (error) { + logger.warn("updateProvider:reset_endpoint_circuit_failed", { + providerId: updateResult.provider.id, + endpointId: updateResult.endpointCircuitResetId, + error: error instanceof Error ? error.message : String(error), + }); } } - if (previousVendorId && transformed.providerVendorId !== previousVendorId) { - await tryDeleteProviderVendorIfEmpty(previousVendorId); + if (updateResult.previousVendorIdToCleanup) { + try { + await tryDeleteProviderVendorIfEmpty(updateResult.previousVendorIdToCleanup); + } catch (error) { + logger.warn("updateProvider:vendor_cleanup_failed", { + providerId: updateResult.provider.id, + previousVendorId: updateResult.previousVendorIdToCleanup, + error: error instanceof Error ? error.message : String(error), + }); + } } - return transformed; + return updateResult.provider; } export async function updateProviderPrioritiesBatch( diff --git a/tests/integration/provider-endpoint-sync-race.test.ts b/tests/integration/provider-endpoint-sync-race.test.ts new file mode 100644 index 000000000..e87d3b68b --- /dev/null +++ b/tests/integration/provider-endpoint-sync-race.test.ts @@ -0,0 +1,147 @@ +import { and, eq, isNull, sql } from "drizzle-orm"; +import { describe, expect, test } from "vitest"; +import { db } from "@/drizzle/db"; +import { providerEndpoints } from "@/drizzle/schema"; +import { + createProvider, + deleteProvider, + findProviderById, + updateProvider, +} from "@/repository/provider"; +import { + ensureProviderEndpointExistsForUrl, + findProviderEndpointsByVendorAndType, + tryDeleteProviderVendorIfEmpty, +} from "@/repository/provider-endpoints"; + +const run = process.env.DSN ? describe : describe.skip; + +function createDeferred() { + let resolve: () => void; + const promise = new Promise((res) => { + resolve = res; + }); + return { + promise, + resolve: resolve!, + }; +} + +run("Provider endpoint sync on edit (integration race)", () => { + test("concurrent next-url insert should not break provider edit transaction", async () => { + const suffix = `${Date.now()}-${Math.random().toString(16).slice(2)}`; + const oldUrl = `https://race-${suffix}.example.com/v1/messages`; + const nextUrl = `https://race-${suffix}.example.com/v2/messages`; + const websiteUrl = `https://vendor-${suffix}.example.com`; + + const created = await createProvider({ + name: `Race Provider ${suffix}`, + url: oldUrl, + key: `sk-race-${suffix}`, + provider_type: "claude", + website_url: websiteUrl, + favicon_url: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + }); + + const vendorId = created.providerVendorId; + expect(vendorId).not.toBeNull(); + + const [previousEndpoint] = await db + .select({ + id: providerEndpoints.id, + }) + .from(providerEndpoints) + .where( + and( + eq(providerEndpoints.vendorId, vendorId!), + eq(providerEndpoints.providerType, created.providerType), + eq(providerEndpoints.url, oldUrl), + isNull(providerEndpoints.deletedAt) + ) + ) + .limit(1); + + expect(previousEndpoint).toBeDefined(); + + const lockAcquired = createDeferred(); + const releaseLock = createDeferred(); + + const lockTask = db.transaction(async (tx) => { + await tx.execute(sql` + SELECT id + FROM provider_endpoints + WHERE id = ${previousEndpoint!.id} + FOR UPDATE + `); + + lockAcquired.resolve(); + await releaseLock.promise; + }); + + let updatePromise: Promise>> | null = null; + + try { + await lockAcquired.promise; + + updatePromise = updateProvider(created.id, { url: nextUrl }); + + await ensureProviderEndpointExistsForUrl({ + vendorId: vendorId!, + providerType: created.providerType, + url: nextUrl, + }); + + releaseLock.resolve(); + await lockTask; + + const updated = await updatePromise; + expect(updated).not.toBeNull(); + expect(updated?.url).toBe(nextUrl); + + const [previousAfter] = await db + .select({ + id: providerEndpoints.id, + url: providerEndpoints.url, + deletedAt: providerEndpoints.deletedAt, + isEnabled: providerEndpoints.isEnabled, + }) + .from(providerEndpoints) + .where(eq(providerEndpoints.id, previousEndpoint!.id)) + .limit(1); + + expect(previousAfter).toBeDefined(); + expect(previousAfter?.url).toBe(oldUrl); + expect(previousAfter?.deletedAt).toBeTruthy(); + expect(previousAfter?.isEnabled).toBe(false); + + const activeEndpoints = await findProviderEndpointsByVendorAndType( + vendorId!, + created.providerType + ); + + const nextActive = activeEndpoints.filter((endpoint) => endpoint.url === nextUrl); + expect(nextActive).toHaveLength(1); + expect(nextActive[0]?.isEnabled).toBe(true); + expect(activeEndpoints.some((endpoint) => endpoint.url === oldUrl)).toBe(false); + + const providerAfter = await findProviderById(created.id); + expect(providerAfter?.url).toBe(nextUrl); + } finally { + releaseLock.resolve(); + await lockTask.catch(() => {}); + + await deleteProvider(created.id); + if (vendorId) { + await tryDeleteProviderVendorIfEmpty(vendorId).catch(() => {}); + } + + if (updatePromise) { + await updatePromise.catch(() => {}); + } + } + }); +}); diff --git a/tests/unit/actions/providers-recluster.test.ts b/tests/unit/actions/providers-recluster.test.ts index a9b775d6b..3af78aa16 100644 --- a/tests/unit/actions/providers-recluster.test.ts +++ b/tests/unit/actions/providers-recluster.test.ts @@ -211,14 +211,15 @@ describe("reclusterProviderVendors", () => { getOrCreateProviderVendorIdFromUrlsMock.mockResolvedValue(2); backfillProviderEndpointsFromProvidersMock.mockResolvedValue({}); tryDeleteProviderVendorIfEmptyMock.mockResolvedValue(true); - dbMock.transaction.mockImplementation(async (fn) => { - return fn({ - update: vi.fn().mockReturnValue({ - set: vi.fn().mockReturnValue({ - where: vi.fn().mockResolvedValue({}), - }), + const tx = { + update: vi.fn().mockReturnValue({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue({}), }), - }); + }), + }; + dbMock.transaction.mockImplementation(async (fn) => { + return fn(tx); }); const { reclusterProviderVendors } = await import("@/actions/providers"); @@ -229,6 +230,13 @@ describe("reclusterProviderVendors", () => { expect(result.data.applied).toBe(true); } expect(dbMock.transaction).toHaveBeenCalled(); + expect(getOrCreateProviderVendorIdFromUrlsMock).toHaveBeenCalledWith( + expect.objectContaining({ + providerUrl: "http://192.168.1.1:8080/v1/messages", + websiteUrl: null, + }), + { tx } + ); }); it("publishes cache invalidation after apply", async () => { diff --git a/tests/unit/actions/providers.test.ts b/tests/unit/actions/providers.test.ts index c0c4c6682..30bc0b4c6 100644 --- a/tests/unit/actions/providers.test.ts +++ b/tests/unit/actions/providers.test.ts @@ -497,6 +497,65 @@ describe("Provider Actions - Async Optimization", () => { expect(result.ok).toBe(true); expect(revalidatePathMock).not.toHaveBeenCalled(); }); + + it("editProvider endpoint sync: should forward url/provider_type edits to repository", async () => { + const nextUrl = "https://new.example.com/v1/responses"; + const { editProvider } = await import("@/actions/providers"); + + const result = await editProvider(1, { + url: nextUrl, + provider_type: "codex", + }); + + expect(result.ok).toBe(true); + expect(updateProviderMock).toHaveBeenCalledWith( + 1, + expect.objectContaining({ + url: nextUrl, + provider_type: "codex", + }) + ); + expect(publishProviderCacheInvalidationMock).toHaveBeenCalledTimes(1); + }); + + it("editProvider endpoint sync: should generate favicon_url when website_url is updated", async () => { + const nextUrl = "https://new.example.com/v1/messages"; + const nextWebsiteUrl = "https://vendor.example.com/home"; + const { editProvider } = await import("@/actions/providers"); + + const result = await editProvider(1, { + url: nextUrl, + website_url: nextWebsiteUrl, + }); + + expect(result.ok).toBe(true); + expect(updateProviderMock).toHaveBeenCalledWith( + 1, + expect.objectContaining({ + url: nextUrl, + website_url: nextWebsiteUrl, + favicon_url: "https://www.google.com/s2/favicons?domain=vendor.example.com&sz=32", + }) + ); + }); + + it("editProvider endpoint sync: should clear favicon_url when website_url is cleared", async () => { + const { editProvider } = await import("@/actions/providers"); + + const result = await editProvider(1, { + url: "https://new.example.com/v1/messages", + website_url: null, + }); + + expect(result.ok).toBe(true); + expect(updateProviderMock).toHaveBeenCalledWith( + 1, + expect.objectContaining({ + website_url: null, + favicon_url: null, + }) + ); + }); }); describe("deleteProvider", () => { diff --git a/tests/unit/lib/rate-limit/cost-limits.test.ts b/tests/unit/lib/rate-limit/cost-limits.test.ts index aa3634baf..b8299251e 100644 --- a/tests/unit/lib/rate-limit/cost-limits.test.ts +++ b/tests/unit/lib/rate-limit/cost-limits.test.ts @@ -34,6 +34,12 @@ vi.mock("@/lib/redis", () => ({ getRedisClient: () => redisClient, })); +const resolveSystemTimezoneMock = vi.hoisted(() => vi.fn(async () => "Asia/Shanghai")); + +vi.mock("@/lib/utils/timezone", () => ({ + resolveSystemTimezone: resolveSystemTimezoneMock, +})); + const statisticsMock = { // total cost sumKeyTotalCost: vi.fn(async () => 0), @@ -59,6 +65,7 @@ describe("RateLimitService - cost limits and quota checks", () => { beforeEach(() => { pipelineCommands.length = 0; vi.resetAllMocks(); + resolveSystemTimezoneMock.mockResolvedValue("Asia/Shanghai"); vi.useFakeTimers(); vi.setSystemTime(new Date(nowMs)); }); diff --git a/tests/unit/lib/rate-limit/rolling-window-cache-warm.test.ts b/tests/unit/lib/rate-limit/rolling-window-cache-warm.test.ts index 95ffead80..cee080f43 100644 --- a/tests/unit/lib/rate-limit/rolling-window-cache-warm.test.ts +++ b/tests/unit/lib/rate-limit/rolling-window-cache-warm.test.ts @@ -31,6 +31,10 @@ vi.mock("@/lib/redis", () => ({ getRedisClient: () => redisClient, })); +vi.mock("@/lib/utils/timezone", () => ({ + resolveSystemTimezone: vi.fn(async () => "Asia/Shanghai"), +})); + const statisticsMock = { sumKeyTotalCost: vi.fn(async () => 0), sumUserCostToday: vi.fn(async () => 0), diff --git a/tests/unit/lib/rate-limit/service-extra.test.ts b/tests/unit/lib/rate-limit/service-extra.test.ts index 25118d5ba..faf58ebba 100644 --- a/tests/unit/lib/rate-limit/service-extra.test.ts +++ b/tests/unit/lib/rate-limit/service-extra.test.ts @@ -54,6 +54,12 @@ vi.mock("@/lib/redis", () => ({ getRedisClient: () => redisClientRef, })); +const resolveSystemTimezoneMock = vi.hoisted(() => vi.fn(async () => "Asia/Shanghai")); + +vi.mock("@/lib/utils/timezone", () => ({ + resolveSystemTimezone: resolveSystemTimezoneMock, +})); + const statisticsMock = { // service.ts 顶层静态导入需要这些 export 存在 sumKeyTotalCost: vi.fn(async () => 0), @@ -85,6 +91,7 @@ describe("RateLimitService - other quota paths", () => { beforeEach(() => { vi.resetAllMocks(); + resolveSystemTimezoneMock.mockResolvedValue("Asia/Shanghai"); pipelineCalls.length = 0; vi.useFakeTimers(); vi.setSystemTime(new Date(nowMs)); diff --git a/tests/unit/repository/provider-create-transaction.test.ts b/tests/unit/repository/provider-create-transaction.test.ts new file mode 100644 index 000000000..f31cdf9c8 --- /dev/null +++ b/tests/unit/repository/provider-create-transaction.test.ts @@ -0,0 +1,187 @@ +import { describe, expect, test, vi } from "vitest"; + +type ProviderRow = Record; + +function createProviderRow(overrides: Partial = {}): ProviderRow { + const now = new Date("2025-01-01T00:00:00.000Z"); + + return { + id: 101, + name: "Provider A", + url: "https://new.example.com/v1/messages", + key: "test-key", + providerVendorId: 11, + isEnabled: true, + weight: 1, + priority: 0, + costMultiplier: "1.0", + groupTag: null, + providerType: "claude", + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + totalCostResetAt: null, + limitConcurrentSessions: 0, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 1800000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 10000, + requestTimeoutNonStreamingMs: 600000, + websiteUrl: "https://vendor.example.com", + faviconUrl: null, + cacheTtlPreference: null, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: now, + updatedAt: now, + deletedAt: null, + ...overrides, + }; +} + +function createCreateProviderInput(overrides: Record = {}) { + return { + name: "Provider A", + url: "https://new.example.com/v1/messages", + key: "test-key", + provider_type: "claude", + website_url: "https://vendor.example.com", + favicon_url: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + ...overrides, + }; +} + +function createDbMock(insertedRow: ProviderRow) { + const insertReturningMock = vi.fn(async () => [insertedRow]); + const insertValuesMock = vi.fn(() => ({ returning: insertReturningMock })); + const insertMock = vi.fn(() => ({ values: insertValuesMock })); + + const tx = { + insert: insertMock, + }; + + const transactionMock = vi.fn(async (runInTx: (trx: typeof tx) => Promise) => { + return runInTx(tx); + }); + + return { + db: { + transaction: transactionMock, + }, + mocks: { + transactionMock, + insertMock, + }, + }; +} + +describe("provider repository - createProvider transactional endpoint seeding", () => { + test("createProvider should execute vendor resolve + provider insert + endpoint seed in one transaction", async () => { + vi.resetModules(); + + const dbState = createDbMock( + createProviderRow({ + providerType: "codex", + url: "https://new.example.com/v1/responses", + }) + ); + + vi.doMock("@/drizzle/db", () => ({ + db: dbState.db, + })); + + const getOrCreateProviderVendorIdFromUrlsMock = vi.fn(async () => 11); + const ensureProviderEndpointExistsForUrlMock = vi.fn(async () => true); + + vi.doMock("@/repository/provider-endpoints", () => ({ + getOrCreateProviderVendorIdFromUrls: getOrCreateProviderVendorIdFromUrlsMock, + ensureProviderEndpointExistsForUrl: ensureProviderEndpointExistsForUrlMock, + syncProviderEndpointOnProviderEdit: vi.fn(), + tryDeleteProviderVendorIfEmpty: vi.fn(), + })); + + const { createProvider } = await import("@/repository/provider"); + const provider = await createProvider( + createCreateProviderInput({ + provider_type: "codex", + url: "https://new.example.com/v1/responses", + }) + ); + + expect(provider.id).toBe(101); + expect(dbState.mocks.transactionMock).toHaveBeenCalledTimes(1); + expect(dbState.mocks.insertMock).toHaveBeenCalledTimes(1); + + expect(getOrCreateProviderVendorIdFromUrlsMock).toHaveBeenCalledWith( + expect.objectContaining({ + providerUrl: "https://new.example.com/v1/responses", + }), + expect.objectContaining({ tx: expect.any(Object) }) + ); + + expect(ensureProviderEndpointExistsForUrlMock).toHaveBeenCalledWith( + expect.objectContaining({ + vendorId: 11, + providerType: "codex", + url: "https://new.example.com/v1/responses", + }), + expect.objectContaining({ tx: expect.any(Object) }) + ); + }); + + test("createProvider should bubble endpoint seed errors to avoid partial success", async () => { + vi.resetModules(); + + const dbState = createDbMock(createProviderRow()); + + vi.doMock("@/drizzle/db", () => ({ + db: dbState.db, + })); + + const getOrCreateProviderVendorIdFromUrlsMock = vi.fn(async () => 11); + const ensureProviderEndpointExistsForUrlMock = vi.fn(async () => { + throw new Error("endpoint seed failed"); + }); + + vi.doMock("@/repository/provider-endpoints", () => ({ + getOrCreateProviderVendorIdFromUrls: getOrCreateProviderVendorIdFromUrlsMock, + ensureProviderEndpointExistsForUrl: ensureProviderEndpointExistsForUrlMock, + syncProviderEndpointOnProviderEdit: vi.fn(), + tryDeleteProviderVendorIfEmpty: vi.fn(), + })); + + const { createProvider } = await import("@/repository/provider"); + + await expect(createProvider(createCreateProviderInput())).rejects.toThrow( + "endpoint seed failed" + ); + expect(dbState.mocks.transactionMock).toHaveBeenCalledTimes(1); + }); +}); diff --git a/tests/unit/repository/provider-endpoint-sync-helper.test.ts b/tests/unit/repository/provider-endpoint-sync-helper.test.ts new file mode 100644 index 000000000..c2e0eb962 --- /dev/null +++ b/tests/unit/repository/provider-endpoint-sync-helper.test.ts @@ -0,0 +1,292 @@ +import { describe, expect, test, vi } from "vitest"; + +type SelectRow = Record; + +function createTxMock(selectResults: SelectRow[][]) { + const queue = [...selectResults]; + + const selectLimitMock = vi.fn(async () => queue.shift() ?? []); + const selectWhereMock = vi.fn(() => ({ limit: selectLimitMock })); + const selectFromMock = vi.fn(() => ({ where: selectWhereMock })); + const selectMock = vi.fn(() => ({ from: selectFromMock })); + + const updatePayloads: Array> = []; + const updateWhereMock = vi.fn(async () => []); + const updateSetMock = vi.fn((payload: Record) => { + updatePayloads.push(payload); + return { where: updateWhereMock }; + }); + const updateMock = vi.fn(() => ({ set: updateSetMock })); + + const insertReturningMock = vi.fn(async () => []); + const insertOnConflictDoNothingMock = vi.fn(() => ({ returning: insertReturningMock })); + const insertValuesMock = vi.fn(() => ({ onConflictDoNothing: insertOnConflictDoNothingMock })); + const insertMock = vi.fn(() => ({ values: insertValuesMock })); + + const tx = { + select: selectMock, + update: updateMock, + insert: insertMock, + }; + + const nestedTransactionMock = vi.fn( + async (runInTx: (nestedTx: typeof tx) => Promise) => { + return runInTx(tx); + } + ); + + const txWithSavepoint = { + ...tx, + transaction: nestedTransactionMock, + }; + + return { + tx: txWithSavepoint, + updatePayloads, + mocks: { + updateMock, + updateWhereMock, + insertMock, + insertReturningMock, + selectLimitMock, + nestedTransactionMock, + }, + }; +} + +async function arrangeSyncTest(selectResults: SelectRow[][]) { + vi.resetModules(); + + const txState = createTxMock(selectResults); + const transactionMock = vi.fn(async (runInTx: (tx: typeof txState.tx) => Promise) => { + return runInTx(txState.tx); + }); + const resetEndpointCircuitMock = vi.fn(async () => {}); + + vi.doMock("@/drizzle/db", () => ({ + db: { + transaction: transactionMock, + }, + })); + vi.doMock("@/lib/endpoint-circuit-breaker", () => ({ + resetEndpointCircuit: resetEndpointCircuitMock, + })); + + const { syncProviderEndpointOnProviderEdit } = await import("@/repository/provider-endpoints"); + + return { + syncProviderEndpointOnProviderEdit, + transactionMock, + resetEndpointCircuitMock, + ...txState, + }; +} + +describe("syncProviderEndpointOnProviderEdit", () => { + test("invalid next url should throw instead of silent noop", async () => { + const { syncProviderEndpointOnProviderEdit, transactionMock, mocks } = await arrangeSyncTest( + [] + ); + + await expect( + syncProviderEndpointOnProviderEdit({ + providerId: 1, + vendorId: 11, + providerType: "claude", + previousVendorId: 11, + previousProviderType: "claude", + previousUrl: "https://old.example.com/v1/messages", + nextUrl: "not-a-valid-url", + keepPreviousWhenReferenced: true, + }) + ).rejects.toThrow("[ProviderEndpointSync] nextUrl must be a valid URL"); + + expect(transactionMock).not.toHaveBeenCalled(); + expect(mocks.updateMock).not.toHaveBeenCalled(); + expect(mocks.insertMock).not.toHaveBeenCalled(); + }); + + test("website_url only edit should not revive disabled endpoint when identity is unchanged", async () => { + const endpointUrl = "https://same.example.com/v1/messages"; + const { syncProviderEndpointOnProviderEdit, mocks, resetEndpointCircuitMock } = + await arrangeSyncTest([[{ id: 101, deletedAt: null, isEnabled: false }]]); + + const result = await syncProviderEndpointOnProviderEdit({ + providerId: 1, + vendorId: 11, + providerType: "claude", + previousVendorId: 11, + previousProviderType: "claude", + previousUrl: endpointUrl, + nextUrl: endpointUrl, + keepPreviousWhenReferenced: true, + }); + + expect(result).toEqual({ action: "noop" }); + expect(mocks.updateMock).not.toHaveBeenCalled(); + expect(mocks.insertMock).not.toHaveBeenCalled(); + expect(resetEndpointCircuitMock).not.toHaveBeenCalled(); + }); + + test("in-place url move should clear stale probe snapshot fields", async () => { + const oldUrl = "https://old.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + const { syncProviderEndpointOnProviderEdit, updatePayloads, mocks, resetEndpointCircuitMock } = + await arrangeSyncTest([[{ id: 7, deletedAt: null, isEnabled: true }], [], []]); + + const result = await syncProviderEndpointOnProviderEdit({ + providerId: 1, + vendorId: 11, + providerType: "claude", + previousVendorId: 11, + previousProviderType: "claude", + previousUrl: oldUrl, + nextUrl: newUrl, + keepPreviousWhenReferenced: true, + }); + + expect(result).toEqual({ action: "updated-previous-in-place" }); + expect(mocks.updateMock).toHaveBeenCalledTimes(1); + expect(resetEndpointCircuitMock).toHaveBeenCalledTimes(1); + expect(resetEndpointCircuitMock).toHaveBeenCalledWith(7); + expect(updatePayloads[0]).toEqual( + expect.objectContaining({ + url: newUrl, + lastProbedAt: null, + lastProbeOk: null, + lastProbeStatusCode: null, + lastProbeLatencyMs: null, + lastProbeErrorType: null, + lastProbeErrorMessage: null, + }) + ); + }); + + test("in-place url move with external tx should defer circuit reset until caller commits", async () => { + const oldUrl = "https://old.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + const { + syncProviderEndpointOnProviderEdit, + tx, + transactionMock, + resetEndpointCircuitMock, + updatePayloads, + mocks, + } = await arrangeSyncTest([[{ id: 7, deletedAt: null, isEnabled: true }], [], []]); + + const result = await syncProviderEndpointOnProviderEdit( + { + providerId: 1, + vendorId: 11, + providerType: "claude", + previousVendorId: 11, + previousProviderType: "claude", + previousUrl: oldUrl, + nextUrl: newUrl, + keepPreviousWhenReferenced: true, + }, + { tx } + ); + + expect(result).toEqual({ + action: "updated-previous-in-place", + resetCircuitEndpointId: 7, + }); + expect(mocks.updateMock).toHaveBeenCalledTimes(1); + expect(updatePayloads[0]).toEqual(expect.objectContaining({ url: newUrl })); + expect(transactionMock).not.toHaveBeenCalled(); + expect(resetEndpointCircuitMock).not.toHaveBeenCalled(); + }); + + test("concurrent insert conflict should degrade to noop instead of throwing", async () => { + const oldUrl = "https://old.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + const { syncProviderEndpointOnProviderEdit, mocks, resetEndpointCircuitMock } = + await arrangeSyncTest([[], [], [], [{ id: 201, deletedAt: null, isEnabled: true }]]); + + const result = await syncProviderEndpointOnProviderEdit({ + providerId: 1, + vendorId: 11, + providerType: "claude", + previousVendorId: 11, + previousProviderType: "claude", + previousUrl: oldUrl, + nextUrl: newUrl, + keepPreviousWhenReferenced: true, + }); + + expect(result).toEqual({ action: "noop" }); + expect(mocks.insertMock).toHaveBeenCalledTimes(1); + expect(mocks.updateMock).not.toHaveBeenCalled(); + expect(resetEndpointCircuitMock).not.toHaveBeenCalled(); + }); + + test("in-place move unique conflict should fallback to keep-next and soft-delete previous", async () => { + const oldUrl = "https://old.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + const { syncProviderEndpointOnProviderEdit, updatePayloads, mocks, resetEndpointCircuitMock } = + await arrangeSyncTest([ + [{ id: 7, deletedAt: null, isEnabled: true }], + [], + [], + [], + [{ id: 9, deletedAt: null, isEnabled: true }], + ]); + + mocks.updateWhereMock.mockRejectedValueOnce( + Object.assign(new Error("duplicate key value violates unique constraint"), { + code: "23505", + }) + ); + + const result = await syncProviderEndpointOnProviderEdit({ + providerId: 1, + vendorId: 11, + providerType: "claude", + previousVendorId: 11, + previousProviderType: "claude", + previousUrl: oldUrl, + nextUrl: newUrl, + keepPreviousWhenReferenced: true, + }); + + expect(result).toEqual({ action: "soft-deleted-previous-and-kept-next" }); + expect(mocks.insertMock).toHaveBeenCalledTimes(1); + expect(mocks.updateMock).toHaveBeenCalledTimes(2); + expect(updatePayloads[1]).toEqual( + expect.objectContaining({ + isEnabled: false, + }) + ); + expect(resetEndpointCircuitMock).not.toHaveBeenCalled(); + }); + + test("kept-previous with concurrent noop should return kept-previous-and-kept-next", async () => { + const oldUrl = "https://old.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + const { syncProviderEndpointOnProviderEdit, mocks, resetEndpointCircuitMock } = + await arrangeSyncTest([ + [{ id: 7, deletedAt: null, isEnabled: true }], + [], + [{ id: 99 }], + [], + [{ id: 9, deletedAt: null, isEnabled: true }], + ]); + + const result = await syncProviderEndpointOnProviderEdit({ + providerId: 1, + vendorId: 11, + providerType: "claude", + previousVendorId: 11, + previousProviderType: "claude", + previousUrl: oldUrl, + nextUrl: newUrl, + keepPreviousWhenReferenced: true, + }); + + expect(result).toEqual({ action: "kept-previous-and-kept-next" }); + expect(mocks.insertMock).toHaveBeenCalledTimes(1); + expect(mocks.updateMock).not.toHaveBeenCalled(); + expect(resetEndpointCircuitMock).not.toHaveBeenCalled(); + }); +}); diff --git a/tests/unit/repository/provider-endpoint-sync-on-edit.test.ts b/tests/unit/repository/provider-endpoint-sync-on-edit.test.ts new file mode 100644 index 000000000..19a7817fb --- /dev/null +++ b/tests/unit/repository/provider-endpoint-sync-on-edit.test.ts @@ -0,0 +1,242 @@ +import { describe, expect, test, vi } from "vitest"; + +type ProviderRow = Record; + +function createProviderRow(overrides: Partial = {}): ProviderRow { + const now = new Date("2025-01-01T00:00:00.000Z"); + + return { + id: 1, + name: "Provider A", + url: "https://old.example.com/v1/messages", + key: "test-key", + providerVendorId: 11, + isEnabled: true, + weight: 1, + priority: 0, + costMultiplier: "1.0", + groupTag: null, + providerType: "claude", + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + totalCostResetAt: null, + limitConcurrentSessions: 0, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 1800000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 10000, + requestTimeoutNonStreamingMs: 600000, + websiteUrl: "https://vendor.example.com", + faviconUrl: null, + cacheTtlPreference: null, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: now, + updatedAt: now, + deletedAt: null, + ...overrides, + }; +} + +function createDbMock(currentRow: ProviderRow, updatedRow: ProviderRow) { + const selectLimitMock = vi.fn(async () => [currentRow]); + const selectWhereMock = vi.fn(() => ({ limit: selectLimitMock })); + const selectFromMock = vi.fn(() => ({ where: selectWhereMock })); + const selectMock = vi.fn(() => ({ from: selectFromMock })); + + const updateReturningMock = vi.fn(async () => [updatedRow]); + const updateWhereMock = vi.fn(() => ({ returning: updateReturningMock })); + const updateSetMock = vi.fn(() => ({ where: updateWhereMock })); + const updateMock = vi.fn(() => ({ set: updateSetMock })); + + const tx = { + select: selectMock, + update: updateMock, + }; + const transactionMock = vi.fn(async (runInTx: (trx: typeof tx) => Promise) => { + return runInTx(tx); + }); + + return { + select: selectMock, + update: updateMock, + transaction: transactionMock, + }; +} + +async function arrangeUrlEditRedScenario(input: { + oldUrl: string; + newUrl: string; + previousVendorId?: number; + nextVendorId?: number; +}) { + vi.resetModules(); + + const previousVendorId = input.previousVendorId ?? 11; + const nextVendorId = input.nextVendorId ?? previousVendorId; + + const currentRow = createProviderRow({ + id: 1, + url: input.oldUrl, + providerVendorId: previousVendorId, + providerType: "claude", + }); + const updatedRow = createProviderRow({ + id: 1, + url: input.newUrl, + providerVendorId: nextVendorId, + providerType: "claude", + }); + + const db = createDbMock(currentRow, updatedRow); + vi.doMock("@/drizzle/db", () => ({ db })); + + const getOrCreateProviderVendorIdFromUrlsMock = vi.fn(async () => nextVendorId); + const ensureProviderEndpointExistsForUrlMock = vi.fn(async () => true); + const tryDeleteProviderVendorIfEmptyMock = vi.fn(async () => false); + const syncProviderEndpointOnProviderEditMock = vi.fn( + async (): Promise<{ action: string; resetCircuitEndpointId?: number }> => ({ action: "noop" }) + ); + const resetEndpointCircuitMock = vi.fn(async () => {}); + + vi.doMock("@/repository/provider-endpoints", () => ({ + getOrCreateProviderVendorIdFromUrls: getOrCreateProviderVendorIdFromUrlsMock, + ensureProviderEndpointExistsForUrl: ensureProviderEndpointExistsForUrlMock, + tryDeleteProviderVendorIfEmpty: tryDeleteProviderVendorIfEmptyMock, + syncProviderEndpointOnProviderEdit: syncProviderEndpointOnProviderEditMock, + })); + vi.doMock("@/lib/endpoint-circuit-breaker", () => ({ + resetEndpointCircuit: resetEndpointCircuitMock, + })); + + const { updateProvider } = await import("@/repository/provider"); + + return { + updateProvider, + mocks: { + ensureProviderEndpointExistsForUrlMock, + syncProviderEndpointOnProviderEditMock, + tryDeleteProviderVendorIfEmptyMock, + resetEndpointCircuitMock, + }, + }; +} + +describe("provider repository - endpoint sync on edit (#722 RED)", () => { + test("old-url exists + new-url absent: should update endpoint row instead of insert-only ensure", async () => { + const oldUrl = "https://old.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + + const { updateProvider, mocks } = await arrangeUrlEditRedScenario({ oldUrl, newUrl }); + const provider = await updateProvider(1, { url: newUrl }); + + expect(provider?.url).toBe(newUrl); + expect(mocks.syncProviderEndpointOnProviderEditMock).toHaveBeenCalledWith( + expect.objectContaining({ + providerId: 1, + vendorId: 11, + providerType: "claude", + previousUrl: oldUrl, + nextUrl: newUrl, + }), + expect.objectContaining({ tx: expect.any(Object) }) + ); + }); + + test("sync result with reset endpoint id should reset circuit after update commit", async () => { + const oldUrl = "https://old.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + + const { updateProvider, mocks } = await arrangeUrlEditRedScenario({ oldUrl, newUrl }); + mocks.syncProviderEndpointOnProviderEditMock.mockResolvedValueOnce({ + action: "updated-previous-in-place", + resetCircuitEndpointId: 7, + }); + + await updateProvider(1, { url: newUrl }); + + expect(mocks.resetEndpointCircuitMock).toHaveBeenCalledTimes(1); + expect(mocks.resetEndpointCircuitMock).toHaveBeenCalledWith(7); + }); + + test("old-url exists + new-url exists: should avoid duplicate accumulation and not call insert-only ensure", async () => { + const oldUrl = "https://old.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + + const { updateProvider, mocks } = await arrangeUrlEditRedScenario({ oldUrl, newUrl }); + await updateProvider(1, { url: newUrl }); + + expect(mocks.ensureProviderEndpointExistsForUrlMock).not.toHaveBeenCalled(); + }); + + test("old-url still referenced by another active provider: should keep old-url endpoint (safe cleanup guard)", async () => { + const oldUrl = "https://shared.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + + const { updateProvider, mocks } = await arrangeUrlEditRedScenario({ oldUrl, newUrl }); + await updateProvider(1, { url: newUrl }); + + expect(mocks.syncProviderEndpointOnProviderEditMock).toHaveBeenCalledWith( + expect.objectContaining({ + previousUrl: oldUrl, + nextUrl: newUrl, + keepPreviousWhenReferenced: true, + }), + expect.objectContaining({ tx: expect.any(Object) }) + ); + expect(mocks.tryDeleteProviderVendorIfEmptyMock).not.toHaveBeenCalled(); + }); + + test("endpoint sync throw: should bubble error instead of silent partial success", async () => { + const oldUrl = "https://old.example.com/v1/messages"; + const newUrl = "https://new.example.com/v1/messages"; + + const { updateProvider, mocks } = await arrangeUrlEditRedScenario({ oldUrl, newUrl }); + mocks.syncProviderEndpointOnProviderEditMock.mockRejectedValueOnce(new Error("sync failed")); + + await expect(updateProvider(1, { url: newUrl })).rejects.toThrow("sync failed"); + expect(mocks.tryDeleteProviderVendorIfEmptyMock).not.toHaveBeenCalled(); + }); + + test("vendor cleanup failure should not block provider update", async () => { + const oldUrl = "https://old-vendor.example.com/v1/messages"; + const newUrl = "https://new-vendor.example.com/v1/messages"; + + const { updateProvider, mocks } = await arrangeUrlEditRedScenario({ + oldUrl, + newUrl, + previousVendorId: 11, + nextVendorId: 22, + }); + + mocks.tryDeleteProviderVendorIfEmptyMock.mockRejectedValueOnce(new Error("cleanup failed")); + + const provider = await updateProvider(1, { url: newUrl }); + expect(provider?.providerVendorId).toBe(22); + expect(mocks.tryDeleteProviderVendorIfEmptyMock).toHaveBeenCalledWith(11); + }); +}); diff --git a/tests/unit/repository/provider-endpoints.test.ts b/tests/unit/repository/provider-endpoints.test.ts index a8a4015bc..ecfd2ef1b 100644 --- a/tests/unit/repository/provider-endpoints.test.ts +++ b/tests/unit/repository/provider-endpoints.test.ts @@ -17,7 +17,7 @@ function createThenableQuery(result: T) { } describe("provider-endpoints repository", () => { - test("ensureProviderEndpointExistsForUrl: url 为空时返回 false 且不写 DB", async () => { + test("ensureProviderEndpointExistsForUrl: url 为空时抛错且不写 DB", async () => { vi.resetModules(); const insertMock = vi.fn(); @@ -28,17 +28,17 @@ describe("provider-endpoints repository", () => { })); const { ensureProviderEndpointExistsForUrl } = await import("@/repository/provider-endpoints"); - const ok = await ensureProviderEndpointExistsForUrl({ - vendorId: 1, - providerType: "claude", - url: " ", - }); - - expect(ok).toBe(false); + await expect( + ensureProviderEndpointExistsForUrl({ + vendorId: 1, + providerType: "claude", + url: " ", + }) + ).rejects.toThrow("[ProviderEndpointEnsure] url is required"); expect(insertMock).not.toHaveBeenCalled(); }); - test("ensureProviderEndpointExistsForUrl: url 非法时返回 false 且不写 DB", async () => { + test("ensureProviderEndpointExistsForUrl: url 非法时抛错且不写 DB", async () => { vi.resetModules(); const insertMock = vi.fn(); @@ -49,13 +49,13 @@ describe("provider-endpoints repository", () => { })); const { ensureProviderEndpointExistsForUrl } = await import("@/repository/provider-endpoints"); - const ok = await ensureProviderEndpointExistsForUrl({ - vendorId: 1, - providerType: "claude", - url: "not a url", - }); - - expect(ok).toBe(false); + await expect( + ensureProviderEndpointExistsForUrl({ + vendorId: 1, + providerType: "claude", + url: "not a url", + }) + ).rejects.toThrow("[ProviderEndpointEnsure] url must be a valid URL"); expect(insertMock).not.toHaveBeenCalled(); }); @@ -122,6 +122,37 @@ describe("provider-endpoints repository", () => { expect(ok).toBe(false); }); + test("ensureProviderEndpointExistsForUrl: 非编辑路径保持 insert-only 语义(不触发 update/transaction)", async () => { + vi.resetModules(); + + const returning = vi.fn(async () => []); + const onConflictDoNothing = vi.fn(() => ({ returning })); + const values = vi.fn(() => ({ onConflictDoNothing })); + const insertMock = vi.fn(() => ({ values })); + const updateMock = vi.fn(); + const transactionMock = vi.fn(); + + vi.doMock("@/drizzle/db", () => ({ + db: { + insert: insertMock, + update: updateMock, + transaction: transactionMock, + }, + })); + + const { ensureProviderEndpointExistsForUrl } = await import("@/repository/provider-endpoints"); + const ok = await ensureProviderEndpointExistsForUrl({ + vendorId: 1, + providerType: "codex", + url: "https://api.example.com/v1/responses", + }); + + expect(ok).toBe(false); + expect(insertMock).toHaveBeenCalledTimes(1); + expect(updateMock).not.toHaveBeenCalled(); + expect(transactionMock).not.toHaveBeenCalled(); + }); + test("backfillProviderEndpointsFromProviders: 全部无效时不写 DB", async () => { vi.resetModules(); @@ -390,7 +421,7 @@ describe("provider-endpoints repository", () => { expect(deleteMock).toHaveBeenCalledTimes(2); }); - test("tryDeleteProviderVendorIfEmpty: transaction 抛错时返回 false", async () => { + test("tryDeleteProviderVendorIfEmpty: transaction 抛错时抛出异常", async () => { vi.resetModules(); const transactionMock = vi.fn(async () => { @@ -404,9 +435,7 @@ describe("provider-endpoints repository", () => { })); const { tryDeleteProviderVendorIfEmpty } = await import("@/repository/provider-endpoints"); - const ok = await tryDeleteProviderVendorIfEmpty(123); - - expect(ok).toBe(false); + await expect(tryDeleteProviderVendorIfEmpty(123)).rejects.toThrow("boom"); }); test("deleteProviderVendor: vendor 存在时返回 true 且执行级联删除", async () => { diff --git a/vitest.integration.config.ts b/vitest.integration.config.ts index a9818dec7..2980e0924 100644 --- a/vitest.integration.config.ts +++ b/vitest.integration.config.ts @@ -22,6 +22,7 @@ export default defineConfig({ "tests/integration/webhook-targets-crud.test.ts", "tests/integration/notification-bindings.test.ts", "tests/integration/auth.test.ts", + "tests/integration/provider-endpoint-sync-race.test.ts", // 需要 DB 的 API 测试(从主配置排除,在此运行) "tests/api/users-actions.test.ts", "tests/api/providers-actions.test.ts",