Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions src/actions/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { revalidatePath } from "next/cache";
import { GeminiAuth } from "@/app/v1/_lib/gemini/auth";
import { isClientAbortError } from "@/app/v1/_lib/proxy/errors";
import { getSession } from "@/lib/auth";
import { publishProviderCacheInvalidation } from "@/lib/cache/provider-cache";
import {
clearConfigCache,
clearProviderState,
Expand Down Expand Up @@ -36,6 +37,7 @@ import {
createProvider,
deleteProvider,
findAllProviders,
findAllProvidersFresh,
findProviderById,
getProviderStatistics,
resetProviderTotalCostResetAt,
Expand Down Expand Up @@ -102,6 +104,29 @@ const API_TEST_CONFIG = {
const PROXY_RETRY_STATUS_CODES = new Set([502, 504, 520, 521, 522, 523, 524, 525, 526, 527, 530]);
const CLOUDFLARE_ERROR_STATUS_CODES = new Set([520, 521, 522, 523, 524, 525, 526, 527, 530]);

/**
* 广播 Provider 缓存失效通知(跨实例)
*
* CRUD 操作后调用,通知所有实例清除缓存。
* 失败时不影响主流程,其他实例将依赖 TTL 过期后刷新。
*/
async function broadcastProviderCacheInvalidation(context: {
operation: "add" | "edit" | "remove";
providerId: number;
}): Promise<void> {
try {
await publishProviderCacheInvalidation();
logger.debug(`${context.operation} Provider:cache_invalidation_success`, {
providerId: context.providerId,
});
} catch (error) {
logger.warn(`${context.operation} Provider:cache_invalidation_failed`, {
providerId: context.providerId,
error: error instanceof Error ? error.message : String(error),
});
}
}

// 获取服务商数据
export async function getProviders(): Promise<ProviderDisplay[]> {
try {
Expand All @@ -121,7 +146,7 @@ export async function getProviders(): Promise<ProviderDisplay[]> {

// 并行获取供应商列表和统计数据
const [providers, statistics] = await Promise.all([
findAllProviders(),
findAllProvidersFresh(),
getProviderStatistics().catch((error) => {
logger.trace("getProviders:statistics_error", {
message: error.message,
Expand Down Expand Up @@ -309,7 +334,7 @@ export async function getProviderGroupsWithCount(): Promise<
ActionResult<Array<{ group: string; providerCount: number }>>
> {
try {
const providers = await findAllProviders();
const providers = await findAllProvidersFresh();
const groupCounts = new Map<string, number>();

for (const provider of providers) {
Expand Down Expand Up @@ -495,6 +520,9 @@ export async function addProvider(data: {
// 不影响主流程,仅记录警告
}

// 广播缓存更新(跨实例即时生效)
await broadcastProviderCacheInvalidation({ operation: "add", providerId: provider.id });

revalidatePath("/settings/providers");
logger.trace("addProvider:revalidated", { path: "/settings/providers" });

Expand Down Expand Up @@ -633,6 +661,9 @@ export async function editProvider(
}
}

// 广播缓存更新(跨实例即时生效)
await broadcastProviderCacheInvalidation({ operation: "edit", providerId });

revalidatePath("/settings/providers");
return { ok: true };
} catch (error) {
Expand Down Expand Up @@ -667,6 +698,9 @@ export async function removeProvider(providerId: number): Promise<ActionResult>
});
}

// 广播缓存更新(跨实例即时生效)
await broadcastProviderCacheInvalidation({ operation: "remove", providerId });

revalidatePath("/settings/providers");
return { ok: true };
} catch (error) {
Expand All @@ -687,7 +721,7 @@ export async function getProvidersHealthStatus() {
return {};
}

const providerIds = await findAllProviders().then((providers) => providers.map((p) => p.id));
const providerIds = await findAllProvidersFresh().then((providers) => providers.map((p) => p.id));
const healthStatus = await getAllHealthStatusAsync(providerIds, {
forceRefresh: true,
});
Expand Down
6 changes: 5 additions & 1 deletion src/app/v1/_lib/proxy/provider-selector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,11 @@ export class ProxyProviderResolver {
provider: Provider | null;
context: NonNullable<ProviderChainItem["decisionContext"]>;
}> {
const allProviders = await findAllProviders();
// 使用 Session 快照保证故障迁移期间数据一致性
// 如果没有 session,回退到 findAllProviders(内部已使用缓存)
const allProviders = session
? await session.getProvidersSnapshot()
: await findAllProviders();
const requestedModel = session?.getCurrentModel() || "";

// === Step 1: 分组预过滤(静默,用户只能看到自己分组内的供应商)===
Expand Down
26 changes: 26 additions & 0 deletions src/app/v1/_lib/proxy/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { Context } from "hono";
import { logger } from "@/lib/logger";
import { clientRequestsContext1m as clientRequestsContext1mHelper } from "@/lib/special-attributes";
import { findLatestPriceByModel } from "@/repository/model-price";
import { findAllProviders } from "@/repository/provider";
import type { CacheTtlResolved } from "@/types/cache";
import type { Key } from "@/types/key";
import type { ProviderChainItem } from "@/types/message";
Expand Down Expand Up @@ -108,6 +109,14 @@ export class ProxySession {
// Cached price data for billing model source (lazy loaded: undefined=not loaded, null=no data)
private cachedBillingPriceData?: ModelPriceData | null;

/**
* 请求级 Provider 快照
*
* 在 Session 首次获取时冻结,整个请求生命周期保持不变。
* 用于保证故障迁移期间数据一致性(避免同一请求多次调用返回不同结果)。
*/
private providersSnapshot: Provider[] | null = null;

private constructor(init: {
startTime: number;
method: string;
Expand Down Expand Up @@ -313,6 +322,23 @@ export class ProxySession {
return this.requestSequence;
}

/**
* 获取 Provider 列表快照
*
* 首次调用时从进程缓存获取并冻结,后续调用返回相同数据。
* 用于保证故障迁移期间数据一致性(避免同一请求多次调用返回不同结果)。
*
* @returns Provider 列表(整个请求生命周期不变)
*/
async getProvidersSnapshot(): Promise<Provider[]> {
if (this.providersSnapshot !== null) {
return this.providersSnapshot;
}

this.providersSnapshot = await findAllProviders();
return this.providersSnapshot;
}

/**
* 生成基于请求指纹的确定性 Session ID
*
Expand Down
172 changes: 172 additions & 0 deletions src/lib/cache/provider-cache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/**
* Provider 进程级缓存
*
* 特性:
* - 30s TTL 自动过期
* - Redis Pub/Sub 失效通知(跨实例即时同步)
* - 降级策略:Redis 不可用时依赖 TTL 自动过期
* - 版本号防止并发刷新竞态
* - 请求级快照支持(保证故障迁移期间数据一致性)
*/

import "server-only";

import { logger } from "@/lib/logger";
import {
publishCacheInvalidation,
subscribeCacheInvalidation,
} from "@/lib/redis/pubsub";
import type { Provider } from "@/types/provider";

export const CHANNEL_PROVIDERS_UPDATED = "cch:cache:providers:updated";

const CACHE_TTL_MS = 30_000; // 30 seconds

interface ProviderCacheState {
data: Provider[] | null;
expiresAt: number;
version: number; // 防止并发刷新竞态
refreshPromise: Promise<Provider[]> | null; // 防止并发请求同时刷新
}

const cache: ProviderCacheState = {
data: null,
expiresAt: 0,
version: 0,
refreshPromise: null,
};

let subscriptionInitialized = false;

/**
* 初始化 Redis 订阅
*
* 使用失效通知模式:收到通知后清除本地缓存,下次请求时从 DB 刷新
* pubsub.ts 订阅:静默降级
*/
async function ensureSubscription(): Promise<void> {
if (subscriptionInitialized) return;

// CI/build 阶段跳过
if (process.env.CI === "true" || process.env.NEXT_PHASE === "phase-production-build") {
subscriptionInitialized = true;
return;
}

subscriptionInitialized = true;
// pubsub.ts 订阅机制
await subscribeCacheInvalidation(CHANNEL_PROVIDERS_UPDATED, () => {
invalidateCache();
logger.debug("[ProviderCache] Cache invalidated via pub/sub");
});
}

/**
* 失效缓存(本地)
*/
export function invalidateCache(): void {
cache.data = null;
cache.expiresAt = 0;
cache.version++;
cache.refreshPromise = null;
}

/**
* 发布缓存失效通知(跨实例)
*
* CRUD 操作后调用,通知所有实例清除缓存。
* 各实例在下次请求时自行从 DB 刷新,保证:
* - 类型安全:Date 等类型从 DB 正确构造
* - 数据安全:不通过 Redis 传输敏感数据(如 provider.key)
*/
export async function publishProviderCacheInvalidation(): Promise<void> {
invalidateCache();
await publishCacheInvalidation(CHANNEL_PROVIDERS_UPDATED);
logger.debug("[ProviderCache] Published cache invalidation");
}

/**
* 获取缓存的 Provider 列表(带自动刷新)
*
* @param fetcher - 数据库查询函数(依赖注入,便于测试)
* @returns Provider 列表
*/
export async function getCachedProviders(
fetcher: () => Promise<Provider[]>
): Promise<Provider[]> {
// 确保订阅已初始化(异步,不阻塞)
void ensureSubscription();

const now = Date.now();

// 1. 缓存命中且未过期
if (cache.data && cache.expiresAt > now) {
return cache.data;
}

// 2. 已有刷新任务在进行中,等待它完成(防止并发刷新)
if (cache.refreshPromise) {
return cache.refreshPromise;
}

// 3. 需要刷新,创建新的刷新任务
const currentVersion = cache.version;
cache.refreshPromise = (async () => {
try {
const data = await fetcher();

// 检查版本号,防止被更新的失效事件覆盖
if (cache.version === currentVersion) {
cache.data = data;
cache.expiresAt = Date.now() + CACHE_TTL_MS;
logger.debug("[ProviderCache] Cache refreshed from DB", {
count: data.length,
ttlMs: CACHE_TTL_MS,
});
}

return data;
} finally {
// 清除 refreshPromise(允许下次刷新)
if (cache.version === currentVersion) {
cache.refreshPromise = null;
}
}
})();

return cache.refreshPromise;
}

/**
* 预热缓存(启动时调用)
*/
export async function warmupProviderCache(
fetcher: () => Promise<Provider[]>
): Promise<void> {
try {
await getCachedProviders(fetcher);
logger.info("[ProviderCache] Cache warmed up successfully");
} catch (error) {
logger.warn("[ProviderCache] Cache warmup failed", { error });
}
}

/**
* 获取缓存统计信息(用于监控/调试)
*/
export function getProviderCacheStats(): {
hasData: boolean;
count: number;
expiresIn: number;
version: number;
isRefreshing: boolean;
} {
const now = Date.now();
return {
hasData: cache.data !== null,
count: cache.data?.length ?? 0,
expiresIn: Math.max(0, cache.expiresAt - now),
version: cache.version,
isRefreshing: cache.refreshPromise !== null,
};
}
25 changes: 21 additions & 4 deletions src/repository/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import { and, desc, eq, isNotNull, isNull, ne, sql } from "drizzle-orm";
import { db } from "@/drizzle/db";
import { providers } from "@/drizzle/schema";
import { getCachedProviders } from "@/lib/cache/provider-cache";
import { getEnvConfig } from "@/lib/config";
import { logger } from "@/lib/logger";
import type { CreateProviderData, Provider, UpdateProviderData } from "@/types/provider";
Expand Down Expand Up @@ -191,10 +192,13 @@ export async function findProviderList(
}

/**
* Fetch all providers without pagination limits.
* Use this when you need the complete provider list (e.g., for selection, health status).
* 直接从数据库获取所有供应商(绕过缓存)
*
* 用于:
* - 管理后台需要保证数据新鲜度的场景
* - 缓存刷新时的数据源
*/
export async function findAllProviders(): Promise<Provider[]> {
export async function findAllProvidersFresh(): Promise<Provider[]> {
const result = await db
.select({
id: providers.id,
Expand Down Expand Up @@ -252,14 +256,27 @@ export async function findAllProviders(): Promise<Provider[]> {
.where(isNull(providers.deletedAt))
.orderBy(desc(providers.createdAt));

logger.trace("findAllProviders:query_result", {
logger.trace("findAllProvidersFresh:query_result", {
count: result.length,
ids: result.map((r) => r.id),
});

return result.map(toProvider);
}

/**
* 获取所有供应商(带缓存)
*
* 使用进程级缓存:
* - 30s TTL 自动过期
* - Redis Pub/Sub 跨实例即时失效
*
* 用于高频读取场景(如供应商选择)
*/
export async function findAllProviders(): Promise<Provider[]> {
return getCachedProviders(findAllProvidersFresh);
}

export async function findProviderById(id: number): Promise<Provider | null> {
const [provider] = await db
.select({
Expand Down
Loading