diff --git a/.env.example b/.env.example index 14193bd05..fb1120931 100644 --- a/.env.example +++ b/.env.example @@ -8,6 +8,11 @@ AUTO_MIGRATE=true # 数据库连接字符串(仅用于本地开发或非 Docker Compose 部署) DSN="postgres://user:password@host:port/db_name" +# API Key Vacuum Filter(真空过滤器) +# - true (默认):启用。用于在访问 DB 前“负向短路”无效 key,降低 DB 压力、抵御爆破 +# - false:禁用(例如:需要排查问题或节省内存时) +ENABLE_API_KEY_VACUUM_FILTER="true" + # PostgreSQL 连接池配置(postgres.js) # 说明: # - 这些值是“每个应用进程”的连接池上限;k8s 多副本时需要按副本数分摊 @@ -58,6 +63,11 @@ REDIS_TLS_REJECT_UNAUTHORIZED=true # 是否验证 Redis TLS 证书(默认 # 设置为 false 可跳过证书验证,用于自签证书或共享证书场景 # 仅在 rediss:// 协议时生效 +# API Key 鉴权缓存(Vacuum Filter -> Redis -> DB) +# 说明:需要 ENABLE_RATE_LIMIT=true 且配置 REDIS_URL 才会启用 Redis 缓存;否则自动回落到 DB。 +API_KEY_AUTH_CACHE_TTL_SECONDS="60" # 鉴权缓存 TTL(秒,默认 60,最大 3600) +ENABLE_API_KEY_REDIS_CACHE="true" # 是否启用 API Key Redis 缓存(默认:true) + # Session 配置 SESSION_TTL=300 # Session 过期时间(秒,默认 300 = 5 分钟) STORE_SESSION_MESSAGES=false # 会话消息存储模式(默认:false) diff --git a/README.en.md b/README.en.md index 01d7876ed..afcd7c6c8 100644 --- a/README.en.md +++ b/README.en.md @@ -276,6 +276,9 @@ Docker Compose is the **preferred deployment method** — it automatically provi | `REDIS_URL` | `redis://localhost:6379` | Redis endpoint, supports `rediss://` for TLS providers. | | `REDIS_TLS_REJECT_UNAUTHORIZED` | `true` | Validate Redis TLS certificates; set `false` to skip (for self-signed/shared certs). | | `ENABLE_RATE_LIMIT` | `true` | Toggles multi-dimensional rate limiting; Fail-Open handles Redis outages gracefully. | +| `ENABLE_API_KEY_VACUUM_FILTER` | `true` | Enables API Key Vacuum Filter (negative short-circuit only; set to `false/0` to disable). | +| `ENABLE_API_KEY_REDIS_CACHE` | `true` | Enables API Key auth Redis cache (requires Redis; auto-fallback to DB on errors). | +| `API_KEY_AUTH_CACHE_TTL_SECONDS` | `60` | API Key auth cache TTL in seconds (default 60, max 3600). | | `SESSION_TTL` | `300` | Session cache window (seconds) that drives vendor reuse. | | `ENABLE_SECURE_COOKIES` | `true` | Browsers require HTTPS for Secure cookies; set to `false` when serving plain HTTP outside localhost. | | `ENABLE_CIRCUIT_BREAKER_ON_NETWORK_ERRORS` | `false` | When `true`, network errors also trip the circuit breaker for quicker isolation. | @@ -283,7 +286,7 @@ Docker Compose is the **preferred deployment method** — it automatically provi | `APP_URL` | empty | Populate to expose correct `servers` entries in OpenAPI docs. | | `API_TEST_TIMEOUT_MS` | `15000` | Timeout (ms) for provider API connectivity tests. Accepts 5000-120000 for regional tuning. | -> Boolean values should be `true/false` or `1/0` without quotes; otherwise Zod may coerce strings incorrectly. See `.env.example` for the full list. +> Boolean values support `true/false` or `1/0`. Quoting in `.env` is also fine (dotenv will strip quotes). See `.env.example` for the full list. ## ❓ FAQ diff --git a/README.md b/README.md index 595d1b01c..14d24676f 100644 --- a/README.md +++ b/README.md @@ -286,6 +286,9 @@ Docker Compose 是**首选部署方式**,自动配置数据库、Redis 和应 | `REDIS_URL` | `redis://localhost:6379` | Redis 地址,支持 `rediss://` 用于 TLS。 | | `REDIS_TLS_REJECT_UNAUTHORIZED` | `true` | 是否验证 Redis TLS 证书;设为 `false` 可跳过验证(用于自签/共享证书)。 | | `ENABLE_RATE_LIMIT` | `true` | 控制多维限流开关;Fail-Open 策略在 Redis 不可用时自动降级。 | +| `ENABLE_API_KEY_VACUUM_FILTER` | `true` | 是否启用 API Key 真空过滤器(仅负向短路无效 key;可设为 `false/0` 关闭用于排查/节省内存)。 | +| `ENABLE_API_KEY_REDIS_CACHE` | `true` | 是否启用 API Key 鉴权 Redis 缓存(需 Redis 可用;异常自动回落到 DB)。 | +| `API_KEY_AUTH_CACHE_TTL_SECONDS` | `60` | API Key 鉴权缓存 TTL(秒,默认 60,最大 3600)。 | | `SESSION_TTL` | `300` | Session 缓存时间(秒),影响供应商复用策略。 | | `ENABLE_SECURE_COOKIES` | `true` | 仅 HTTPS 场景能设置 Secure Cookie;HTTP 访问(非 localhost)需改为 `false`。 | | `ENABLE_CIRCUIT_BREAKER_ON_NETWORK_ERRORS` | `false` | 是否将网络错误计入熔断器;开启后能更激进地阻断异常线路。 | @@ -293,7 +296,7 @@ Docker Compose 是**首选部署方式**,自动配置数据库、Redis 和应 | `APP_URL` | 空 | 设置后 OpenAPI 文档 `servers` 将展示正确域名/端口。 | | `API_TEST_TIMEOUT_MS` | `15000` | 供应商 API 测试超时时间(毫秒,范围 5000-120000),跨境网络可适当提高。 | -> 布尔变量请直接写 `true/false` 或 `1/0`,勿加引号,避免被 Zod 转换为真值。更多字段参考 `.env.example`。 +> 布尔变量支持 `true/false` 或 `1/0`;在 `.env` 文件里写成带引号形式也没问题(dotenv 会解析并去掉引号)。更多字段参考 `.env.example`。 ## ❓ FAQ diff --git a/package.json b/package.json index db64d8491..7a72b1fcc 100644 --- a/package.json +++ b/package.json @@ -4,7 +4,7 @@ "private": true, "scripts": { "dev": "next dev --port 13500", - "build": "next build && cp VERSION .next/standalone/VERSION", + "build": "next build && (node scripts/copy-version-to-standalone.cjs || bun scripts/copy-version-to-standalone.cjs)", "start": "next start", "lint": "biome check .", "lint:fix": "biome check --write .", diff --git a/scripts/copy-version-to-standalone.cjs b/scripts/copy-version-to-standalone.cjs new file mode 100644 index 000000000..a8776bdcc --- /dev/null +++ b/scripts/copy-version-to-standalone.cjs @@ -0,0 +1,15 @@ +const fs = require("node:fs"); +const path = require("node:path"); + +const src = path.resolve(process.cwd(), "VERSION"); +const dstDir = path.resolve(process.cwd(), ".next", "standalone"); +const dst = path.join(dstDir, "VERSION"); + +if (!fs.existsSync(src)) { + console.error(`[copy-version] VERSION not found at ${src}`); + process.exit(1); +} + +fs.mkdirSync(dstDir, { recursive: true }); +fs.copyFileSync(src, dst); +console.log(`[copy-version] Copied VERSION -> ${dst}`); diff --git a/src/app/[locale]/dashboard/logs/_hooks/use-lazy-filter-options.ts b/src/app/[locale]/dashboard/logs/_hooks/use-lazy-filter-options.ts index 79beb9d1c..64dfbc811 100644 --- a/src/app/[locale]/dashboard/logs/_hooks/use-lazy-filter-options.ts +++ b/src/app/[locale]/dashboard/logs/_hooks/use-lazy-filter-options.ts @@ -47,7 +47,6 @@ function createLazyFilterHook( }; }, []); - // biome-ignore lint/correctness/useExhaustiveDependencies: fetcher 是工厂函数的闭包参数,在 hook 生命周期内永不改变 const load = useCallback(async () => { // 如果已加载或有进行中的请求,跳过 if (isLoaded || inFlightRef.current) return; diff --git a/src/instrumentation.ts b/src/instrumentation.ts index a2940b1be..d81b33ae9 100644 --- a/src/instrumentation.ts +++ b/src/instrumentation.ts @@ -3,11 +3,13 @@ * 在服务器启动时自动执行数据库迁移 */ -// instrumentation 需要 Node.js runtime(依赖数据库与 Redis 等 Node 能力) -export const runtime = "nodejs"; - import { startCacheCleanup, stopCacheCleanup } from "@/lib/cache/session-cache"; import { logger } from "@/lib/logger"; +import { CHANNEL_API_KEYS_UPDATED, subscribeCacheInvalidation } from "@/lib/redis/pubsub"; +import { apiKeyVacuumFilter } from "@/lib/security/api-key-vacuum-filter"; + +// instrumentation 需要 Node.js runtime(依赖数据库与 Redis 等 Node 能力) +export const runtime = "nodejs"; const instrumentationState = globalThis as unknown as { __CCH_CACHE_CLEANUP_STARTED__?: boolean; @@ -15,6 +17,8 @@ const instrumentationState = globalThis as unknown as { __CCH_SHUTDOWN_IN_PROGRESS__?: boolean; __CCH_CLOUD_PRICE_SYNC_STARTED__?: boolean; __CCH_CLOUD_PRICE_SYNC_INTERVAL_ID__?: ReturnType; + __CCH_API_KEY_VF_SYNC_STARTED__?: boolean; + __CCH_API_KEY_VF_SYNC_CLEANUP__?: (() => void) | null; }; /** @@ -82,6 +86,57 @@ async function startCloudPriceSyncScheduler(): Promise { } } +/** + * 多实例:订阅 API Key 变更广播,触发本机 Vacuum Filter 失效并重建。 + * + * 目标: + * - 避免“本机 filter 漏包含新 key”导致的误拒绝 + * - 重建失败/Redis 未配置时自动降级(不阻塞启动) + */ +async function startApiKeyVacuumFilterSync(): Promise { + if (instrumentationState.__CCH_API_KEY_VF_SYNC_STARTED__) { + return; + } + + // 与 Redis client 的启用条件保持一致:未启用限流/未配置 Redis 时不尝试订阅,避免额外 warn 日志 + const rateLimitRaw = process.env.ENABLE_RATE_LIMIT?.trim(); + if (rateLimitRaw === "false" || rateLimitRaw === "0" || !process.env.REDIS_URL) { + return; + } + + try { + const cleanup = await subscribeCacheInvalidation(CHANNEL_API_KEYS_UPDATED, () => { + apiKeyVacuumFilter.invalidateAndReload({ reason: "api_keys_updated" }); + }); + + if (!cleanup) { + return; + } + + instrumentationState.__CCH_API_KEY_VF_SYNC_STARTED__ = true; + instrumentationState.__CCH_API_KEY_VF_SYNC_CLEANUP__ = cleanup; + logger.info("[Instrumentation] API Key Vacuum Filter sync enabled"); + } catch (error) { + logger.warn("[Instrumentation] API Key Vacuum Filter sync init failed", { + error: error instanceof Error ? error.message : String(error), + }); + } +} + +function warmupApiKeyVacuumFilter(): void { + // 预热 API Key Vacuum Filter(减少无效 key 对 DB 的压力) + try { + apiKeyVacuumFilter.startBackgroundReload({ reason: "startup" }); + } catch (error) { + logger.warn("[Instrumentation] Failed to start API key vacuum filter preload", { + error: error instanceof Error ? error.message : String(error), + }); + } + + // 多实例:订阅 key 变更广播以触发本机 filter 重建 + void startApiKeyVacuumFilterSync(); +} + export async function register() { // 仅在服务器端执行 if (process.env.NEXT_RUNTIME === "nodejs") { @@ -121,6 +176,15 @@ export async function register() { }); } + try { + instrumentationState.__CCH_API_KEY_VF_SYNC_CLEANUP__?.(); + instrumentationState.__CCH_API_KEY_VF_SYNC_STARTED__ = false; + } catch (error) { + logger.warn("[Instrumentation] Failed to cleanup API key vacuum filter sync", { + error: error instanceof Error ? error.message : String(error), + }); + } + try { const { stopEndpointProbeScheduler } = await import( "@/lib/provider-endpoints/probe-scheduler" @@ -206,6 +270,8 @@ export async function register() { logger.info("[Instrumentation] AUTO_MIGRATE=false: skipping migrations"); } + warmupApiKeyVacuumFilter(); + // 回填 provider_vendors(按域名自动聚合旧 providers) try { const { backfillProviderVendorsFromProviders } = await import( @@ -306,6 +372,8 @@ export async function register() { if (isConnected) { await runMigrations(); + warmupApiKeyVacuumFilter(); + // 回填 provider_vendors(按域名自动聚合旧 providers) try { const { backfillProviderVendorsFromProviders } = await import( diff --git a/src/lib/auth.ts b/src/lib/auth.ts index 9e41effaa..62a2cac0f 100644 --- a/src/lib/auth.ts +++ b/src/lib/auth.ts @@ -1,8 +1,7 @@ import { cookies, headers } from "next/headers"; import { config } from "@/lib/config/config"; import { getEnvConfig } from "@/lib/config/env.schema"; -import { findActiveKeyByKeyString } from "@/repository/key"; -import { findUserById } from "@/repository/user"; +import { validateApiKeyAndGetUser } from "@/repository/key"; import type { Key } from "@/types/key"; import type { User } from "@/types/user"; @@ -107,18 +106,24 @@ export async function validateKey( return { user: adminUser, key: adminKey }; } - const key = await findActiveKeyByKeyString(keyString); - if (!key) { + // 默认鉴权链路:Vacuum Filter(仅负向短路) → Redis(key/user 缓存) → DB(权威校验) + const authResult = await validateApiKeyAndGetUser(keyString); + if (!authResult) { return null; } - // 检查 Web UI 登录权限 - if (!allowReadOnlyAccess && !key.canLoginWebUi) { + const { user, key } = authResult; + + // 用户状态校验:与 v1 proxy 侧保持一致,避免禁用/过期用户继续登录或持有会话 + if (!user.isEnabled) { + return null; + } + if (user.expiresAt && user.expiresAt.getTime() <= Date.now()) { return null; } - const user = await findUserById(key.userId); - if (!user) { + // 检查 Web UI 登录权限 + if (!allowReadOnlyAccess && !key.canLoginWebUi) { return null; } diff --git a/src/lib/redis/client.ts b/src/lib/redis/client.ts index 0ec54f2f3..ea999e041 100644 --- a/src/lib/redis/client.ts +++ b/src/lib/redis/client.ts @@ -21,7 +21,8 @@ function maskRedisUrl(redisUrl: string) { * Includes servername for SNI (Server Name Indication) support. */ function buildTlsConfig(redisUrl: string): Record { - const rejectUnauthorized = process.env.REDIS_TLS_REJECT_UNAUTHORIZED !== "false"; + const raw = process.env.REDIS_TLS_REJECT_UNAUTHORIZED?.trim(); + const rejectUnauthorized = raw !== "false" && raw !== "0"; try { const url = new URL(redisUrl); @@ -79,7 +80,8 @@ export function getRedisClient(): Redis | null { } const redisUrl = process.env.REDIS_URL; - const isEnabled = process.env.ENABLE_RATE_LIMIT === "true"; + const rateLimitRaw = process.env.ENABLE_RATE_LIMIT?.trim(); + const isEnabled = rateLimitRaw !== "false" && rateLimitRaw !== "0"; if (!isEnabled || !redisUrl) { logger.warn("[Redis] Rate limiting disabled or REDIS_URL not configured"); @@ -112,7 +114,8 @@ export function getRedisClient(): Redis | null { // 2. 如果使用 rediss://,则添加显式的 TLS 配置(支持跳过证书验证) if (useTls) { - const rejectUnauthorized = process.env.REDIS_TLS_REJECT_UNAUTHORIZED !== "false"; + const raw = process.env.REDIS_TLS_REJECT_UNAUTHORIZED?.trim(); + const rejectUnauthorized = raw !== "false" && raw !== "0"; logger.info("[Redis] Using TLS connection (rediss://)", { redisUrl: safeRedisUrl, rejectUnauthorized, diff --git a/src/lib/redis/pubsub.ts b/src/lib/redis/pubsub.ts index 07e854521..c00dd4c93 100644 --- a/src/lib/redis/pubsub.ts +++ b/src/lib/redis/pubsub.ts @@ -7,6 +7,8 @@ import { getRedisClient } from "./client"; export const CHANNEL_ERROR_RULES_UPDATED = "cch:cache:error_rules:updated"; export const CHANNEL_REQUEST_FILTERS_UPDATED = "cch:cache:request_filters:updated"; export const CHANNEL_SENSITIVE_WORDS_UPDATED = "cch:cache:sensitive_words:updated"; +// API Key 集合发生变化(典型:创建新 key)时,通知各实例重建 Vacuum Filter,避免误拒绝 +export const CHANNEL_API_KEYS_UPDATED = "cch:cache:api_keys:updated"; type CacheInvalidationCallback = () => void; diff --git a/src/lib/security/api-key-auth-cache.ts b/src/lib/security/api-key-auth-cache.ts new file mode 100644 index 000000000..dcb15358b --- /dev/null +++ b/src/lib/security/api-key-auth-cache.ts @@ -0,0 +1,403 @@ +import { logger } from "@/lib/logger"; +import type { Key } from "@/types/key"; +import type { User } from "@/types/user"; + +type RedisPipelineLike = { + setex(key: string, ttlSeconds: number, value: string): RedisPipelineLike; + del(key: string): RedisPipelineLike; + exec(): Promise; +}; + +type RedisLike = { + get(key: string): Promise; + setex(key: string, ttlSeconds: number, value: string): Promise; + del(key: string): Promise; + pipeline(): RedisPipelineLike; +}; + +const CACHE_VERSION = 1 as const; + +const REDIS_KEYS = { + keyByHash: (sha256Hex: string) => `api_key_auth:v${CACHE_VERSION}:key:${sha256Hex}`, + userById: (userId: number) => `api_key_auth:v${CACHE_VERSION}:user:${userId}`, +}; + +function isEdgeRuntime(): boolean { + if (typeof process === "undefined") return true; + return process.env.NEXT_RUNTIME === "edge"; +} + +function isApiKeyRedisCacheEnabled(): boolean { + if (isEdgeRuntime()) return false; + const raw = process.env.ENABLE_API_KEY_REDIS_CACHE?.trim(); + return raw !== "false" && raw !== "0"; +} + +function getCacheTtlSeconds(): number { + const raw = process.env.API_KEY_AUTH_CACHE_TTL_SECONDS; + const parsed = raw ? Number.parseInt(raw, 10) : 60; + if (!Number.isFinite(parsed) || parsed <= 0) return 60; + // 上限 1 小时,避免配置错误导致“长时间脏读” + return Math.min(parsed, 3600); +} + +const textEncoder = new TextEncoder(); +const byteToHex = Array.from({ length: 256 }, (_, index) => index.toString(16).padStart(2, "0")); + +function bufferToHex(buffer: ArrayBuffer): string { + const bytes = new Uint8Array(buffer); + let out = ""; + for (let i = 0; i < bytes.length; i++) { + out += byteToHex[bytes[i]]; + } + return out; +} + +async function sha256Hex(value: string): Promise { + const subtle = (globalThis as unknown as { crypto?: Crypto }).crypto?.subtle; + if (!subtle) return null; + + try { + const digest = await subtle.digest("SHA-256", textEncoder.encode(value)); + return bufferToHex(digest); + } catch (error) { + logger.debug( + { error: error instanceof Error ? error.message : String(error) }, + "[ApiKeyAuthCache] sha256 digest failed" + ); + return null; + } +} + +function shouldUseRedisClient(): boolean { + // Edge runtime/浏览器等无 process 环境:直接禁用 + if (typeof process === "undefined") return false; + + // 与 getRedisClient 的启用条件保持一致,避免在未配置 Redis 时触发热路径 warn 日志 + if (process.env.CI === "true" || process.env.NEXT_PHASE === "phase-production-build") return false; + if (!process.env.REDIS_URL) return false; + const rateLimitRaw = process.env.ENABLE_RATE_LIMIT?.trim(); + if (rateLimitRaw === "false" || rateLimitRaw === "0") return false; + return true; +} + +let getRedisClientFn: (() => unknown) | null | undefined; + +async function getRedisForApiKeyAuthCache(): Promise { + if (!isApiKeyRedisCacheEnabled()) return null; + if (!shouldUseRedisClient()) return null; + + if (getRedisClientFn === undefined) { + try { + const mod = await import("@/lib/redis/client"); + getRedisClientFn = mod.getRedisClient; + } catch (error) { + logger.debug( + { error: error instanceof Error ? error.message : String(error) }, + "[ApiKeyAuthCache] Load redis client failed" + ); + getRedisClientFn = null; + } + } + + if (!getRedisClientFn) return null; + return getRedisClientFn() as RedisLike | null; +} + +function parseRequiredDate(value: unknown): Date | null { + const date = value instanceof Date ? value : new Date(String(value)); + return Number.isNaN(date.getTime()) ? null : date; +} + +function parseOptionalDate(value: unknown): Date | null | undefined { + if (value === undefined) return undefined; + if (value === null) return null; + return parseRequiredDate(value); +} + +type CachedKeyPayloadV1 = { + v: 1; + key: Omit; +}; + +type CachedUserPayloadV1 = { + v: 1; + user: User; +}; + +function hydrateKeyFromCache(keyString: string, payload: CachedKeyPayloadV1): Key | null { + const key = payload.key as unknown as Record; + if (!key || typeof key !== "object") return null; + if (typeof key.id !== "number" || typeof key.userId !== "number") return null; + if (typeof key.name !== "string" || typeof key.isEnabled !== "boolean") return null; + if (typeof key.canLoginWebUi !== "boolean") return null; + if (typeof key.dailyResetMode !== "string" || typeof key.dailyResetTime !== "string") return null; + if (typeof key.limitConcurrentSessions !== "number") return null; + + const createdAt = parseRequiredDate(key.createdAt); + const updatedAt = parseRequiredDate(key.updatedAt); + if (!createdAt || !updatedAt) return null; + + const expiresAt = parseOptionalDate(key.expiresAt); + const deletedAt = parseOptionalDate(key.deletedAt); + if (key.expiresAt != null && !expiresAt) return null; + if (key.deletedAt != null && !deletedAt) return null; + + return { + ...(payload.key as Omit), + key: keyString, + createdAt, + updatedAt, + expiresAt: expiresAt === undefined ? undefined : expiresAt, + deletedAt: deletedAt === undefined ? undefined : deletedAt, + } as Key; +} + +function hydrateUserFromCache(payload: CachedUserPayloadV1): User | null { + const user = payload.user as unknown as Record; + if (!user || typeof user !== "object") return null; + if (typeof user.id !== "number" || typeof user.name !== "string") return null; + if (typeof user.role !== "string") return null; + if (typeof user.isEnabled !== "boolean") return null; + if (typeof user.dailyResetMode !== "string" || typeof user.dailyResetTime !== "string") return null; + + const createdAt = parseRequiredDate(user.createdAt); + const updatedAt = parseRequiredDate(user.updatedAt); + if (!createdAt || !updatedAt) return null; + + const expiresAt = parseOptionalDate(user.expiresAt); + const deletedAt = parseOptionalDate(user.deletedAt); + if (user.expiresAt != null && !expiresAt) return null; + if (user.deletedAt != null && !deletedAt) return null; + + return { + ...(payload.user as User), + createdAt, + updatedAt, + expiresAt: expiresAt === undefined ? undefined : expiresAt, + deletedAt: deletedAt === undefined ? undefined : deletedAt, + } as User; +} + +function stripKeySecret(key: Key): Omit { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { key: _secretKey, ...rest } = key; + return rest; +} + +function resolveKeyCacheTtlSeconds(key: Key): number { + const base = getCacheTtlSeconds(); + const expiresAt = parseOptionalDate(key.expiresAt); + // expiresAt 存在但无法解析:安全起见不缓存 + if (key.expiresAt != null && !expiresAt) return 0; + if (!(expiresAt instanceof Date)) return base; + + const remainingMs = expiresAt.getTime() - Date.now(); + if (remainingMs <= 0) return 0; + const remainingSeconds = Math.max(1, Math.floor(remainingMs / 1000)); + return Math.min(base, remainingSeconds); +} + +export async function getCachedActiveKey(keyString: string): Promise { + const redis = await getRedisForApiKeyAuthCache(); + if (!redis) return null; + + const keyHash = await sha256Hex(keyString); + if (!keyHash) return null; + const redisKey = REDIS_KEYS.keyByHash(keyHash); + + try { + const raw = await redis.get(redisKey); + if (!raw) return null; + + const parsed = JSON.parse(raw) as CachedKeyPayloadV1; + if (parsed?.v !== 1 || !parsed.key) { + redis.del(redisKey).catch(() => {}); + return null; + } + + const hydrated = hydrateKeyFromCache(keyString, parsed); + if (!hydrated) { + redis.del(redisKey).catch(() => {}); + return null; + } + + // 仅用于“活跃 key”缓存:不满足条件时视为缓存失效 + if (hydrated.isEnabled !== true) { + redis.del(redisKey).catch(() => {}); + return null; + } + if (hydrated.deletedAt) { + redis.del(redisKey).catch(() => {}); + return null; + } + if (hydrated.expiresAt && hydrated.expiresAt.getTime() <= Date.now()) { + redis.del(redisKey).catch(() => {}); + return null; + } + + return hydrated; + } catch (error) { + // Fail open:缓存错误不影响鉴权正确性(会回落到 DB) + logger.debug( + { error: error instanceof Error ? error.message : String(error) }, + "[ApiKeyAuthCache] Read key cache failed" + ); + return null; + } +} + +export async function cacheActiveKey(key: Key): Promise { + const redis = await getRedisForApiKeyAuthCache(); + if (!redis) return; + + const ttlSeconds = resolveKeyCacheTtlSeconds(key); + const expiresAt = parseOptionalDate(key.expiresAt); + const expiresAtInvalid = key.expiresAt != null && !expiresAt; + const isExpired = expiresAt instanceof Date && expiresAt.getTime() <= Date.now(); + + const keyHash = await sha256Hex(key.key); + if (!keyHash) return; + const redisKey = REDIS_KEYS.keyByHash(keyHash); + + // 非活跃 key:直接清理缓存,避免脏读误放行 + if (key.isEnabled !== true || key.deletedAt || isExpired || expiresAtInvalid || ttlSeconds <= 0) { + try { + await redis.del(redisKey); + } catch { + // ignore + } + return; + } + + const payload: CachedKeyPayloadV1 = { v: 1, key: stripKeySecret(key) }; + try { + await redis.setex(redisKey, ttlSeconds, JSON.stringify(payload)); + } catch (error) { + logger.debug( + { error: error instanceof Error ? error.message : String(error) }, + "[ApiKeyAuthCache] Write key cache failed" + ); + } +} + +export async function invalidateCachedKey(keyString: string): Promise { + const redis = await getRedisForApiKeyAuthCache(); + if (!redis) return; + + const keyHash = await sha256Hex(keyString); + if (!keyHash) return; + const redisKey = REDIS_KEYS.keyByHash(keyHash); + try { + await redis.del(redisKey); + } catch { + // ignore + } +} + +export async function getCachedUser(userId: number): Promise { + const redis = await getRedisForApiKeyAuthCache(); + if (!redis) return null; + + const redisKey = REDIS_KEYS.userById(userId); + + try { + const raw = await redis.get(redisKey); + if (!raw) return null; + + const parsed = JSON.parse(raw) as CachedUserPayloadV1; + if (parsed?.v !== 1 || !parsed.user) { + redis.del(redisKey).catch(() => {}); + return null; + } + + const hydrated = hydrateUserFromCache(parsed); + if (!hydrated) { + redis.del(redisKey).catch(() => {}); + return null; + } + + // validateApiKeyAndGetUser 的语义:user 仅要求“未删除”;isEnabled/expiresAt 等状态由上层按需校验(如 auth.ts) + if (hydrated.deletedAt) { + redis.del(redisKey).catch(() => {}); + return null; + } + + return hydrated; + } catch (error) { + logger.debug( + { error: error instanceof Error ? error.message : String(error) }, + "[ApiKeyAuthCache] Read user cache failed" + ); + return null; + } +} + +export async function cacheUser(user: User): Promise { + const redis = await getRedisForApiKeyAuthCache(); + if (!redis) return; + + if (user.deletedAt) return; + + const ttlSeconds = getCacheTtlSeconds(); + const redisKey = REDIS_KEYS.userById(user.id); + const payload: CachedUserPayloadV1 = { v: 1, user }; + try { + await redis.setex(redisKey, ttlSeconds, JSON.stringify(payload)); + } catch (error) { + logger.debug( + { error: error instanceof Error ? error.message : String(error) }, + "[ApiKeyAuthCache] Write user cache failed" + ); + } +} + +export async function invalidateCachedUser(userId: number): Promise { + const redis = await getRedisForApiKeyAuthCache(); + if (!redis) return; + + const redisKey = REDIS_KEYS.userById(userId); + try { + await redis.del(redisKey); + } catch { + // ignore + } +} + +export async function cacheAuthResult(keyString: string, value: { key: Key; user: User }): Promise { + const redis = await getRedisForApiKeyAuthCache(); + if (!redis) return; + + const { key, user } = value; + const keyHash = await sha256Hex(keyString); + if (!keyHash) return; + const keyRedisKey = REDIS_KEYS.keyByHash(keyHash); + const userRedisKey = REDIS_KEYS.userById(user.id); + + const keyTtlSeconds = resolveKeyCacheTtlSeconds(key); + const userTtlSeconds = getCacheTtlSeconds(); + + try { + const pipeline = redis.pipeline(); + if (keyTtlSeconds > 0 && key.isEnabled === true && !key.deletedAt) { + const keyPayload: CachedKeyPayloadV1 = { v: 1, key: stripKeySecret(key) }; + pipeline.setex(keyRedisKey, keyTtlSeconds, JSON.stringify(keyPayload)); + } else { + pipeline.del(keyRedisKey); + } + + if (!user.deletedAt) { + const userPayload: CachedUserPayloadV1 = { v: 1, user }; + pipeline.setex(userRedisKey, userTtlSeconds, JSON.stringify(userPayload)); + } else { + pipeline.del(userRedisKey); + } + + await pipeline.exec(); + } catch (error) { + logger.debug( + { error: error instanceof Error ? error.message : String(error) }, + "[ApiKeyAuthCache] Write auth cache failed" + ); + } +} diff --git a/src/lib/security/api-key-vacuum-filter.ts b/src/lib/security/api-key-vacuum-filter.ts new file mode 100644 index 000000000..cf02ad6b2 --- /dev/null +++ b/src/lib/security/api-key-vacuum-filter.ts @@ -0,0 +1,384 @@ +import { logger } from "@/lib/logger"; +import { VacuumFilter } from "@/lib/vacuum-filter/vacuum-filter"; +import { randomBytes } from "@/lib/vacuum-filter/random"; + +type ApiKeyVacuumFilterStats = { + enabled: boolean; + ready: boolean; + loading: boolean; + lastReloadAt: number | null; + sourceKeyCount: number; + filterSize: number; + filterLoadFactor: number; + fingerprintBits: number; + maxKickSteps: number; +}; + +type ReloadOptions = { + reason: string; + /** + * 是否强制触发(忽略 cooldown)。 + * + * 用途: + * - 多实例场景收到“key 已新增”的广播后,需要尽快重建避免误拒绝 + */ + force?: boolean; +}; + +/** + * 纯构建函数:从 key 列表构建 VacuumFilter。 + * + * 导出原因: + * - 便于测试(不依赖 DB) + * - 便于未来扩展(例如:从 Redis/文件加载快照) + */ +export function buildVacuumFilterFromKeyStrings(options: { + keyStrings: string[]; + fingerprintBits: number; + maxKickSteps: number; + seed: Uint8Array; +}): VacuumFilter { + const { keyStrings, fingerprintBits, maxKickSteps, seed } = options; + + const uniqueKeys = Array.from(new Set(keyStrings)).filter((v) => v.length > 0); + + // 目标:尽量接近 Vacuum Filter 的高负载设计点,同时给“增量新增 key”留少量 headroom, + // 避免刚重建就接近极限导致频繁 insert_failed 重建。 + const targetLoadFactor = 0.96; + const desiredLoadFactor = 0.9; + let maxItems = Math.max( + 128, + Math.ceil((uniqueKeys.length * targetLoadFactor) / desiredLoadFactor) + ); + let lastError: Error | null = null; + + for (let attempt = 1; attempt <= 6; attempt++) { + const vf = new VacuumFilter({ + maxItems, + fingerprintBits, + maxKickSteps, + seed, + targetLoadFactor, + }); + + let okAll = true; + for (const key of uniqueKeys) { + if (!vf.add(key)) { + okAll = false; + break; + } + } + + if (okAll) { + return vf; + } + + lastError = new Error(`build failed at attempt=${attempt}, maxItems=${maxItems}`); + maxItems = Math.ceil(maxItems * 1.6); + } + + throw lastError ?? new Error("Vacuum filter build failed"); +} + +/** + * API Key Vacuum Filter(进程级单例) + * + * 用途: + * - 在访问数据库前,先用真空过滤器快速判定“肯定不存在”的 key,直接拒绝(减少 DB 压力、抵御爆破) + * + * 关键安全语义: + * - 仅用于“负向短路”:filter.has(key)===false 才能“肯定不存在” + * - filter.has(key)===true 只代表“可能存在”,仍必须走 DB 校验(避免假阳性误放行) + * + * 正确性约束: + * - 允许“过度包含”(比如包含禁用/过期 key,甚至包含已删除 key 的 fingerprint),只会降低短路命中率,不影响安全性。 + * - 严禁“漏包含”有效 key:否则会产生错误拒绝。因此: + * - 启动时尽量从 DB 全量加载(见 instrumentation) + * - 新增 key 时增量写入(createKey -> noteExistingKey) + */ +class ApiKeyVacuumFilter { + private readonly enabled: boolean; + private readonly seed: Uint8Array; + private readonly fingerprintBits = 32; + private readonly maxKickSteps = 500; + + private vf: VacuumFilter | null = null; + private loadingPromise: Promise | null = null; + + // 关键:当 vf 尚未就绪(或正在重建)时,新 key 可能在这段窗口期被创建。 + // 若不记录并在下一次重建时纳入,会导致“漏包含”有效 key,从而误拒绝(假阴性)。 + private pendingKeys = new Set(); + private readonly pendingKeysLimit = 10_000; + + // 若重建过程中又收到新的重建请求(例如:多实例收到 key 创建广播),需要串行再跑一次。 + private pendingReloadReason: string | null = null; + private pendingReloadForce = false; + + private lastReloadAttemptAt: number | null = null; + private readonly reloadCooldownMs = 10_000; + + private lastReloadAt: number | null = null; + private sourceKeyCount = 0; + + constructor() { + // 默认开启:升级后无需额外配置即可启用(仅负向短路;不会影响鉴权正确性)。 + // 如需排查或节省资源,可通过环境变量显式关闭:ENABLE_API_KEY_VACUUM_FILTER=false/0 + if (typeof process === "undefined") { + // Edge/浏览器等无 process 环境:强制关闭(避免访问 process.env 抛错) + this.enabled = false; + } else { + const isEdgeRuntime = process.env.NEXT_RUNTIME === "edge"; + const raw = process.env.ENABLE_API_KEY_VACUUM_FILTER?.trim(); + const explicitlyDisabled = raw === "false" || raw === "0"; + this.enabled = !isEdgeRuntime && !explicitlyDisabled; + } + this.seed = randomBytes(16); + } + + /** + * 返回: + * - true:过滤器“肯定判断不存在”(可直接拒绝) + * - false:过滤器认为“可能存在”(必须继续走 DB) + * - null:过滤器未就绪或未启用(不要短路) + */ + isDefinitelyNotPresent(keyString: string): boolean | null { + if (!this.enabled) return null; + + // 重建过程中:安全优先,不短路(避免使用可能过期的 vf 产生误拒绝) + if (this.loadingPromise) { + return null; + } + + const vf = this.vf; + if (!vf) { + // 懒加载:第一次触发时后台预热(同时保持“安全优先”:不就绪时不短路) + this.startBackgroundReload({ reason: "lazy_warmup" }); + return null; + } + + return !vf.has(keyString); + } + + /** + * 将一个“已确认为存在”的 key 写入过滤器(尽量保持新建 key 的即时可用性)。 + * + * 注意:写入失败不会影响正确性(仍会走 DB),只是降低短路命中率;失败后可依赖后台重建修复。 + */ + noteExistingKey(keyString: string): void { + if (!this.enabled) return; + const trimmed = keyString.trim(); + if (!trimmed) return; + + try { + const vf = this.vf; + if (!vf) { + // vf 未就绪:记录到 pending,确保下一次重建会覆盖到该 key(避免误拒绝) + if (this.pendingKeys.size < this.pendingKeysLimit) { + this.pendingKeys.add(trimmed); + } else { + logger.warn("[ApiKeyVacuumFilter] Pending keys overflow; scheduling rebuild", { + limit: this.pendingKeysLimit, + }); + } + this.startBackgroundReload({ reason: "pending_key", force: true }); + return; + } + + // 重建进行中:同时写入 pending,确保新 filter 不会漏包含该 key + if (this.loadingPromise) { + if (this.pendingKeys.size < this.pendingKeysLimit) { + this.pendingKeys.add(trimmed); + } else { + logger.warn("[ApiKeyVacuumFilter] Pending keys overflow; scheduling rebuild", { + limit: this.pendingKeysLimit, + }); + } + + // 合并重建请求:当前重建结束后再跑一次,确保纳入 pendingKeys + this.startBackgroundReload({ reason: "pending_key_during_reload", force: true }); + } + + // 注意:不要用 vf.has(key) 来“去重” —— has 可能是短暂假阳性,后续插入/搬移可能让假阳性消失, + // 从而导致真正存在的 key 没被写入、最终产生误拒绝风险。对新建 key(应唯一)直接 add 更安全。 + const ok = vf.add(trimmed); + if (!ok) { + logger.warn("[ApiKeyVacuumFilter] Insert failed; scheduling rebuild", { + keyLength: trimmed.length, + }); + // 安全优先:插入失败意味着新 key 可能未被覆盖。 + // 为避免误拒绝(假阴性),临时禁用短路,等待后台重建完成后再恢复。 + if (this.pendingKeys.size < this.pendingKeysLimit) { + this.pendingKeys.add(trimmed); + } else { + logger.warn("[ApiKeyVacuumFilter] Pending keys overflow; scheduling rebuild", { + limit: this.pendingKeysLimit, + }); + } + this.vf = null; + this.startBackgroundReload({ reason: "insert_failed", force: true }); + } + } catch (error) { + logger.warn("[ApiKeyVacuumFilter] noteExistingKey failed; scheduling rebuild", { + error: error instanceof Error ? error.message : String(error), + }); + if (this.pendingKeys.size < this.pendingKeysLimit) { + this.pendingKeys.add(trimmed); + } else { + logger.warn("[ApiKeyVacuumFilter] Pending keys overflow; scheduling rebuild", { + limit: this.pendingKeysLimit, + }); + } + this.vf = null; + try { + this.startBackgroundReload({ reason: "note_existing_key_failed", force: true }); + } catch { + // ignore + } + } + } + + /** + * 外部触发:标记过滤器可能已过期,并强制后台重建。 + * + * 典型场景:多实例环境下,某个实例创建了新 key;其它实例需要尽快重建,避免误拒绝。 + */ + invalidateAndReload(options: ReloadOptions): void { + if (!this.enabled) return; + this.vf = null; + this.startBackgroundReload({ ...options, force: true }); + } + + startBackgroundReload(options: ReloadOptions): void { + if (!this.enabled) return; + if (this.loadingPromise) { + // 重建进行中:合并请求,待当前重建结束后再跑一次(避免“读到旧快照”漏新 key) + this.pendingReloadReason = options.reason; + this.pendingReloadForce = this.pendingReloadForce || options.force === true; + return; + } + + const now = Date.now(); + if ( + options.force !== true && + this.lastReloadAttemptAt && + now - this.lastReloadAttemptAt < this.reloadCooldownMs + ) { + return; + } + this.lastReloadAttemptAt = now; + + this.loadingPromise = this.reloadFromDatabase(options) + .catch((error) => { + logger.warn("[ApiKeyVacuumFilter] Reload failed", { + reason: options.reason, + error: error instanceof Error ? error.message : String(error), + }); + }) + .finally(() => { + this.loadingPromise = null; + + // 若重建期间又收到新的重建请求,串行补一次(避免漏 key) + if (this.pendingReloadReason) { + const reason = this.pendingReloadReason; + const force = this.pendingReloadForce; + this.pendingReloadReason = null; + this.pendingReloadForce = false; + this.startBackgroundReload({ reason, force }); + } + }); + } + + getStats(): ApiKeyVacuumFilterStats { + const vf = this.vf; + return { + enabled: this.enabled, + ready: !!vf, + loading: !!this.loadingPromise, + lastReloadAt: this.lastReloadAt, + sourceKeyCount: this.sourceKeyCount, + filterSize: vf?.size() ?? 0, + filterLoadFactor: vf?.loadFactor() ?? 0, + fingerprintBits: this.fingerprintBits, + maxKickSteps: this.maxKickSteps, + }; + } + + // ==================== 预热/重建 ==================== + + private async reloadFromDatabase(options: ReloadOptions): Promise { + // CI / 测试环境通常不接 DB;避免大量告警日志 + const dsn = process.env.DSN || ""; + if ( + process.env.CI === "true" || + process.env.NODE_ENV === "test" || + process.env.VITEST === "true" || + !dsn || + dsn.includes("user:password@host:port") + ) { + logger.debug("[ApiKeyVacuumFilter] Skip reload (test env or DB not configured)"); + return; + } + + // 延迟 import,避免构建/测试阶段触发 DB 初始化 + const [{ db }, { keys }, { isNull }] = await Promise.all([ + import("@/drizzle/db"), + import("@/drizzle/schema"), + import("drizzle-orm"), + ]); + + const rows = await db + .select({ key: keys.key }) + .from(keys) + // 仅排除逻辑删除;禁用/过期 key 保留在 filter 中(安全:不会误拒绝) + .where(isNull(keys.deletedAt)); + + const keyStrings = rows + .map((r) => r.key) + .filter((v): v is string => typeof v === "string" && v.length > 0); + + // 将 pendingKeys 合并进来:覆盖“重建窗口期创建的新 key”。 + // 通过“Set 交换”获得快照,避免 snapshot-merge-clear 的竞态窗口: + // - reload 期间新增的 key 会进入新的 pendingKeys + // - 本次快照 key 会被纳入 built filter + // - 若 build 失败,会将快照 key 合并回 pendingKeys,避免漏 key + const pendingSnapshotSet = this.pendingKeys; + this.pendingKeys = new Set(); + const pendingSnapshot = + pendingSnapshotSet.size > 0 ? Array.from(pendingSnapshotSet.values()) : []; + + let built: VacuumFilter; + try { + built = buildVacuumFilterFromKeyStrings({ + keyStrings: pendingSnapshot.length > 0 ? keyStrings.concat(pendingSnapshot) : keyStrings, + fingerprintBits: this.fingerprintBits, + maxKickSteps: this.maxKickSteps, + seed: this.seed, + }); + } catch (error) { + // build 失败:回滚快照,避免漏 key(同时保留 reload 期间新增的 key) + for (const k of pendingSnapshotSet.values()) { + if (this.pendingKeys.size >= this.pendingKeysLimit) break; + this.pendingKeys.add(k); + } + throw error; + } + + this.vf = built; + this.sourceKeyCount = new Set(keyStrings).size; + this.lastReloadAt = Date.now(); + + logger.info("[ApiKeyVacuumFilter] Reloaded", { + reason: options.reason, + keyCount: this.sourceKeyCount, + loadFactor: Number(built.loadFactor().toFixed(4)), + }); + } +} + +// 使用 globalThis 保证单例(避免开发环境热重载重复实例化) +const g = globalThis as unknown as { __CCH_API_KEY_VACUUM_FILTER__?: ApiKeyVacuumFilter }; +if (!g.__CCH_API_KEY_VACUUM_FILTER__) { + g.__CCH_API_KEY_VACUUM_FILTER__ = new ApiKeyVacuumFilter(); +} + +export const apiKeyVacuumFilter = g.__CCH_API_KEY_VACUUM_FILTER__; diff --git a/src/lib/vacuum-filter/random.ts b/src/lib/vacuum-filter/random.ts new file mode 100644 index 000000000..4186eba20 --- /dev/null +++ b/src/lib/vacuum-filter/random.ts @@ -0,0 +1,23 @@ +type WebCryptoLike = { + getRandomValues(bytes: Uint8Array): Uint8Array; +}; + +function getWebCrypto(): WebCryptoLike | null { + const c = (globalThis as unknown as { crypto?: WebCryptoLike }).crypto; + return c && typeof c.getRandomValues === "function" ? c : null; +} + +export function randomBytes(size: number): Uint8Array { + const out = new Uint8Array(size); + const webCrypto = getWebCrypto(); + if (webCrypto) { + webCrypto.getRandomValues(out); + return out; + } + + // 兜底:极端环境无 Web Crypto 时,使用 Math.random(仅用于 seed,不影响正确性) + for (let i = 0; i < out.length; i++) { + out[i] = Math.floor(Math.random() * 256); + } + return out; +} diff --git a/src/lib/vacuum-filter/vacuum-filter.ts b/src/lib/vacuum-filter/vacuum-filter.ts new file mode 100644 index 000000000..d5a1edc53 --- /dev/null +++ b/src/lib/vacuum-filter/vacuum-filter.ts @@ -0,0 +1,611 @@ +import { randomBytes } from "@/lib/vacuum-filter/random"; + +const textEncoder = new TextEncoder(); +const BUCKET_SIZE = 4 as const; +const DEFAULT_SCRATCH_BYTES = 256; + +/** + * Vacuum Filter(真空过滤器) + * + * 目标: + * - 近似集合成员查询(AMQ):支持插入 / 查询 / 删除 + * - 无假阴性(在不发生“误删”的前提下):插入成功的元素,查询必定返回 true + * - 有假阳性:查询可能返回 true,但元素实际不存在(由 fingerprint 位数决定) + * + * 实现要点(对照论文与作者参考实现): + * - 结构与 Cuckoo Filter 类似:每个元素映射到两个 bucket(i1 与 i2),每个 bucket 4 个 slot + * - Alternate Range(AR):i2 在 i1 的局部范围内(提升局部性并提高高负载下成功率) + * - Vacuuming:插入遇到满桶时,优先做“局部换位路径搜索”(一跳前瞻),把空位“吸”过来,降低反复踢出重试 + * + * 注意: + * - 本实现为工程可用版本,核心算法与 vacuuming 逻辑对齐论文/作者代码,但未做 semi-sorting 的 bit packing; + * 为 API Key 防护场景选择 32-bit fingerprint 时仍然具备非常好的空间与性能表现。 + * - 删除是“近似删除”:理论上仍可能因 fingerprint 碰撞导致误删(概率与 FPR 同数量级)。 + * 对安全敏感场景建议使用 32-bit fingerprint,降低碰撞与误删风险。 + */ + +export type VacuumFilterInitOptions = { + /** + * 预期最多插入的元素数量(用于计算 bucket 数量与装载率)。 + * 该值越接近实际峰值,空间利用率越高;取值偏小可能导致插入失败。 + */ + maxItems: number; + /** + * 每个 bucket 的 slot 数;论文与常见实现为 4(此实现固定为 4)。 + */ + bucketSize?: 4; + /** + * fingerprint 位数(1~32)。 + * - 位数越大,假阳性越低,但占用内存越多。 + * - 推荐:32(用于安全敏感场景,尽量避免碰撞/误删风险)。 + */ + fingerprintBits?: number; + /** + * 最大踢出次数(失败后返回 false,调用方可选择扩容重建)。 + */ + maxKickSteps?: number; + /** + * 哈希种子(用于对抗可控输入导致的退化/碰撞攻击)。 + * - 不传则进程启动时随机生成(每次重启不同)。 + */ + seed?: Uint8Array | string; + /** + * 目标装载率(越高越省内存,但插入更困难)。 + * 论文/参考实现默认约 0.96(结合 VF 的 vacuuming 仍可维持高成功率)。 + */ + targetLoadFactor?: number; +}; + +type UndoLog = { pos: number[]; prev: number[] }; + +class XorShift32 { + private state: number; + + constructor(seed: number) { + const s = seed >>> 0; + // 避免全 0 状态(xorshift 会卡死) + this.state = s === 0 ? 0x9e3779b9 : s; + } + + nextU32(): number { + // xorshift32 + let x = this.state >>> 0; + x ^= (x << 13) >>> 0; + x ^= x >>> 17; + x ^= (x << 5) >>> 0; + this.state = x >>> 0; + return this.state; + } + + nextInt(maxExclusive: number): number { + return maxExclusive <= 1 ? 0 : this.nextU32() % maxExclusive; + } + + nextBool(): boolean { + return (this.nextU32() & 1) === 1; + } +} + +function upperPower2(x: number): number { + if (x <= 1) return 1; + let ret = 1; + // 注意:不要用位运算左移(JS 位运算是 32-bit),用乘法避免大数溢出/变负数 + while (ret < x) ret *= 2; + return ret; +} + +function roundUpToMultiple(x: number, base: number): number { + if (base <= 0) return x; + const r = x % base; + return r === 0 ? x : x + (base - r); +} + +// 解方程:1 + x(logc - logx + 1) - c = 0(参考实现同名函数) +function solveEquation(c: number): number { + let x = c + 0.1; + let guard = 0; + const f = (v: number) => 1 + v * (Math.log(c) - Math.log(v) + 1) - c; + const fd = (v: number) => Math.log(c) - Math.log(v); + while (Math.abs(f(x)) > 0.001 && guard++ < 10_000) { + x -= f(x) / fd(x); + if (!Number.isFinite(x) || x <= 0) { + // 数值异常时回退到一个保守值,避免死循环 + return c + 1; + } + } + return x; +} + +// balls-in-bins 最大负载上界(参考实现同名函数) +function ballsInBinsMaxLoad(balls: number, bins: number): number { + const m = balls; + const n = bins; + if (n <= 1) return m; + + const c = m / (n * Math.log(n)); + // 更准确的 bound..(c < 5 区间) + if (c < 5) { + const dc = solveEquation(c); + return (dc - 1 + 2) * Math.log(n); + } + + return m / n + 1.5 * Math.sqrt((2 * m * Math.log(n)) / n); +} + +/** + * 选择合适的 Alternate Range(power-of-two),移植自作者参考实现 proper_alt_range。 + * + * 直觉: + * - AR 越小:局部性越好,但高负载下更容易出现“局部拥堵”导致插入失败 + * - AR 越大:更容易找到空位,但局部性变差 + * - Vacuum Filter 采用多档 AR(按 tag 的低位分组)兼顾两者 + */ +function properAltRange(bucketCount: number, groupIndex: number): number { + const b = 4; // slots per bucket + const lf = 0.95; // target load factor (用于估算) + let altRange = 8; + while (altRange < bucketCount) { + const f = (4 - groupIndex) * 0.25; // group 占比(参考实现) + if ( + ballsInBinsMaxLoad(f * b * lf * bucketCount, bucketCount / altRange) < + 0.97 * b * altRange + ) { + return altRange; + } + // 同 upperPower2:避免 32-bit 位移溢出 + altRange *= 2; + } + return altRange; +} + +function normalizeSeed(seed?: VacuumFilterInitOptions["seed"]): Uint8Array { + if (!seed) return randomBytes(16); + if (typeof seed === "string") return textEncoder.encode(seed); + return new Uint8Array(seed); +} + +function readU32LE(bytes: Uint8Array, offset: number): number { + return ( + ((bytes[offset] ?? 0) | + ((bytes[offset + 1] ?? 0) << 8) | + ((bytes[offset + 2] ?? 0) << 16) | + ((bytes[offset + 3] ?? 0) << 24)) >>> + 0 + ); +} + +// MurmurHash3 x86 32-bit x2(共享同一份 bytes 扫描;用于生成 index/tag) +function murmur3X86_32x2( + bytes: Uint8Array, + len: number, + seedA: number, + seedB: number, + out: Uint32Array +): void { + let hA = seedA >>> 0; + let hB = seedB >>> 0; + const c1 = 0xcc9e2d51; + const c2 = 0x1b873593; + + const length = len >>> 0; + const nblocks = (length / 4) | 0; + const blockLen = nblocks * 4; + + for (let base = 0; base < blockLen; base += 4) { + let k = + (bytes[base] | + (bytes[base + 1] << 8) | + (bytes[base + 2] << 16) | + (bytes[base + 3] << 24)) >>> + 0; + + k = Math.imul(k, c1) >>> 0; + k = ((k << 15) | (k >>> 17)) >>> 0; + k = Math.imul(k, c2) >>> 0; + + hA ^= k; + hA = ((hA << 13) | (hA >>> 19)) >>> 0; + hA = (Math.imul(hA, 5) + 0xe6546b64) >>> 0; + + hB ^= k; + hB = ((hB << 13) | (hB >>> 19)) >>> 0; + hB = (Math.imul(hB, 5) + 0xe6546b64) >>> 0; + } + + // tail + let k1 = 0; + const tail = blockLen; + const rem = length & 3; + if (rem >= 3) { + k1 ^= bytes[tail + 2] << 16; + } + if (rem >= 2) { + k1 ^= bytes[tail + 1] << 8; + } + if (rem >= 1) { + k1 ^= bytes[tail]; + k1 = Math.imul(k1, c1) >>> 0; + k1 = ((k1 << 15) | (k1 >>> 17)) >>> 0; + k1 = Math.imul(k1, c2) >>> 0; + hA ^= k1; + hB ^= k1; + } + + // fmix (A) + hA ^= length; + hA ^= hA >>> 16; + hA = Math.imul(hA, 0x85ebca6b) >>> 0; + hA ^= hA >>> 13; + hA = Math.imul(hA, 0xc2b2ae35) >>> 0; + hA ^= hA >>> 16; + + // fmix (B) + hB ^= length; + hB ^= hB >>> 16; + hB = Math.imul(hB, 0x85ebca6b) >>> 0; + hB ^= hB >>> 13; + hB = Math.imul(hB, 0xc2b2ae35) >>> 0; + hB ^= hB >>> 16; + + out[0] = hA >>> 0; + out[1] = hB >>> 0; +} + +export class VacuumFilter { + private readonly fingerprintBits: number; + private readonly tagMask: number; + private readonly maxKickSteps: number; + private readonly seed: Uint8Array; + private readonly hashSeedA: number; + private readonly hashSeedB: number; + private readonly rng: XorShift32; + + // AR 组数固定为 4(与论文/参考实现一致) + private readonly lenMasks: [number, number, number, number]; + + private readonly numBuckets: number; + private readonly table: Uint32Array; + private numItems = 0; + + // 热路径优化:避免 TextEncoder.encode 分配;每次 has/add/delete 复用同一块 scratch + private scratch: Uint8Array = new Uint8Array(DEFAULT_SCRATCH_BYTES); + private readonly hashOut: Uint32Array = new Uint32Array(2); + private tmpIndex = 0; + private tmpTag = 0; + + constructor(options: VacuumFilterInitOptions) { + if (!Number.isFinite(options.maxItems) || options.maxItems <= 0) { + throw new Error("VacuumFilter: maxItems 必须为正数"); + } + + const rawFingerprintBits = options.fingerprintBits; + const fingerprintBits = + typeof rawFingerprintBits === "number" && Number.isFinite(rawFingerprintBits) + ? Math.floor(rawFingerprintBits) + : 32; + this.fingerprintBits = Math.max(1, Math.min(32, fingerprintBits)); + + const rawMaxKickSteps = options.maxKickSteps; + const maxKickSteps = + typeof rawMaxKickSteps === "number" && Number.isFinite(rawMaxKickSteps) + ? Math.floor(rawMaxKickSteps) + : 500; + this.maxKickSteps = Math.max(1, maxKickSteps); + this.seed = normalizeSeed(options.seed); + this.hashSeedA = (readU32LE(this.seed, 0) ^ 0x6a09e667) >>> 0; + this.hashSeedB = (readU32LE(this.seed, 4) ^ 0xbb67ae85) >>> 0; + this.rng = new XorShift32(readU32LE(this.seed, 8) ^ 0x3c6ef372); + + // tagMask:用于从哈希中截取 fingerprint(32-bit 特判;避免 1<<31 的有符号溢出陷阱) + this.tagMask = + this.fingerprintBits === 32 + ? 0xffffffff + : (0xffffffff >>> (32 - this.fingerprintBits)) >>> 0; + + const rawTargetLoadFactor = options.targetLoadFactor; + const rawTargetLoadFactorValue = + typeof rawTargetLoadFactor === "number" && Number.isFinite(rawTargetLoadFactor) + ? rawTargetLoadFactor + : 0.96; + const targetLoadFactor = Math.max(0.5, Math.min(0.99, rawTargetLoadFactorValue)); + + // 与作者实现一致:numBuckets ≈ maxItems / (0.96 * 4) + const maxItems = Math.ceil(options.maxItems); + // 工程上更保守:用 ceil 保证“按目标装载率”时能容纳 maxItems + let bucketCount = Math.ceil(maxItems / targetLoadFactor / BUCKET_SIZE); + bucketCount = Math.max(bucketCount, 128); // 避免过小导致 AR 设置异常 + + // 小规模表:使用更小的段长,避免强制对齐到 1024 导致空间浪费 + // 参考作者另一份实现(vacuum.h)的初始化策略。 + if (bucketCount < 10_000) { + const bigSeg = + bucketCount < 256 ? upperPower2(bucketCount) : upperPower2(Math.floor(bucketCount / 4)); + bucketCount = roundUpToMultiple(bucketCount, bigSeg); + + const mask = bigSeg - 1; + this.lenMasks = [mask, mask, mask, mask]; + this.numBuckets = bucketCount; + this.table = new Uint32Array(this.numBuckets * BUCKET_SIZE); + return; + } + + // Alternate Range 设置(aligned=false 路径) + const bigSeg = Math.max(1024, properAltRange(bucketCount, 0)); + bucketCount = roundUpToMultiple(bucketCount, bigSeg); + + const l0 = bigSeg - 1; + const l1 = properAltRange(bucketCount, 1) - 1; + const l2 = properAltRange(bucketCount, 2) - 1; + // 最后一组扩大一倍(参考实现) + const l3 = properAltRange(bucketCount, 3) * 2 - 1; + + this.lenMasks = [l0, l1, l2, l3]; + + // 重要:保证 bucketCount 是所有 segment length 的倍数,避免 AltIndex 落到末段“越界” + // 由于这些长度都是 2 的幂,取最大值即可覆盖其它组(大幂必为小幂的倍数)。 + const segLens = [l0 + 1, l1 + 1, l2 + 1, l3 + 1]; + const maxSegLen = Math.max(...segLens); + this.numBuckets = roundUpToMultiple(bucketCount, upperPower2(maxSegLen)); + this.table = new Uint32Array(this.numBuckets * BUCKET_SIZE); + } + + /** + * 当前已插入的元素数量(插入成功才计数) + */ + size(): number { + return this.numItems; + } + + /** + * 表容量(slot 总数) + */ + capacitySlots(): number { + return this.numBuckets * BUCKET_SIZE; + } + + /** + * 负载因子(占用 slot / 总 slot) + */ + loadFactor(): number { + return this.capacitySlots() === 0 ? 0 : this.numItems / this.capacitySlots(); + } + + /** + * 判断是否可能存在(true=可能存在;false=一定不存在) + */ + has(key: string): boolean { + this.indexTag(key); + const i1 = this.tmpIndex; + const tag = this.tmpTag; + + const table = this.table; + let start = i1 * BUCKET_SIZE; + if ( + table[start] === tag || + table[start + 1] === tag || + table[start + 2] === tag || + table[start + 3] === tag + ) { + return true; + } + + const i2 = this.altIndex(i1, tag); + start = i2 * BUCKET_SIZE; + return ( + table[start] === tag || + table[start + 1] === tag || + table[start + 2] === tag || + table[start + 3] === tag + ); + } + + /** + * 插入(成功返回 true;失败返回 false) + */ + add(key: string): boolean { + this.indexTag(key); + return this.addIndexTag(this.tmpIndex, this.tmpTag); + } + + /** + * 删除(成功返回 true;未找到返回 false) + * + * 注意:这是“近似删除”,存在极低概率误删(fingerprint 碰撞导致不可区分)。 + */ + delete(key: string): boolean { + this.indexTag(key); + const i1 = this.tmpIndex; + const tag = this.tmpTag; + const i2 = this.altIndex(i1, tag); + + const ok1 = this.deleteFromBucket(i1, tag); + if (ok1) { + this.numItems--; + return true; + } + + const ok2 = this.deleteFromBucket(i2, tag); + if (ok2) { + this.numItems--; + return true; + } + + return false; + } + + // ==================== 内部实现 ==================== + + private indexTag(key: string): void { + // 使用 seeded MurmurHash3(32-bit)生成确定性哈希,降低可控输入退化风险 + // 关键优化:ASCII 快路径(API Key/ID 通常为 ASCII),避免 TextEncoder.encode 分配 + const strLen = key.length; + if (this.scratch.length < strLen) { + this.scratch = new Uint8Array(Math.max(this.scratch.length * 2, strLen)); + } + + let asciiLen = 0; + for (; asciiLen < strLen; asciiLen++) { + const c = key.charCodeAt(asciiLen); + if (c > 0x7f) break; + this.scratch[asciiLen] = c; + } + + if (asciiLen === strLen) { + murmur3X86_32x2(this.scratch, strLen, this.hashSeedA, this.hashSeedB, this.hashOut); + } else { + // 非 ASCII:交给 TextEncoder(少见路径) + const keyBytes = textEncoder.encode(key); + murmur3X86_32x2(keyBytes, keyBytes.length, this.hashSeedA, this.hashSeedB, this.hashOut); + } + + const hvIndex = this.hashOut[0] >>> 0; + const hvTag = this.hashOut[1] >>> 0; + + // 参考实现使用 `hash % numBuckets`。这里保持简单、快速(即便 numBuckets 非 2 的幂也可用)。 + const index = hvIndex % this.numBuckets; + + let tag = (hvTag & this.tagMask) >>> 0; + if (tag === 0) tag = 1; + + this.tmpIndex = index; + this.tmpTag = tag; + } + + private altIndex(index: number, tag: number): number { + const segMask = this.lenMasks[tag & 3]; + + // delta = (tag * C) & segMask,若为 0 则置为 1,避免 alt==index + let delta = (Math.imul(tag, 0x5bd1e995) >>> 0) & segMask; + if (delta === 0) delta = 1; + + // segLen 为 2 的幂:index % segLen 等价于 index & segMask(index 来自 32-bit hash,安全使用位运算) + const offset = (index & segMask) >>> 0; + const altOffset = (offset ^ delta) >>> 0; + return index - offset + altOffset; + } + + private bucketStart(index: number): number { + return index * BUCKET_SIZE; + } + + private writeSlot(pos: number, value: number, undo?: UndoLog): void { + if (undo) { + undo.pos.push(pos); + undo.prev.push(this.table[pos]); + } + this.table[pos] = value; + } + + private rollback(undo: UndoLog): void { + for (let i = undo.pos.length - 1; i >= 0; i--) { + this.table[undo.pos[i]] = undo.prev[i]; + } + } + + private insertTagToBucket(index: number, tag: number, undo?: UndoLog): boolean { + const start = this.bucketStart(index); + if (this.table[start] === 0) { + this.writeSlot(start, tag, undo); + return true; + } + if (this.table[start + 1] === 0) { + this.writeSlot(start + 1, tag, undo); + return true; + } + if (this.table[start + 2] === 0) { + this.writeSlot(start + 2, tag, undo); + return true; + } + if (this.table[start + 3] === 0) { + this.writeSlot(start + 3, tag, undo); + return true; + } + return false; + } + + private deleteFromBucket(index: number, tag: number): boolean { + const start = this.bucketStart(index); + if (this.table[start] === tag) { + this.table[start] = 0; + return true; + } + if (this.table[start + 1] === tag) { + this.table[start + 1] = 0; + return true; + } + if (this.table[start + 2] === tag) { + this.table[start + 2] = 0; + return true; + } + if (this.table[start + 3] === tag) { + this.table[start + 3] = 0; + return true; + } + return false; + } + + private bucketOccupancy(index: number): number { + const start = this.bucketStart(index); + return ( + (this.table[start] !== 0 ? 1 : 0) + + (this.table[start + 1] !== 0 ? 1 : 0) + + (this.table[start + 2] !== 0 ? 1 : 0) + + (this.table[start + 3] !== 0 ? 1 : 0) + ); + } + + private addIndexTag(index: number, tag: number): boolean { + const i1 = index; + const i2 = this.altIndex(i1, tag); + + const occ1 = this.bucketOccupancy(i1); + const occ2 = this.bucketOccupancy(i2); + + // 先尝试插入到“更空”的 bucket(参考实现:优先更少元素的桶) + const first = occ1 <= occ2 ? i1 : i2; + const second = first === i1 ? i2 : i1; + + if (this.insertTagToBucket(first, tag) || this.insertTagToBucket(second, tag)) { + this.numItems++; + return true; + } + + // 两个 bucket 都满:进入踢出 + vacuuming + // 关键语义:若最终插入失败,必须回滚所有修改,避免“丢元素”导致假阴性。 + const undo: UndoLog = { pos: [], prev: [] }; + let curIndex = this.rng.nextBool() ? i1 : i2; + let curTag = tag; + + for (let count = 0; count < this.maxKickSteps; count++) { + // 1) 可能因上一次换位导致当前桶出现空位(保守再试一次) + if (this.insertTagToBucket(curIndex, curTag, undo)) { + this.numItems++; + return true; + } + + // 2) Vacuuming(一跳前瞻):尝试把当前桶内某个 tag 挪到它的 alternate bucket 的空位 + const start = this.bucketStart(curIndex); + for (let slot = 0; slot < BUCKET_SIZE; slot++) { + const existing = this.table[start + slot]; + if (existing === 0) continue; + const alt = this.altIndex(curIndex, existing); + if (this.insertTagToBucket(alt, existing, undo)) { + // 将空位“吸”到当前 slot:existing 移走,curTag 填入 + this.writeSlot(start + slot, curTag, undo); + this.numItems++; + return true; + } + } + + // 3) 随机踢出一个 tag,继续链式搬运 + const r = this.rng.nextInt(BUCKET_SIZE); + const oldTag = this.table[start + r]; + this.writeSlot(start + r, curTag, undo); + curTag = oldTag; + curIndex = this.altIndex(curIndex, curTag); + } + + this.rollback(undo); + return false; + } +} diff --git a/src/repository/key.ts b/src/repository/key.ts index 9a5631cfd..65ffdd6ff 100644 --- a/src/repository/key.ts +++ b/src/repository/key.ts @@ -3,6 +3,16 @@ import { and, count, desc, eq, gt, gte, inArray, isNull, lt, or, sql, sum } from "drizzle-orm"; import { db } from "@/drizzle/db"; import { keys, messageRequest, providers, users } from "@/drizzle/schema"; +import { CHANNEL_API_KEYS_UPDATED, publishCacheInvalidation } from "@/lib/redis/pubsub"; +import { + cacheActiveKey, + cacheAuthResult, + cacheUser, + getCachedActiveKey, + getCachedUser, + invalidateCachedKey, +} from "@/lib/security/api-key-auth-cache"; +import { apiKeyVacuumFilter } from "@/lib/security/api-key-vacuum-filter"; import { Decimal, toCostDecimal } from "@/lib/utils/currency"; import type { CreateKeyData, Key, UpdateKeyData } from "@/types/key"; import type { User } from "@/types/user"; @@ -161,12 +171,41 @@ export async function createKey(keyData: CreateKeyData): Promise { limitTotalUsd: keys.limitTotalUsd, limitConcurrentSessions: keys.limitConcurrentSessions, providerGroup: keys.providerGroup, + cacheTtlPreference: keys.cacheTtlPreference, createdAt: keys.createdAt, updatedAt: keys.updatedAt, deletedAt: keys.deletedAt, }); - return toKey(key); + const created = toKey(key); + // 将新建 key 写入 Vacuum Filter(提升新 key 的即时可用性;失败不影响正确性) + try { + apiKeyVacuumFilter.noteExistingKey(created.key); + } catch { + // ignore + } + // Redis 缓存(最佳努力,不影响正确性) + // 注意:多实例环境下其它实例可能在 Vacuum Filter 尚未重建时收到新 key 的请求。 + // 为减少“新 key 立刻使用偶发 401”的窗口,这里会等待 Redis 写入/广播; + // 但必须设置超时上限,避免 Redis 慢/不可用时拖慢 key 创建。 + const redisBestEffortTimeoutMs = 200; + const redisTasks: Array> = []; + + redisTasks.push(cacheActiveKey(created).catch(() => {})); + + // 多实例:广播 key 集合变更,触发其它实例重建 Vacuum Filter,避免误拒绝 + const rateLimitRaw = process.env.ENABLE_RATE_LIMIT?.trim(); + if (process.env.REDIS_URL && rateLimitRaw !== "false" && rateLimitRaw !== "0") { + redisTasks.push(publishCacheInvalidation(CHANNEL_API_KEYS_UPDATED).catch(() => {})); + } + + if (redisTasks.length > 0) { + await Promise.race([ + Promise.all(redisTasks), + new Promise((resolve) => setTimeout(resolve, redisBestEffortTimeoutMs)), + ]); + } + return created; } export async function updateKey(id: number, keyData: UpdateKeyData): Promise { @@ -232,7 +271,17 @@ export async function updateKey(id: number, keyData: UpdateKeyData): Promise {}); + } else { + await invalidateCachedKey(updated.key).catch(() => {}); + } + return updated; } export async function findActiveKeyByUserIdAndName( @@ -394,12 +443,34 @@ export async function deleteKey(id: number): Promise { .update(keys) .set({ deletedAt: new Date() }) .where(and(eq(keys.id, id), isNull(keys.deletedAt))) - .returning({ id: keys.id }); + .returning({ id: keys.id, key: keys.key }); + if (result.length > 0) { + await invalidateCachedKey(result[0].key).catch(() => {}); + } return result.length > 0; } export async function findActiveKeyByKeyString(keyString: string): Promise { + const vfSaysMissing = apiKeyVacuumFilter.isDefinitelyNotPresent(keyString) === true; + + // Redis 缓存命中:避免打 DB + const cached = await getCachedActiveKey(keyString); + if (cached) { + // 多实例一致性:若 Vacuum Filter 判定缺失但 Redis 命中,说明本机 filter 可能滞后。 + // 最佳努力将 key 写入本机 filter(不影响正确性,仅提升后续性能)。 + if (vfSaysMissing) { + apiKeyVacuumFilter.noteExistingKey(keyString); + } + return cached; + } + + // Vacuum Filter 负向短路:肯定不存在则直接返回 null,避免打 DB + // 注意:此处必须放在 Redis 读取之后,避免多实例环境中新建 key 的短暂误拒绝窗口。 + if (vfSaysMissing) { + return null; + } + const [key] = await db .select({ id: keys.id, @@ -418,6 +489,7 @@ export async function findActiveKeyByKeyString(keyString: string): Promise {}); + return active; } // 验证 API Key 并返回用户信息 export async function validateApiKeyAndGetUser( keyString: string ): Promise<{ user: User; key: Key } | null> { + const vfSaysMissing = apiKeyVacuumFilter.isDefinitelyNotPresent(keyString) === true; + + // 默认鉴权链路:Vacuum Filter -> Redis -> DB + const cachedKey = await getCachedActiveKey(keyString); + if (cachedKey) { + // 多实例一致性:若 Vacuum Filter 判定缺失但 Redis 命中,说明本机 filter 可能滞后。 + // 最佳努力将 key 写入本机 filter(不影响正确性,仅提升后续性能)。 + if (vfSaysMissing) { + apiKeyVacuumFilter.noteExistingKey(keyString); + } + + const cachedUser = await getCachedUser(cachedKey.userId); + if (cachedUser) { + return { user: cachedUser, key: cachedKey }; + } + + // user 缓存 miss:仅补齐 user(相较 join 更轻量) + const [userRow] = await db + .select({ + id: users.id, + name: users.name, + description: users.description, + role: users.role, + rpm: users.rpmLimit, + dailyQuota: users.dailyLimitUsd, + providerGroup: users.providerGroup, + tags: users.tags, + createdAt: users.createdAt, + updatedAt: users.updatedAt, + deletedAt: users.deletedAt, + limit5hUsd: users.limit5hUsd, + limitWeeklyUsd: users.limitWeeklyUsd, + limitMonthlyUsd: users.limitMonthlyUsd, + limitTotalUsd: users.limitTotalUsd, + limitConcurrentSessions: users.limitConcurrentSessions, + dailyResetMode: users.dailyResetMode, + dailyResetTime: users.dailyResetTime, + isEnabled: users.isEnabled, + expiresAt: users.expiresAt, + allowedClients: users.allowedClients, + allowedModels: users.allowedModels, + }) + .from(users) + .where(and(eq(users.id, cachedKey.userId), isNull(users.deletedAt))); + + if (!userRow) { + // join 语义:用户被删除则 key 无效;顺带清理 key 缓存避免重复 miss + invalidateCachedKey(keyString).catch(() => {}); + return null; + } + + const user = toUser(userRow); + cacheUser(user).catch(() => {}); + return { user, key: cachedKey }; + } + + // Vacuum Filter 负向短路:肯定不存在则直接返回 null,避免打 DB + // 注意:此处必须放在 Redis 读取之后,避免多实例环境中新建 key 的短暂误拒绝窗口。 + if (vfSaysMissing) { + return null; + } + const result = await db .select({ // Key fields @@ -551,6 +687,8 @@ export async function validateApiKeyAndGetUser( deletedAt: row.keyDeletedAt, }); + // 最佳努力:写入 Redis 缓存(不影响正确性) + cacheAuthResult(keyString, { user, key }).catch(() => {}); return { user, key }; } diff --git a/src/repository/user.ts b/src/repository/user.ts index e112603ee..350ccbf6c 100644 --- a/src/repository/user.ts +++ b/src/repository/user.ts @@ -3,6 +3,7 @@ import { and, asc, eq, isNull, type SQL, sql } from "drizzle-orm"; import { db } from "@/drizzle/db"; import { keys as keysTable, users } from "@/drizzle/schema"; +import { cacheUser, invalidateCachedUser } from "@/lib/security/api-key-auth-cache"; import type { CreateUserData, UpdateUserData, User } from "@/types/user"; import { toUser } from "./_shared/transformers"; @@ -86,7 +87,9 @@ export async function createUser(userData: CreateUserData): Promise { allowedModels: users.allowedModels, }); - return toUser(user); + const created = toUser(user); + await cacheUser(created).catch(() => {}); + return created; } export async function findUserList(limit: number = 50, offset: number = 0): Promise { @@ -432,7 +435,9 @@ export async function updateUser(id: number, userData: UpdateUserData): Promise< if (!user) return null; - return toUser(user); + const updated = toUser(user); + await cacheUser(updated).catch(() => {}); + return updated; } export async function deleteUser(id: number): Promise { @@ -442,6 +447,9 @@ export async function deleteUser(id: number): Promise { .where(and(eq(users.id, id), isNull(users.deletedAt))) .returning({ id: users.id }); + if (result.length > 0) { + await invalidateCachedUser(id).catch(() => {}); + } return result.length > 0; } @@ -456,6 +464,7 @@ export async function markUserExpired(userId: number): Promise { .where(and(eq(users.id, userId), eq(users.isEnabled, true), isNull(users.deletedAt))) .returning({ id: users.id }); + await invalidateCachedUser(userId).catch(() => {}); return result.length > 0; } diff --git a/tests/unit/security/api-key-auth-cache-redis-key.test.ts b/tests/unit/security/api-key-auth-cache-redis-key.test.ts new file mode 100644 index 000000000..7f6a3ec63 --- /dev/null +++ b/tests/unit/security/api-key-auth-cache-redis-key.test.ts @@ -0,0 +1,461 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { createHash, webcrypto } from "node:crypto"; +import type { Key } from "@/types/key"; +import type { User } from "@/types/user"; + +type RedisPipelineLike = { + setex(key: string, ttlSeconds: number, value: string): RedisPipelineLike; + del(key: string): RedisPipelineLike; + exec(): Promise; +}; + +type RedisLike = { + get(key: string): Promise; + setex(key: string, ttlSeconds: number, value: string): Promise; + del(key: string): Promise; + pipeline(): RedisPipelineLike; +}; + +type PipelineOp = + | { kind: "setex"; key: string; ttlSeconds: number; value: string } + | { kind: "del"; key: string }; + +class FakeRedisPipeline implements RedisPipelineLike { + readonly ops: PipelineOp[] = []; + readonly exec = vi.fn(async () => { + for (const op of this.ops) { + if (op.kind === "setex") { + this.parent.store.set(op.key, op.value); + } else { + this.parent.store.delete(op.key); + } + } + return []; + }); + + constructor(private readonly parent: FakeRedis) {} + + setex(key: string, ttlSeconds: number, value: string): RedisPipelineLike { + this.ops.push({ kind: "setex", key, ttlSeconds, value }); + return this; + } + + del(key: string): RedisPipelineLike { + this.ops.push({ kind: "del", key }); + return this; + } +} + +class FakeRedis implements RedisLike { + readonly store = new Map(); + readonly get = vi.fn(async (key: string) => this.store.get(key) ?? null); + readonly setex = vi.fn(async (key: string, _ttlSeconds: number, value: string) => { + this.store.set(key, value); + return "OK"; + }); + readonly del = vi.fn(async (key: string) => (this.store.delete(key) ? 1 : 0)); + readonly pipeline = vi.fn(() => { + const pipeline = new FakeRedisPipeline(this); + this.pipelines.push(pipeline); + return pipeline; + }); + + readonly pipelines: FakeRedisPipeline[] = []; +} + +let currentRedis: FakeRedis | null = null; +const getRedisClient = vi.fn(() => currentRedis); + +vi.mock("@/lib/redis/client", () => ({ + getRedisClient, +})); + +function sha256HexNode(value: string): string { + return createHash("sha256").update(value).digest("hex"); +} + +function buildKey(overrides?: Partial): Key { + return { + id: 1, + userId: 10, + name: "k1", + key: "sk-secret", + isEnabled: true, + expiresAt: undefined, + canLoginWebUi: true, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 0, + providerGroup: null, + cacheTtlPreference: null, + createdAt: new Date("2026-01-01T00:00:00.000Z"), + updatedAt: new Date("2026-01-02T00:00:00.000Z"), + deletedAt: undefined, + ...overrides, + }; +} + +function buildUser(overrides?: Partial): User { + return { + id: 10, + name: "u1", + description: "", + role: "user", + rpm: null, + dailyQuota: null, + providerGroup: null, + tags: [], + createdAt: new Date("2026-01-01T00:00:00.000Z"), + updatedAt: new Date("2026-01-02T00:00:00.000Z"), + deletedAt: undefined, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + isEnabled: true, + expiresAt: null, + allowedClients: [], + allowedModels: [], + ...overrides, + }; +} + +function setEnv(values: Record): void { + for (const [key, value] of Object.entries(values)) { + if (value === undefined) { + // eslint-disable-next-line @typescript-eslint/no-dynamic-delete + delete process.env[key]; + } else { + process.env[key] = value; + } + } +} + +describe("ApiKeyAuthCache:Redis key(哈希/命名/TTL/失效)", () => { + const originalEnv: Record = {}; + + beforeEach(() => { + vi.resetModules(); + vi.clearAllMocks(); + currentRedis = new FakeRedis(); + + // 记录并覆盖本文件会改动的环境变量(避免泄漏到其它用例) + for (const k of [ + "CI", + "NEXT_PHASE", + "NEXT_RUNTIME", + "ENABLE_RATE_LIMIT", + "REDIS_URL", + "ENABLE_API_KEY_REDIS_CACHE", + "API_KEY_AUTH_CACHE_TTL_SECONDS", + ]) { + originalEnv[k] = process.env[k]; + } + + setEnv({ + CI: "false", + NEXT_PHASE: "", + NEXT_RUNTIME: "nodejs", + ENABLE_RATE_LIMIT: "true", + REDIS_URL: "redis://localhost:6379", + ENABLE_API_KEY_REDIS_CACHE: "true", + API_KEY_AUTH_CACHE_TTL_SECONDS: "60", + }); + + // 确保测试环境一定有 WebCrypto subtle(不依赖 Node 版本/运行模式) + vi.stubGlobal("crypto", webcrypto as unknown as Crypto); + }); + + afterEach(() => { + vi.useRealTimers(); + vi.unstubAllGlobals(); + setEnv(originalEnv); + currentRedis = null; + }); + + test("cacheActiveKey:应使用 SHA-256(keyString) 作为 Redis key,且不泄漏明文 key", async () => { + const { cacheActiveKey } = await import("@/lib/security/api-key-auth-cache"); + const key = buildKey({ key: "sk-secret" }); + + await cacheActiveKey(key); + + const expectedRedisKey = `api_key_auth:v1:key:${sha256HexNode("sk-secret")}`; + expect(getRedisClient).toHaveBeenCalled(); + expect(currentRedis?.setex).toHaveBeenCalledTimes(1); + + const [redisKey, ttlSeconds, payload] = currentRedis!.setex.mock.calls[0]; + expect(redisKey).toBe(expectedRedisKey); + expect(redisKey).not.toContain("sk-secret"); + expect(ttlSeconds).toBe(60); + expect(typeof payload).toBe("string"); + expect(payload).not.toContain("sk-secret"); + + const parsed = JSON.parse(payload) as { v: number; key: Record }; + expect(parsed.v).toBe(1); + // payload.key 不应包含明文 key 字段 + expect(Object.hasOwn(parsed.key, "key")).toBe(false); + }); + + test("cacheActiveKey + getCachedActiveKey:应可回读并水合 Date 字段", async () => { + const { cacheActiveKey, getCachedActiveKey } = await import("@/lib/security/api-key-auth-cache"); + const key = buildKey({ key: "sk-roundtrip" }); + + await cacheActiveKey(key); + const cached = await getCachedActiveKey("sk-roundtrip"); + + expect(cached?.key).toBe("sk-roundtrip"); + expect(cached?.id).toBe(1); + expect(cached?.userId).toBe(10); + expect(cached?.createdAt).toBeInstanceOf(Date); + expect(cached?.updatedAt).toBeInstanceOf(Date); + expect(cached?.createdAt.toISOString()).toBe(key.createdAt.toISOString()); + expect(cached?.updatedAt.toISOString()).toBe(key.updatedAt.toISOString()); + }); + + test("getCachedActiveKey:payload 版本不匹配时应删除缓存并返回 null", async () => { + const { getCachedActiveKey } = await import("@/lib/security/api-key-auth-cache"); + const keyString = "sk-version-mismatch"; + const redisKey = `api_key_auth:v1:key:${sha256HexNode(keyString)}`; + + currentRedis!.store.set( + redisKey, + JSON.stringify({ + v: 999, + key: { + id: 1, + userId: 10, + name: "k1", + isEnabled: true, + canLoginWebUi: true, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitConcurrentSessions: 0, + createdAt: "2026-01-01T00:00:00.000Z", + updatedAt: "2026-01-02T00:00:00.000Z", + }, + }) + ); + + await expect(getCachedActiveKey(keyString)).resolves.toBeNull(); + expect(currentRedis!.del).toHaveBeenCalledWith(redisKey); + }); + + describe("getCachedActiveKey:disabled/deleted/expired 应视为失效并清理", () => { + const cases = [ + { name: "disabled", payload: { isEnabled: false } }, + { name: "deleted", payload: { deletedAt: "2026-01-01T00:00:00.000Z" } }, + { name: "expired", payload: { expiresAt: "2026-01-01T00:00:00.000Z" } }, + ] as const; + + test.each(cases)("$name", async ({ name, payload }) => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-01-10T00:00:00.000Z")); + + const { getCachedActiveKey } = await import("@/lib/security/api-key-auth-cache"); + + const keyString = `sk-${name}`; + const redisKey = `api_key_auth:v1:key:${sha256HexNode(keyString)}`; + currentRedis!.store.set( + redisKey, + JSON.stringify({ + v: 1, + key: { + id: 1, + userId: 10, + name: "k1", + isEnabled: true, + canLoginWebUi: true, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitConcurrentSessions: 0, + createdAt: "2026-01-01T00:00:00.000Z", + updatedAt: "2026-01-02T00:00:00.000Z", + ...payload, + }, + }) + ); + + await expect(getCachedActiveKey(keyString)).resolves.toBeNull(); + expect(currentRedis!.del).toHaveBeenCalledWith(redisKey); + }); + }); + + describe("cacheActiveKey:非活跃 key(禁用/已删/已过期/无效 expiresAt)应删除缓存,不应 setex", () => { + const cases: Array<{ name: string; key: Key }> = [ + { name: "disabled", key: buildKey({ key: "sk-disabled", isEnabled: false }) }, + { + name: "deleted", + key: buildKey({ key: "sk-deleted", deletedAt: new Date("2026-01-01T00:00:00.000Z") }), + }, + { + name: "expired", + key: buildKey({ key: "sk-expired", expiresAt: new Date("2026-01-01T00:00:00.000Z") }), + }, + { + name: "invalid_expiresAt", + // @ts-expect-error: 覆盖运行时边界 + key: buildKey({ key: "sk-invalid", expiresAt: "not-a-date" }), + }, + ]; + + test.each(cases)("$name", async ({ key }) => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-01-10T00:00:00.000Z")); + + const { cacheActiveKey } = await import("@/lib/security/api-key-auth-cache"); + + await cacheActiveKey(key); + + const expectedRedisKey = `api_key_auth:v1:key:${sha256HexNode(key.key)}`; + expect(currentRedis!.setex).not.toHaveBeenCalled(); + expect(currentRedis!.del).toHaveBeenCalledWith(expectedRedisKey); + }); + }); + + test("cacheActiveKey:应按 key.expiresAt 剩余时间收敛 TTL(秒)", async () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-01-01T00:00:00.000Z")); + + const { cacheActiveKey } = await import("@/lib/security/api-key-auth-cache"); + const expiresAt = new Date(Date.now() + 30_000); + const key = buildKey({ key: "sk-ttl-cap", expiresAt }); + + await cacheActiveKey(key); + + expect(currentRedis!.setex).toHaveBeenCalledTimes(1); + const [_redisKey, ttlSeconds] = currentRedis!.setex.mock.calls[0]; + expect(ttlSeconds).toBe(30); + }); + + test("API_KEY_AUTH_CACHE_TTL_SECONDS:应 clamp 到最大 3600s", async () => { + setEnv({ API_KEY_AUTH_CACHE_TTL_SECONDS: "999999" }); + + const { cacheActiveKey } = await import("@/lib/security/api-key-auth-cache"); + const key = buildKey({ key: "sk-ttl-max" }); + + await cacheActiveKey(key); + + expect(currentRedis!.setex).toHaveBeenCalledTimes(1); + const [_redisKey, ttlSeconds] = currentRedis!.setex.mock.calls[0]; + expect(ttlSeconds).toBe(3600); + }); + + test("invalidateCachedKey:应删除对应的 hashed Redis key", async () => { + const { invalidateCachedKey } = await import("@/lib/security/api-key-auth-cache"); + const keyString = "sk-invalidate"; + + await invalidateCachedKey(keyString); + + const expectedRedisKey = `api_key_auth:v1:key:${sha256HexNode(keyString)}`; + expect(currentRedis!.del).toHaveBeenCalledWith(expectedRedisKey); + }); + + test("cacheAuthResult:应使用 pipeline 写入 key cache(并遵守活跃条件)", async () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-01-01T00:00:00.000Z")); + + const { cacheAuthResult } = await import("@/lib/security/api-key-auth-cache"); + + await cacheAuthResult("sk-auth", { + key: buildKey({ key: "sk-auth" }), + user: buildUser({ id: 10 }), + }); + + expect(currentRedis!.pipeline).toHaveBeenCalledTimes(1); + const pipeline = currentRedis!.pipelines[0]; + expect(pipeline.exec).toHaveBeenCalledTimes(1); + const keyRedisKey = `api_key_auth:v1:key:${sha256HexNode("sk-auth")}`; + expect(pipeline.ops.some((op) => op.kind === "setex" && op.key === keyRedisKey)).toBe(true); + }); + + test("cacheAuthResult:key 非活跃时应 del key cache(避免脏读误放行)", async () => { + const { cacheAuthResult } = await import("@/lib/security/api-key-auth-cache"); + + await cacheAuthResult("sk-inactive", { + key: buildKey({ key: "sk-inactive", isEnabled: false }), + user: buildUser({ id: 10 }), + }); + + const keyRedisKey = `api_key_auth:v1:key:${sha256HexNode("sk-inactive")}`; + const pipeline = currentRedis!.pipelines[0]; + expect(pipeline.ops.some((op) => op.kind === "del" && op.key === keyRedisKey)).toBe(true); + }); + + test("ENABLE_API_KEY_REDIS_CACHE=false:应完全禁用缓存(不触发 Redis 调用)", async () => { + setEnv({ ENABLE_API_KEY_REDIS_CACHE: "false" }); + const { cacheActiveKey } = await import("@/lib/security/api-key-auth-cache"); + + await cacheActiveKey(buildKey({ key: "sk-disabled-by-env" })); + + expect(getRedisClient).not.toHaveBeenCalled(); + expect(currentRedis!.setex).not.toHaveBeenCalled(); + expect(currentRedis!.del).not.toHaveBeenCalled(); + }); + + test("ENABLE_API_KEY_REDIS_CACHE=0:应完全禁用缓存(不触发 Redis 调用)", async () => { + setEnv({ ENABLE_API_KEY_REDIS_CACHE: "0" }); + const { cacheActiveKey } = await import("@/lib/security/api-key-auth-cache"); + + await cacheActiveKey(buildKey({ key: "sk-disabled-by-env-0" })); + + expect(getRedisClient).not.toHaveBeenCalled(); + expect(currentRedis!.setex).not.toHaveBeenCalled(); + expect(currentRedis!.del).not.toHaveBeenCalled(); + }); + + test("NEXT_RUNTIME=edge:应禁用缓存(避免在 Edge runtime 引入 Node Redis 依赖)", async () => { + setEnv({ NEXT_RUNTIME: "edge" }); + const { getCachedActiveKey } = await import("@/lib/security/api-key-auth-cache"); + + await expect(getCachedActiveKey("sk-edge")).resolves.toBeNull(); + expect(getRedisClient).not.toHaveBeenCalled(); + }); + + test("ENABLE_RATE_LIMIT!=true 或缺少 REDIS_URL:应自动回落(不触发 Redis 调用)", async () => { + setEnv({ ENABLE_RATE_LIMIT: "false" }); + const { cacheActiveKey } = await import("@/lib/security/api-key-auth-cache"); + await cacheActiveKey(buildKey({ key: "sk-fallback-1" })); + expect(getRedisClient).not.toHaveBeenCalled(); + + vi.resetModules(); + vi.clearAllMocks(); + currentRedis = new FakeRedis(); + setEnv({ ENABLE_RATE_LIMIT: "true", REDIS_URL: undefined }); + const { cacheActiveKey: cacheActiveKey2 } = await import("@/lib/security/api-key-auth-cache"); + await cacheActiveKey2(buildKey({ key: "sk-fallback-2" })); + expect(getRedisClient).not.toHaveBeenCalled(); + }); + + test("ENABLE_RATE_LIMIT=1:应允许使用 Redis 缓存(兼容 1/0 写法)", async () => { + setEnv({ ENABLE_RATE_LIMIT: "1" }); + const { cacheActiveKey } = await import("@/lib/security/api-key-auth-cache"); + + await cacheActiveKey(buildKey({ key: "sk-rate-limit-1" })); + + expect(getRedisClient).toHaveBeenCalled(); + expect(currentRedis!.setex).toHaveBeenCalledTimes(1); + }); + + test("crypto.subtle 缺失:sha256Hex 返回 null,应自动回落(不触发 Redis 调用)", async () => { + vi.unstubAllGlobals(); + vi.stubGlobal("crypto", {} as unknown as Crypto); + + const { cacheActiveKey } = await import("@/lib/security/api-key-auth-cache"); + await cacheActiveKey(buildKey({ key: "sk-no-crypto" })); + + expect(currentRedis!.setex).not.toHaveBeenCalled(); + expect(currentRedis!.del).not.toHaveBeenCalled(); + }); + + test("Redis 异常:get/setex 抛错时应 fail-open(不影响鉴权正确性)", async () => { + const { cacheActiveKey, getCachedActiveKey } = await import("@/lib/security/api-key-auth-cache"); + currentRedis!.setex.mockRejectedValueOnce(new Error("REDIS_DOWN")); + await expect(cacheActiveKey(buildKey({ key: "sk-redis-down" }))).resolves.toBeUndefined(); + + currentRedis!.get.mockRejectedValueOnce(new Error("REDIS_DOWN")); + await expect(getCachedActiveKey("sk-redis-down")).resolves.toBeNull(); + }); +}); diff --git a/tests/unit/security/api-key-auth-cache.test.ts b/tests/unit/security/api-key-auth-cache.test.ts new file mode 100644 index 000000000..c81dfdfa1 --- /dev/null +++ b/tests/unit/security/api-key-auth-cache.test.ts @@ -0,0 +1,400 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import type { Key } from "@/types/key"; +import type { User } from "@/types/user"; + +const isDefinitelyNotPresent = vi.fn(() => false); +const noteExistingKey = vi.fn(); + +const cacheActiveKey = vi.fn(async () => {}); +const cacheAuthResult = vi.fn(async () => {}); +const cacheUser = vi.fn(async () => {}); +const getCachedActiveKey = vi.fn<(keyString: string) => Promise>(); +const getCachedUser = vi.fn<(userId: number) => Promise>(); +const invalidateCachedKey = vi.fn(async () => {}); +const publishCacheInvalidation = vi.fn(async () => {}); + +const dbSelect = vi.fn(); +const dbInsert = vi.fn(); +const dbUpdate = vi.fn(); + +vi.mock("@/lib/security/api-key-vacuum-filter", () => ({ + apiKeyVacuumFilter: { + isDefinitelyNotPresent, + noteExistingKey, + startBackgroundReload: vi.fn(), + getStats: vi.fn(), + }, +})); + +vi.mock("@/lib/security/api-key-auth-cache", () => ({ + cacheActiveKey, + cacheAuthResult, + cacheUser, + getCachedActiveKey, + getCachedUser, + invalidateCachedKey, +})); + +vi.mock("@/lib/redis/pubsub", () => ({ + CHANNEL_ERROR_RULES_UPDATED: "cch:cache:error_rules:updated", + CHANNEL_REQUEST_FILTERS_UPDATED: "cch:cache:request_filters:updated", + CHANNEL_SENSITIVE_WORDS_UPDATED: "cch:cache:sensitive_words:updated", + CHANNEL_API_KEYS_UPDATED: "cch:cache:api_keys:updated", + publishCacheInvalidation, + subscribeCacheInvalidation: vi.fn(async () => null), +})); + +vi.mock("@/drizzle/db", () => ({ + db: { + select: dbSelect, + insert: dbInsert, + update: dbUpdate, + }, +})); + +beforeEach(() => { + vi.clearAllMocks(); + isDefinitelyNotPresent.mockReturnValue(false); + getCachedActiveKey.mockResolvedValue(null); + getCachedUser.mockResolvedValue(null); + dbSelect.mockImplementation(() => { + throw new Error("DB_ACCESS"); + }); + dbInsert.mockImplementation(() => { + throw new Error("DB_ACCESS"); + }); + dbUpdate.mockImplementation(() => { + throw new Error("DB_ACCESS"); + }); +}); + +function buildKey(overrides?: Partial): Key { + return { + id: 1, + userId: 10, + name: "k1", + key: "sk-test", + isEnabled: true, + expiresAt: undefined, + canLoginWebUi: true, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 0, + providerGroup: null, + cacheTtlPreference: null, + createdAt: new Date("2026-01-01T00:00:00.000Z"), + updatedAt: new Date("2026-01-02T00:00:00.000Z"), + deletedAt: undefined, + ...overrides, + }; +} + +function buildUser(overrides?: Partial): User { + return { + id: 10, + name: "u1", + description: "", + role: "user", + rpm: null, + dailyQuota: null, + providerGroup: null, + tags: [], + createdAt: new Date("2026-01-01T00:00:00.000Z"), + updatedAt: new Date("2026-01-02T00:00:00.000Z"), + deletedAt: undefined, + limit5hUsd: undefined, + limitWeeklyUsd: undefined, + limitMonthlyUsd: undefined, + limitTotalUsd: null, + limitConcurrentSessions: undefined, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + isEnabled: true, + expiresAt: null, + allowedClients: [], + allowedModels: [], + ...overrides, + }; +} + +describe("API Key 鉴权缓存:VacuumFilter -> Redis -> DB", () => { + test("findActiveKeyByKeyString:Vacuum Filter 误判缺失时,Redis 命中应纠正(避免误拒绝)", async () => { + const cachedKey = buildKey({ key: "sk-cached-missing" }); + isDefinitelyNotPresent.mockReturnValueOnce(true); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + + const { findActiveKeyByKeyString } = await import("@/repository/key"); + await expect(findActiveKeyByKeyString("sk-cached-missing")).resolves.toEqual(cachedKey); + expect(noteExistingKey).toHaveBeenCalledWith("sk-cached-missing"); + expect(dbSelect).not.toHaveBeenCalled(); + }); + + test("validateApiKeyAndGetUser:Vacuum Filter 误判缺失时,Redis key+user 命中应纠正(避免误拒绝)", async () => { + const cachedKey = buildKey({ key: "sk-cached-missing", userId: 10 }); + const cachedUser = buildUser({ id: 10 }); + isDefinitelyNotPresent.mockReturnValueOnce(true); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + getCachedUser.mockResolvedValueOnce(cachedUser); + + const { validateApiKeyAndGetUser } = await import("@/repository/key"); + await expect(validateApiKeyAndGetUser("sk-cached-missing")).resolves.toEqual({ + user: cachedUser, + key: cachedKey, + }); + expect(noteExistingKey).toHaveBeenCalledWith("sk-cached-missing"); + expect(dbSelect).not.toHaveBeenCalled(); + }); + + test("findActiveKeyByKeyString:Redis 命中时应避免打 DB", async () => { + const cachedKey = buildKey({ key: "sk-cached" }); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + dbSelect.mockImplementation(() => { + throw new Error("DB_ACCESS"); + }); + + const { findActiveKeyByKeyString } = await import("@/repository/key"); + await expect(findActiveKeyByKeyString("sk-cached")).resolves.toEqual(cachedKey); + expect(getCachedActiveKey).toHaveBeenCalledWith("sk-cached"); + expect(dbSelect).not.toHaveBeenCalled(); + }); + + test("findActiveKeyByKeyString:VF 判定不存在且 Redis 未命中时应短路返回 null", async () => { + isDefinitelyNotPresent.mockReturnValueOnce(true); + getCachedActiveKey.mockResolvedValueOnce(null); + + const { findActiveKeyByKeyString } = await import("@/repository/key"); + await expect(findActiveKeyByKeyString("sk-nonexistent")).resolves.toBeNull(); + expect(dbSelect).not.toHaveBeenCalled(); + }); + + test("validateApiKeyAndGetUser:key+user Redis 命中时应避免打 DB", async () => { + const cachedKey = buildKey({ key: "sk-cached", userId: 10 }); + const cachedUser = buildUser({ id: 10 }); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + getCachedUser.mockResolvedValueOnce(cachedUser); + dbSelect.mockImplementation(() => { + throw new Error("DB_ACCESS"); + }); + + const { validateApiKeyAndGetUser } = await import("@/repository/key"); + await expect(validateApiKeyAndGetUser("sk-cached")).resolves.toEqual({ + user: cachedUser, + key: cachedKey, + }); + expect(getCachedActiveKey).toHaveBeenCalledWith("sk-cached"); + expect(getCachedUser).toHaveBeenCalledWith(10); + expect(dbSelect).not.toHaveBeenCalled(); + }); + + test("validateApiKeyAndGetUser:key Redis 命中 + user miss 时应只查 user 并写回缓存", async () => { + const cachedKey = buildKey({ key: "sk-cached", userId: 10 }); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + getCachedUser.mockResolvedValueOnce(null); + + const userRow = { + id: 10, + name: "u1", + description: "", + role: "user", + rpm: null, + dailyQuota: null, + providerGroup: null, + tags: [], + createdAt: new Date("2026-01-01T00:00:00.000Z"), + updatedAt: new Date("2026-01-02T00:00:00.000Z"), + deletedAt: null, + limit5hUsd: null, + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + isEnabled: true, + expiresAt: null, + allowedClients: [], + allowedModels: [], + }; + + dbSelect.mockReturnValueOnce({ + from: () => ({ + where: async () => [userRow], + }), + }); + + const { validateApiKeyAndGetUser } = await import("@/repository/key"); + const result = await validateApiKeyAndGetUser("sk-cached"); + expect(result?.key).toEqual(cachedKey); + expect(result?.user.id).toBe(10); + expect(cacheUser).toHaveBeenCalledTimes(1); + expect(cacheAuthResult).not.toHaveBeenCalled(); + }); + + test("validateApiKeyAndGetUser:缓存未命中时应走 DB join 并写入 auth 缓存", async () => { + getCachedActiveKey.mockResolvedValueOnce(null); + + const joinRow = { + keyId: 1, + keyUserId: 10, + keyString: "sk-db", + keyName: "k1", + keyIsEnabled: true, + keyExpiresAt: null, + keyCanLoginWebUi: true, + keyLimit5hUsd: null, + keyLimitDailyUsd: null, + keyDailyResetMode: "fixed", + keyDailyResetTime: "00:00", + keyLimitWeeklyUsd: null, + keyLimitMonthlyUsd: null, + keyLimitTotalUsd: null, + keyLimitConcurrentSessions: 0, + keyProviderGroup: null, + keyCacheTtlPreference: null, + keyCreatedAt: new Date("2026-01-01T00:00:00.000Z"), + keyUpdatedAt: new Date("2026-01-02T00:00:00.000Z"), + keyDeletedAt: null, + userId: 10, + userName: "u1", + userDescription: "", + userRole: "user", + userRpm: null, + userDailyQuota: null, + userProviderGroup: null, + userLimit5hUsd: null, + userLimitWeeklyUsd: null, + userLimitMonthlyUsd: null, + userLimitTotalUsd: null, + userLimitConcurrentSessions: null, + userDailyResetMode: "fixed", + userDailyResetTime: "00:00", + userIsEnabled: true, + userExpiresAt: null, + userAllowedClients: [], + userAllowedModels: [], + userCreatedAt: new Date("2026-01-01T00:00:00.000Z"), + userUpdatedAt: new Date("2026-01-02T00:00:00.000Z"), + userDeletedAt: null, + }; + + dbSelect.mockReturnValueOnce({ + from: () => ({ + innerJoin: () => ({ + where: async () => [joinRow], + }), + }), + }); + + const { validateApiKeyAndGetUser } = await import("@/repository/key"); + const result = await validateApiKeyAndGetUser("sk-db"); + expect(result?.key.key).toBe("sk-db"); + expect(result?.user.id).toBe(10); + expect(cacheAuthResult).toHaveBeenCalledTimes(1); + }); +}); + +describe("API Key 鉴权缓存:写入/失效点覆盖", () => { + test("createKey:应广播 API key 集合变更(多实例触发 Vacuum Filter 重建)", async () => { + const prevEnableRateLimit = process.env.ENABLE_RATE_LIMIT; + const prevRedisUrl = process.env.REDIS_URL; + process.env.ENABLE_RATE_LIMIT = "true"; + process.env.REDIS_URL = "redis://localhost:6379"; + + const now = new Date("2026-01-02T00:00:00.000Z"); + const keyRow = { + id: 1, + userId: 10, + key: "sk-created", + name: "k1", + isEnabled: true, + expiresAt: null, + canLoginWebUi: true, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 0, + providerGroup: null, + cacheTtlPreference: null, + createdAt: now, + updatedAt: now, + deletedAt: null, + }; + + dbInsert.mockReturnValueOnce({ + values: () => ({ + returning: async () => [keyRow], + }), + }); + + try { + const { createKey } = await import("@/repository/key"); + const created = await createKey({ user_id: 10, name: "k1", key: "sk-created" }); + expect(created.key).toBe("sk-created"); + expect(publishCacheInvalidation).toHaveBeenCalledWith("cch:cache:api_keys:updated"); + } finally { + process.env.ENABLE_RATE_LIMIT = prevEnableRateLimit; + process.env.REDIS_URL = prevRedisUrl; + } + }); + + test("updateKey:应触发 cacheActiveKey", async () => { + const keyRow = { + id: 1, + userId: 10, + key: "sk-update", + name: "k1", + isEnabled: true, + expiresAt: null, + canLoginWebUi: true, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 0, + providerGroup: null, + cacheTtlPreference: null, + createdAt: new Date("2026-01-01T00:00:00.000Z"), + updatedAt: new Date("2026-01-02T00:00:00.000Z"), + deletedAt: null, + }; + + dbUpdate.mockReturnValueOnce({ + set: () => ({ + where: () => ({ + returning: async () => [keyRow], + }), + }), + }); + + const { updateKey } = await import("@/repository/key"); + const updated = await updateKey(1, { name: "k2" }); + expect(updated?.key).toBe("sk-update"); + expect(cacheActiveKey).toHaveBeenCalledTimes(1); + }); + + test("deleteKey:删除成功时应触发 invalidateCachedKey", async () => { + dbUpdate.mockReturnValueOnce({ + set: () => ({ + where: () => ({ + returning: async () => [{ id: 1, key: "sk-deleted" }], + }), + }), + }); + + const { deleteKey } = await import("@/repository/key"); + await expect(deleteKey(1)).resolves.toBe(true); + expect(invalidateCachedKey).toHaveBeenCalledWith("sk-deleted"); + }); +}); diff --git a/tests/unit/security/api-key-vacuum-filter-build.test.ts b/tests/unit/security/api-key-vacuum-filter-build.test.ts new file mode 100644 index 000000000..6eb01ddcb --- /dev/null +++ b/tests/unit/security/api-key-vacuum-filter-build.test.ts @@ -0,0 +1,77 @@ +import { describe, expect, test, vi } from "vitest"; + +describe("buildVacuumFilterFromKeyStrings", () => { + test("应去重并忽略空字符串,且覆盖所有 key", async () => { + const { buildVacuumFilterFromKeyStrings } = await import("@/lib/security/api-key-vacuum-filter"); + const vf = buildVacuumFilterFromKeyStrings({ + keyStrings: ["k1", "k2", "k1", ""], + fingerprintBits: 32, + maxKickSteps: 500, + seed: Buffer.from("unit-test-seed"), + }); + + expect(vf.size()).toBe(2); + expect(vf.has("k1")).toBe(true); + expect(vf.has("k2")).toBe(true); + }); + + test("空数组输入:应返回空 filter", async () => { + const { buildVacuumFilterFromKeyStrings } = await import("@/lib/security/api-key-vacuum-filter"); + const vf = buildVacuumFilterFromKeyStrings({ + keyStrings: [], + fingerprintBits: 32, + maxKickSteps: 500, + seed: Buffer.from("unit-test-seed"), + }); + + expect(vf.size()).toBe(0); + }); + + test("全空字符串:应返回空 filter", async () => { + const { buildVacuumFilterFromKeyStrings } = await import("@/lib/security/api-key-vacuum-filter"); + const vf = buildVacuumFilterFromKeyStrings({ + keyStrings: ["", "", ""], + fingerprintBits: 32, + maxKickSteps: 500, + seed: Buffer.from("unit-test-seed"), + }); + + expect(vf.size()).toBe(0); + }); + + test("构建失败时应扩容重试", async () => { + vi.resetModules(); + const maxItemsSeen: number[] = []; + + vi.doMock("@/lib/vacuum-filter/vacuum-filter", () => { + class VacuumFilter { + private readonly maxItems: number; + + constructor(options: { maxItems: number }) { + this.maxItems = options.maxItems; + maxItemsSeen.push(options.maxItems); + } + + add(_keyString: string): boolean { + // 强制第一次失败(maxItems=128),第二次成功(maxItems=ceil(128*1.6)=205) + return this.maxItems >= 205; + } + } + + return { VacuumFilter }; + }); + + const { buildVacuumFilterFromKeyStrings } = await import("@/lib/security/api-key-vacuum-filter"); + buildVacuumFilterFromKeyStrings({ + keyStrings: ["k1"], + fingerprintBits: 32, + maxKickSteps: 500, + seed: Buffer.from("unit-test-seed"), + }); + + expect(maxItemsSeen).toEqual([128, 205]); + + vi.doUnmock("@/lib/vacuum-filter/vacuum-filter"); + vi.resetModules(); + }); +}); diff --git a/tests/unit/security/api-key-vacuum-filter-reloading.test.ts b/tests/unit/security/api-key-vacuum-filter-reloading.test.ts new file mode 100644 index 000000000..a8299b793 --- /dev/null +++ b/tests/unit/security/api-key-vacuum-filter-reloading.test.ts @@ -0,0 +1,83 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +function setEnv(values: Record): void { + for (const [key, value] of Object.entries(values)) { + if (value === undefined) { + // eslint-disable-next-line @typescript-eslint/no-dynamic-delete + delete process.env[key]; + } else { + process.env[key] = value; + } + } +} + +describe("ApiKeyVacuumFilter:重建窗口安全性", () => { + const originalEnv: Record = {}; + + beforeEach(() => { + vi.resetModules(); + + // eslint-disable-next-line @typescript-eslint/no-dynamic-delete + delete (globalThis as unknown as { __CCH_API_KEY_VACUUM_FILTER__?: unknown }) + .__CCH_API_KEY_VACUUM_FILTER__; + + for (const k of ["NEXT_RUNTIME", "ENABLE_API_KEY_VACUUM_FILTER"]) { + originalEnv[k] = process.env[k]; + } + setEnv({ + NEXT_RUNTIME: "nodejs", + ENABLE_API_KEY_VACUUM_FILTER: "true", + }); + }); + + afterEach(() => { + setEnv(originalEnv); + vi.unstubAllGlobals(); + vi.useRealTimers(); + // eslint-disable-next-line @typescript-eslint/no-dynamic-delete + delete (globalThis as unknown as { __CCH_API_KEY_VACUUM_FILTER__?: unknown }) + .__CCH_API_KEY_VACUUM_FILTER__; + }); + + test("loadingPromise 存在时应返回 null(不短路)", async () => { + const [{ apiKeyVacuumFilter }, { VacuumFilter }] = await Promise.all([ + import("@/lib/security/api-key-vacuum-filter"), + import("@/lib/vacuum-filter/vacuum-filter"), + ]); + + const vf = new VacuumFilter({ + maxItems: 16, + fingerprintBits: 32, + maxKickSteps: 100, + seed: "unit-test-reloading", + }); + expect(vf.add("k1")).toBe(true); + + (apiKeyVacuumFilter as unknown as { vf: VacuumFilter }).vf = vf; + (apiKeyVacuumFilter as unknown as { loadingPromise: Promise | null }).loadingPromise = + new Promise(() => {}); + + expect(apiKeyVacuumFilter.isDefinitelyNotPresent("k1")).toBeNull(); + expect(apiKeyVacuumFilter.isDefinitelyNotPresent("missing")).toBeNull(); + }); + + test("ENABLE_API_KEY_VACUUM_FILTER=0:应禁用过滤器(不短路)", async () => { + setEnv({ ENABLE_API_KEY_VACUUM_FILTER: "0" }); + const { apiKeyVacuumFilter } = await import("@/lib/security/api-key-vacuum-filter"); + + expect(apiKeyVacuumFilter.getStats().enabled).toBe(false); + expect(apiKeyVacuumFilter.isDefinitelyNotPresent("missing")).toBeNull(); + }); + + test("未设置 ENABLE_API_KEY_VACUUM_FILTER:应默认启用(仅负向短路)", async () => { + vi.resetModules(); + // eslint-disable-next-line @typescript-eslint/no-dynamic-delete + delete (globalThis as unknown as { __CCH_API_KEY_VACUUM_FILTER__?: unknown }) + .__CCH_API_KEY_VACUUM_FILTER__; + + setEnv({ NEXT_RUNTIME: "nodejs", ENABLE_API_KEY_VACUUM_FILTER: undefined }); + const { apiKeyVacuumFilter } = await import("@/lib/security/api-key-vacuum-filter"); + + expect(apiKeyVacuumFilter.getStats().enabled).toBe(true); + }); +}); diff --git a/tests/unit/security/api-key-vacuum-filter-shortcircuit.test.ts b/tests/unit/security/api-key-vacuum-filter-shortcircuit.test.ts new file mode 100644 index 000000000..add9a2254 --- /dev/null +++ b/tests/unit/security/api-key-vacuum-filter-shortcircuit.test.ts @@ -0,0 +1,42 @@ +import { describe, expect, test, vi } from "vitest"; + +const isDefinitelyNotPresent = vi.fn(() => true); + +vi.mock("@/lib/security/api-key-vacuum-filter", () => ({ + apiKeyVacuumFilter: { + isDefinitelyNotPresent, + noteExistingKey: vi.fn(), + startBackgroundReload: vi.fn(), + getStats: vi.fn(), + }, +})); + +// 如果 Vacuum Filter 没有短路成功,这些 DB 调用会触发并让测试失败 +vi.mock("@/drizzle/db", () => ({ + db: { + select: vi.fn(() => { + throw new Error("DB_ACCESS"); + }), + insert: vi.fn(() => { + throw new Error("DB_ACCESS"); + }), + update: vi.fn(() => { + throw new Error("DB_ACCESS"); + }), + }, +})); + +describe("API Key Vacuum Filter:负向短路(避免打 DB)", () => { + test("validateApiKeyAndGetUser:definitely not present 时应直接返回 null", async () => { + const { validateApiKeyAndGetUser } = await import("@/repository/key"); + await expect(validateApiKeyAndGetUser("invalid_key")).resolves.toBeNull(); + expect(isDefinitelyNotPresent).toHaveBeenCalledWith("invalid_key"); + }); + + test("findActiveKeyByKeyString:definitely not present 时应直接返回 null", async () => { + const { findActiveKeyByKeyString } = await import("@/repository/key"); + await expect(findActiveKeyByKeyString("invalid_key")).resolves.toBeNull(); + expect(isDefinitelyNotPresent).toHaveBeenCalledWith("invalid_key"); + }); +}); + diff --git a/tests/unit/security/auth-validateKey-cache.test.ts b/tests/unit/security/auth-validateKey-cache.test.ts new file mode 100644 index 000000000..ebe7f407d --- /dev/null +++ b/tests/unit/security/auth-validateKey-cache.test.ts @@ -0,0 +1,158 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import type { Key } from "@/types/key"; +import type { User } from "@/types/user"; + +const isDefinitelyNotPresent = vi.fn(() => false); +const noteExistingKey = vi.fn(); + +const getCachedActiveKey = vi.fn(); +const getCachedUser = vi.fn(); + +// 如果缓存路径未命中,这些 DB 调用会触发并让测试失败 +vi.mock("@/drizzle/db", () => ({ + db: { + select: vi.fn(() => { + throw new Error("DB_ACCESS"); + }), + insert: vi.fn(() => { + throw new Error("DB_ACCESS"); + }), + update: vi.fn(() => { + throw new Error("DB_ACCESS"); + }), + }, +})); + +vi.mock("@/lib/security/api-key-vacuum-filter", () => ({ + apiKeyVacuumFilter: { + isDefinitelyNotPresent, + noteExistingKey, + startBackgroundReload: vi.fn(), + invalidateAndReload: vi.fn(), + getStats: vi.fn(), + }, +})); + +vi.mock("@/lib/security/api-key-auth-cache", () => ({ + getCachedActiveKey, + getCachedUser, + cacheActiveKey: vi.fn(async () => {}), + cacheAuthResult: vi.fn(async () => {}), + cacheUser: vi.fn(async () => {}), + invalidateCachedKey: vi.fn(async () => {}), + invalidateCachedUser: vi.fn(async () => {}), +})); + +function buildKey(overrides?: Partial): Key { + const now = new Date("2026-02-08T00:00:00.000Z"); + return { + id: 1, + userId: 10, + name: "k1", + key: "sk-user-login", + isEnabled: true, + expiresAt: undefined, + canLoginWebUi: true, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 0, + providerGroup: null, + cacheTtlPreference: null, + createdAt: now, + updatedAt: now, + deletedAt: undefined, + ...overrides, + }; +} + +function buildUser(overrides?: Partial): User { + const now = new Date("2026-02-08T00:00:00.000Z"); + return { + id: 10, + name: "u1", + description: "", + role: "user", + rpm: null, + dailyQuota: null, + providerGroup: null, + tags: [], + createdAt: now, + updatedAt: now, + deletedAt: undefined, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + isEnabled: true, + expiresAt: null, + allowedClients: [], + allowedModels: [], + ...overrides, + }; +} + +describe("auth.ts:validateKey(Vacuum Filter -> Redis -> DB)", () => { + beforeEach(() => { + vi.clearAllMocks(); + isDefinitelyNotPresent.mockReturnValue(false); + getCachedActiveKey.mockResolvedValue(null); + getCachedUser.mockResolvedValue(null); + }); + + test("Redis key+user 命中时:validateKey 应不访问 DB 且返回 session(保护 login 侧热路径)", async () => { + const cachedKey = buildKey({ key: "sk-user-login", canLoginWebUi: true, userId: 10 }); + const cachedUser = buildUser({ id: 10 }); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + getCachedUser.mockResolvedValueOnce(cachedUser); + + const { validateKey } = await import("@/lib/auth"); + await expect(validateKey("sk-user-login")).resolves.toEqual({ user: cachedUser, key: cachedKey }); + expect(isDefinitelyNotPresent).toHaveBeenCalledWith("sk-user-login"); + }); + + test("用户禁用:缓存命中也应拒绝(保护登录/会话)", async () => { + const cachedKey = buildKey({ key: "sk-user-disabled", canLoginWebUi: true, userId: 10 }); + const cachedUser = buildUser({ id: 10, isEnabled: false }); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + getCachedUser.mockResolvedValueOnce(cachedUser); + + const { validateKey } = await import("@/lib/auth"); + await expect(validateKey("sk-user-disabled")).resolves.toBeNull(); + }); + + test("用户过期:缓存命中也应拒绝(保护登录/会话)", async () => { + const cachedKey = buildKey({ key: "sk-user-expired", canLoginWebUi: true, userId: 10 }); + const cachedUser = buildUser({ id: 10, expiresAt: new Date("2000-01-01T00:00:00.000Z") }); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + getCachedUser.mockResolvedValueOnce(cachedUser); + + const { validateKey } = await import("@/lib/auth"); + await expect(validateKey("sk-user-expired")).resolves.toBeNull(); + }); + + test("canLoginWebUi=false 且 allowReadOnlyAccess=false:缓存命中也应拒绝", async () => { + const cachedKey = buildKey({ key: "sk-no-webui", canLoginWebUi: false, userId: 10 }); + const cachedUser = buildUser({ id: 10 }); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + getCachedUser.mockResolvedValueOnce(cachedUser); + + const { validateKey } = await import("@/lib/auth"); + await expect(validateKey("sk-no-webui", { allowReadOnlyAccess: false })).resolves.toBeNull(); + }); + + test("allowReadOnlyAccess=true:应允许 canLoginWebUi=false 的 key 登录只读页面", async () => { + const cachedKey = buildKey({ key: "sk-readonly", canLoginWebUi: false, userId: 10 }); + const cachedUser = buildUser({ id: 10 }); + getCachedActiveKey.mockResolvedValueOnce(cachedKey); + getCachedUser.mockResolvedValueOnce(cachedUser); + + const { validateKey } = await import("@/lib/auth"); + await expect(validateKey("sk-readonly", { allowReadOnlyAccess: true })).resolves.toEqual({ + user: cachedUser, + key: cachedKey, + }); + }); +}); diff --git a/tests/unit/vacuum-filter/vacuum-filter.test.ts b/tests/unit/vacuum-filter/vacuum-filter.test.ts new file mode 100644 index 000000000..a05a8db24 --- /dev/null +++ b/tests/unit/vacuum-filter/vacuum-filter.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, test } from "vitest"; +import { VacuumFilter } from "@/lib/vacuum-filter/vacuum-filter"; + +describe("VacuumFilter", () => { + test("add/has/delete 基本语义正确", () => { + const vf = new VacuumFilter({ + maxItems: 1000, + fingerprintBits: 32, + maxKickSteps: 500, + seed: "unit-test-seed", + }); + + expect(vf.has("k1")).toBe(false); + expect(vf.add("k1")).toBe(true); + expect(vf.has("k1")).toBe(true); + + expect(vf.delete("k1")).toBe(true); + expect(vf.has("k1")).toBe(false); + + // 删除不存在的 key + expect(vf.delete("k1")).toBe(false); + }); + + test("高负载下插入与查询稳定(无假阴性)", () => { + const n = 20_000; + const vf = new VacuumFilter({ + maxItems: n, + fingerprintBits: 32, + maxKickSteps: 1000, + seed: "unit-test-high-load", + }); + + for (let i = 0; i < n; i++) { + const ok = vf.add(`key_${i}`); + expect(ok).toBe(true); + } + + for (let i = 0; i < n; i++) { + expect(vf.has(`key_${i}`)).toBe(true); + } + + // 删除一小部分(碰撞概率极低;使用 32-bit fingerprint 避免测试随机性) + for (let i = 0; i < 200; i++) { + expect(vf.delete(`key_${i}`)).toBe(true); + expect(vf.has(`key_${i}`)).toBe(false); + } + }); + + test("插入失败必须回滚(不丢元素,不引入假阴性)", () => { + const vf = new VacuumFilter({ + maxItems: 10, + fingerprintBits: 32, + maxKickSteps: 50, + seed: "unit-test-rollback-on-failure", + }); + + const inserted: string[] = []; + let failed = false; + + for (let i = 0; i < 5000; i++) { + const key = `key_${i}`; + const ok = vf.add(key); + if (!ok) { + failed = true; + break; + } + inserted.push(key); + } + + expect(failed).toBe(true); + expect(vf.size()).toBe(inserted.length); + + // 已插入的元素必须都能查到(无假阴性) + for (const key of inserted) { + expect(vf.has(key)).toBe(true); + } + }); + + test("构造参数包含 NaN 时应使用默认值(不崩溃)", () => { + const vf = new VacuumFilter({ + maxItems: 1000, + // @ts-expect-error: 用于覆盖运行时边界情况 + fingerprintBits: Number.NaN, + // @ts-expect-error: 用于覆盖运行时边界情况 + maxKickSteps: Number.NaN, + // @ts-expect-error: 用于覆盖运行时边界情况 + targetLoadFactor: Number.NaN, + seed: "unit-test-nan-options", + }); + + expect(vf.add("k1")).toBe(true); + expect(vf.has("k1")).toBe(true); + }); + + test("非 ASCII 字符串也应可用(UTF-8 编码路径)", () => { + const vf = new VacuumFilter({ + maxItems: 1000, + fingerprintBits: 32, + maxKickSteps: 500, + seed: "unit-test-non-ascii", + }); + + const keys = ["你好", "ключ", "テスト"]; + for (const key of keys) { + expect(vf.add(key)).toBe(true); + expect(vf.has(key)).toBe(true); + } + }); + + test("alternate index 应为可逆映射(alt(alt(i,tag),tag)=i)且不越界", () => { + const vf = new VacuumFilter({ + maxItems: 50_000, + fingerprintBits: 32, + maxKickSteps: 1000, + seed: "unit-test-alt-index-involution", + }); + + const numBuckets = vf.capacitySlots() / 4; + // @ts-expect-error: 单测需要覆盖私有方法的核心不变量 + const altIndex = (index: number, tag: number) => vf.altIndex(index, tag) as number; + + for (let i = 0; i < 10_000; i++) { + const index = i % numBuckets; + const tag = (i * 2654435761) >>> 0; + const alt = altIndex(index, tag); + expect(alt).toBeGreaterThanOrEqual(0); + expect(alt).toBeLessThan(numBuckets); + + const back = altIndex(alt, tag); + expect(back).toBe(index); + } + }); +});