diff --git a/package.json b/package.json index 3c8b6d924..dc846764d 100644 --- a/package.json +++ b/package.json @@ -19,6 +19,7 @@ "test:integration": "vitest run --config vitest.integration.config.ts --reporter=verbose", "test:coverage": "vitest run --coverage", "test:coverage:quota": "vitest run --config vitest.quota.config.ts --coverage", + "test:coverage:my-usage": "vitest run --config vitest.my-usage.config.ts --coverage", "test:ci": "vitest run --reporter=default --reporter=junit --outputFile.junit=reports/vitest-junit.xml", "cui": "npx cui-server --host 0.0.0.0 --port 30000 --token a7564bc8882aa9a2d25d8b4ea6ea1e2e", "db:generate": "drizzle-kit generate && node scripts/validate-migrations.js", diff --git a/src/actions/my-usage.ts b/src/actions/my-usage.ts index 401aa7f0d..59a267298 100644 --- a/src/actions/my-usage.ts +++ b/src/actions/my-usage.ts @@ -1,6 +1,6 @@ "use server"; -import { and, eq, gte, isNull, sql } from "drizzle-orm"; +import { and, eq, gte, isNull, lt, sql } from "drizzle-orm"; import { db } from "@/drizzle/db"; import { keys as keysTable, messageRequest } from "@/drizzle/schema"; import { getSession } from "@/lib/auth"; @@ -20,6 +20,9 @@ import { import type { BillingModelSource } from "@/types/system-config"; import type { ActionResult } from "./types"; +// Warmup 抢答请求只用于探测/预热:日志可见,但不计入任何聚合统计 +const EXCLUDE_WARMUP_CONDITION = sql`(${messageRequest.blockedBy} IS NULL OR ${messageRequest.blockedBy} <> 'warmup')`; + export interface MyUsageMetadata { keyName: string; keyProviderGroup: string | null; @@ -303,8 +306,9 @@ export async function getMyTodayStats(): Promise> { and( eq(messageRequest.key, session.key.key), isNull(messageRequest.deletedAt), + EXCLUDE_WARMUP_CONDITION, gte(messageRequest.createdAt, timeRange.startTime), - sql`${messageRequest.createdAt} < ${timeRange.endTime}` + lt(messageRequest.createdAt, timeRange.endTime) ) ); @@ -322,8 +326,9 @@ export async function getMyTodayStats(): Promise> { and( eq(messageRequest.key, session.key.key), isNull(messageRequest.deletedAt), + EXCLUDE_WARMUP_CONDITION, gte(messageRequest.createdAt, timeRange.startTime), - sql`${messageRequest.createdAt} < ${timeRange.endTime}` + lt(messageRequest.createdAt, timeRange.endTime) ) ) .groupBy(messageRequest.model, messageRequest.originalModel); diff --git a/src/app/api/actions/[...route]/route.ts b/src/app/api/actions/[...route]/route.ts index 1cb5d5a25..8f763e4c9 100644 --- a/src/app/api/actions/[...route]/route.ts +++ b/src/app/api/actions/[...route]/route.ts @@ -20,6 +20,7 @@ import { z } from "zod"; import * as activeSessionActions from "@/actions/active-sessions"; import * as keyActions from "@/actions/keys"; import * as modelPriceActions from "@/actions/model-prices"; +import * as myUsageActions from "@/actions/my-usage"; import * as notificationBindingActions from "@/actions/notification-bindings"; import * as notificationActions from "@/actions/notifications"; import * as overviewActions from "@/actions/overview"; @@ -55,6 +56,14 @@ app.openAPIRegistry.registerComponent("securitySchemes", "cookieAuth", { "HTTP Cookie 认证。请先通过 Web UI 登录获取 auth-token Cookie,或从浏览器开发者工具中复制 Cookie 值用于 API 调用。详见上方「认证方式」章节。", }); +app.openAPIRegistry.registerComponent("securitySchemes", "bearerAuth", { + type: "http", + scheme: "bearer", + bearerFormat: "API Key", + description: + "Authorization: Bearer 方式认证(适合脚本/CLI 调用)。注意:token 与 Cookie 中 auth-token 值一致。", +}); + // ==================== 用户管理 ==================== const { route: getUsersRoute, handler: getUsersHandler } = createActionRoute( @@ -607,6 +616,194 @@ const { route: getStatusCodeListRoute, handler: getStatusCodeListHandler } = cre ); app.openapi(getStatusCodeListRoute, getStatusCodeListHandler); +// ==================== 我的用量(只读 Key 可访问) ==================== + +const { route: getMyUsageMetadataRoute, handler: getMyUsageMetadataHandler } = createActionRoute( + "my-usage", + "getMyUsageMetadata", + myUsageActions.getMyUsageMetadata, + { + requestSchema: z.object({}).describe("无需请求参数"), + responseSchema: z.object({ + keyName: z.string().describe("当前 Key 名称"), + keyProviderGroup: z.string().nullable().describe("Key 供应商分组(可为空)"), + keyExpiresAt: z.string().nullable().describe("Key 过期时间(ISO 字符串,可为空)"), + keyIsEnabled: z.boolean().describe("Key 是否启用"), + userName: z.string().describe("当前用户名称"), + userProviderGroup: z.string().nullable().describe("用户供应商分组(可为空)"), + userExpiresAt: z.string().nullable().describe("用户过期时间(ISO 字符串,可为空)"), + userIsEnabled: z.boolean().describe("用户是否启用"), + dailyResetMode: z.enum(["fixed", "rolling"]).describe("日限额重置模式"), + dailyResetTime: z.string().describe("日限额重置时间(HH:mm)"), + currencyCode: z.string().describe("货币显示(如 USD)"), + }), + description: "获取当前会话的基础信息(仅返回自己的数据)", + summary: "获取我的用量元信息", + tags: ["概览"], + allowReadOnlyAccess: true, + } +); +app.openapi(getMyUsageMetadataRoute, getMyUsageMetadataHandler); + +const { route: getMyQuotaRoute, handler: getMyQuotaHandler } = createActionRoute( + "my-usage", + "getMyQuota", + myUsageActions.getMyQuota, + { + requestSchema: z.object({}).describe("无需请求参数"), + responseSchema: z.object({ + keyLimit5hUsd: z.number().nullable(), + keyLimitDailyUsd: z.number().nullable(), + keyLimitWeeklyUsd: z.number().nullable(), + keyLimitMonthlyUsd: z.number().nullable(), + keyLimitTotalUsd: z.number().nullable(), + keyLimitConcurrentSessions: z.number().nullable(), + keyCurrent5hUsd: z.number(), + keyCurrentDailyUsd: z.number(), + keyCurrentWeeklyUsd: z.number(), + keyCurrentMonthlyUsd: z.number(), + keyCurrentTotalUsd: z.number(), + keyCurrentConcurrentSessions: z.number(), + + userLimit5hUsd: z.number().nullable(), + userLimitWeeklyUsd: z.number().nullable(), + userLimitMonthlyUsd: z.number().nullable(), + userLimitTotalUsd: z.number().nullable(), + userLimitConcurrentSessions: z.number().nullable(), + userCurrent5hUsd: z.number(), + userCurrentDailyUsd: z.number(), + userCurrentWeeklyUsd: z.number(), + userCurrentMonthlyUsd: z.number(), + userCurrentTotalUsd: z.number(), + userCurrentConcurrentSessions: z.number(), + + userLimitDailyUsd: z.number().nullable(), + userExpiresAt: z.string().nullable(), + userProviderGroup: z.string().nullable(), + userName: z.string(), + userIsEnabled: z.boolean(), + + keyProviderGroup: z.string().nullable(), + keyName: z.string(), + keyIsEnabled: z.boolean(), + + expiresAt: z.string().nullable(), + dailyResetMode: z.enum(["fixed", "rolling"]), + dailyResetTime: z.string(), + }), + description: "获取当前会话的限额与当前使用量(仅返回自己的数据)", + summary: "获取我的限额与用量", + tags: ["密钥管理"], + allowReadOnlyAccess: true, + } +); +app.openapi(getMyQuotaRoute, getMyQuotaHandler); + +const { route: getMyTodayStatsRoute, handler: getMyTodayStatsHandler } = createActionRoute( + "my-usage", + "getMyTodayStats", + myUsageActions.getMyTodayStats, + { + requestSchema: z.object({}).describe("无需请求参数"), + responseSchema: z.object({ + calls: z.number(), + inputTokens: z.number(), + outputTokens: z.number(), + costUsd: z.number(), + modelBreakdown: z.array( + z.object({ + model: z.string().nullable(), + billingModel: z.string().nullable(), + calls: z.number(), + costUsd: z.number(), + inputTokens: z.number(), + outputTokens: z.number(), + }) + ), + currencyCode: z.string(), + billingModelSource: z.enum(["original", "redirected"]), + }), + description: "获取当前会话的“今日”使用统计(按 Key 的日重置配置计算)", + summary: "获取我的今日使用统计", + tags: ["统计分析"], + allowReadOnlyAccess: true, + } +); +app.openapi(getMyTodayStatsRoute, getMyTodayStatsHandler); + +const { route: getMyUsageLogsRoute, handler: getMyUsageLogsHandler } = createActionRoute( + "my-usage", + "getMyUsageLogs", + myUsageActions.getMyUsageLogs, + { + requestSchema: z.object({ + startDate: z.string().optional().describe("开始日期(YYYY-MM-DD,可为空)"), + endDate: z.string().optional().describe("结束日期(YYYY-MM-DD,可为空)"), + model: z.string().optional(), + endpoint: z.string().optional(), + statusCode: z.number().optional(), + excludeStatusCode200: z.boolean().optional(), + minRetryCount: z.number().int().nonnegative().optional(), + pageSize: z.number().int().positive().max(100).default(20).optional(), + page: z.number().int().positive().default(1).optional(), + }), + responseSchema: z.object({ + logs: z.array( + z.object({ + id: z.number(), + createdAt: z.string().nullable(), + model: z.string().nullable(), + billingModel: z.string().nullable(), + modelRedirect: z.string().nullable(), + inputTokens: z.number(), + outputTokens: z.number(), + cost: z.number(), + statusCode: z.number().nullable(), + duration: z.number().nullable(), + endpoint: z.string().nullable(), + cacheCreationInputTokens: z.number().nullable(), + cacheReadInputTokens: z.number().nullable(), + cacheCreation5mInputTokens: z.number().nullable(), + cacheCreation1hInputTokens: z.number().nullable(), + cacheTtlApplied: z.string().nullable(), + }) + ), + total: z.number(), + page: z.number(), + pageSize: z.number(), + currencyCode: z.string(), + billingModelSource: z.enum(["original", "redirected"]), + }), + description: "获取当前会话的使用日志(仅返回自己的数据)", + summary: "获取我的使用日志", + tags: ["使用日志"], + allowReadOnlyAccess: true, + } +); +app.openapi(getMyUsageLogsRoute, getMyUsageLogsHandler); + +const { route: getMyAvailableModelsRoute, handler: getMyAvailableModelsHandler } = + createActionRoute("my-usage", "getMyAvailableModels", myUsageActions.getMyAvailableModels, { + requestSchema: z.object({}).describe("无需请求参数"), + responseSchema: z.array(z.string()), + description: "获取当前会话日志中出现过的模型列表(仅返回自己的数据)", + summary: "获取我的模型筛选项", + tags: ["使用日志"], + allowReadOnlyAccess: true, + }); +app.openapi(getMyAvailableModelsRoute, getMyAvailableModelsHandler); + +const { route: getMyAvailableEndpointsRoute, handler: getMyAvailableEndpointsHandler } = + createActionRoute("my-usage", "getMyAvailableEndpoints", myUsageActions.getMyAvailableEndpoints, { + requestSchema: z.object({}).describe("无需请求参数"), + responseSchema: z.array(z.string()), + description: "获取当前会话日志中出现过的 endpoint 列表(仅返回自己的数据)", + summary: "获取我的 endpoint 筛选项", + tags: ["使用日志"], + allowReadOnlyAccess: true, + }); +app.openapi(getMyAvailableEndpointsRoute, getMyAvailableEndpointsHandler); + // ==================== 概览数据 ==================== const { route: getOverviewDataRoute, handler: getOverviewDataHandler } = createActionRoute( diff --git a/src/lib/api/action-adapter-openapi.ts b/src/lib/api/action-adapter-openapi.ts index 21278e43c..0e5dc34e3 100644 --- a/src/lib/api/action-adapter-openapi.ts +++ b/src/lib/api/action-adapter-openapi.ts @@ -15,6 +15,15 @@ import type { ActionResult } from "@/actions/types"; import { validateKey } from "@/lib/auth"; import { logger } from "@/lib/logger"; +function getBearerTokenFromAuthHeader(raw: string | undefined): string | null { + const trimmed = raw?.trim(); + if (!trimmed) return null; + + const match = /^Bearer\s+(.+)$/i.exec(trimmed); + const token = match?.[1]?.trim(); + return token ? token : null; +} + // Server Action 函数签名 (支持两种格式) type ServerAction = // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -57,6 +66,17 @@ export interface ActionRouteOptions { */ requiresAuth?: boolean; + /** + * 允许仅访问只读页面/接口(如 my-usage),跳过 canLoginWebUi 校验 + * + * 注意: + * - 这是一个“白名单开关”,仅应对“只读且强制绑定当前会话”的端点开启 + * - 绝不能用于允许传入 userId/keyId 等可导致越权的管理型接口 + * + * @default false + */ + allowReadOnlyAccess?: boolean; + /** * 权限要求 */ @@ -243,6 +263,7 @@ export function createActionRoute( summary, tags = [module], requiresAuth = true, + allowReadOnlyAccess = false, requiredRole, requestExamples, argsMapper, // 新增:参数映射函数 @@ -269,7 +290,7 @@ export function createActionRoute( responses: createResponseSchemas(responseSchema), // 安全定义 (可选,需要在 OpenAPI 文档中配置) ...(requiresAuth && { - security: [{ cookieAuth: [] }], + security: [{ cookieAuth: [] }, { bearerAuth: [] }], }), }); @@ -281,13 +302,14 @@ export function createActionRoute( try { // 0. 认证检查 (如果需要) if (requiresAuth) { - const authToken = getCookie(c, "auth-token"); + const authToken = + getCookie(c, "auth-token") ?? getBearerTokenFromAuthHeader(c.req.header("authorization")); if (!authToken) { logger.warn(`[ActionAPI] ${fullPath} 认证失败: 缺少 auth-token`); return c.json({ ok: false, error: "未认证" }, 401); } - const session = await validateKey(authToken); + const session = await validateKey(authToken, { allowReadOnlyAccess }); if (!session) { logger.warn(`[ActionAPI] ${fullPath} 认证失败: 无效的 auth-token`); return c.json({ ok: false, error: "认证无效或已过期" }, 401); diff --git a/src/lib/auth.ts b/src/lib/auth.ts index 4fb65b0a6..0362df5fd 100644 --- a/src/lib/auth.ts +++ b/src/lib/auth.ts @@ -1,4 +1,4 @@ -import { cookies } from "next/headers"; +import { cookies, headers } from "next/headers"; import { config } from "@/lib/config/config"; import { getEnvConfig } from "@/lib/config/env.schema"; import { findActiveKeyByKeyString } from "@/repository/key"; @@ -119,10 +119,29 @@ export async function getSession(options?: { */ allowReadOnlyAccess?: boolean; }): Promise { - const keyString = await getAuthCookie(); + const keyString = await getAuthToken(); if (!keyString) { return null; } return validateKey(keyString, options); } + +function parseBearerToken(raw: string | null | undefined): string | undefined { + const trimmed = raw?.trim(); + if (!trimmed) return undefined; + + const match = /^Bearer\s+(.+)$/i.exec(trimmed); + const token = match?.[1]?.trim(); + return token || undefined; +} + +async function getAuthToken(): Promise { + // 优先使用 Cookie(兼容现有 Web UI 的登录态) + const cookieToken = await getAuthCookie(); + if (cookieToken) return cookieToken; + + // Cookie 缺失时,允许通过 Authorization: Bearer 自助调用只读接口 + const headersStore = await headers(); + return parseBearerToken(headersStore.get("authorization")); +} diff --git a/src/repository/system-config.ts b/src/repository/system-config.ts index d70b6aac6..dc8f8ee1f 100644 --- a/src/repository/system-config.ts +++ b/src/repository/system-config.ts @@ -71,6 +71,67 @@ function isTableMissingError(error: unknown, depth = 0): boolean { return false; } +function isUndefinedColumnError(error: unknown, depth = 0): boolean { + if (!error || depth > 5) { + return false; + } + + if (typeof error === "string") { + const normalized = error.toLowerCase(); + return ( + normalized.includes("42703") || + (normalized.includes("column") && + (normalized.includes("does not exist") || + normalized.includes("doesn't exist") || + normalized.includes("不存在"))) + ); + } + + if (typeof error === "object") { + const err = error as { + code?: unknown; + message?: unknown; + cause?: unknown; + errors?: unknown; + originalError?: unknown; + }; + + if (typeof err.code === "string" && err.code.toUpperCase() === "42703") { + return true; + } + + if (typeof err.message === "string" && isUndefinedColumnError(err.message, depth + 1)) { + return true; + } + + if ("cause" in err && err.cause && isUndefinedColumnError(err.cause, depth + 1)) { + return true; + } + + if (Array.isArray(err.errors)) { + return err.errors.some((item) => isUndefinedColumnError(item, depth + 1)); + } + + if (err.originalError && isUndefinedColumnError(err.originalError, depth + 1)) { + return true; + } + + const stringified = (() => { + try { + return String(error); + } catch { + return undefined; + } + })(); + + if (stringified) { + return isUndefinedColumnError(stringified, depth + 1); + } + } + + return false; +} + function createFallbackSettings(): SystemSettings { const now = new Date(); return { @@ -96,33 +157,61 @@ function createFallbackSettings(): SystemSettings { * 获取系统设置,如果不存在则创建默认记录 */ export async function getSystemSettings(): Promise { + async function selectSettingsRow() { + const fullSelection = { + id: systemSettings.id, + siteTitle: systemSettings.siteTitle, + allowGlobalUsageView: systemSettings.allowGlobalUsageView, + currencyDisplay: systemSettings.currencyDisplay, + billingModelSource: systemSettings.billingModelSource, + enableAutoCleanup: systemSettings.enableAutoCleanup, + cleanupRetentionDays: systemSettings.cleanupRetentionDays, + cleanupSchedule: systemSettings.cleanupSchedule, + cleanupBatchSize: systemSettings.cleanupBatchSize, + enableClientVersionCheck: systemSettings.enableClientVersionCheck, + verboseProviderError: systemSettings.verboseProviderError, + enableHttp2: systemSettings.enableHttp2, + interceptAnthropicWarmupRequests: systemSettings.interceptAnthropicWarmupRequests, + createdAt: systemSettings.createdAt, + updatedAt: systemSettings.updatedAt, + }; + + try { + const [row] = await db.select(fullSelection).from(systemSettings).limit(1); + return row ?? null; + } catch (error) { + // 兼容旧版本数据库:system_settings 表存在但列未迁移齐全 + if (isUndefinedColumnError(error)) { + logger.warn("system_settings 表列缺失,使用降级字段集读取(建议运行数据库迁移)。", { + error, + }); + + const minimalSelection = { + id: systemSettings.id, + siteTitle: systemSettings.siteTitle, + allowGlobalUsageView: systemSettings.allowGlobalUsageView, + currencyDisplay: systemSettings.currencyDisplay, + billingModelSource: systemSettings.billingModelSource, + createdAt: systemSettings.createdAt, + updatedAt: systemSettings.updatedAt, + }; + + const [row] = await db.select(minimalSelection).from(systemSettings).limit(1); + return row ?? null; + } + + throw error; + } + } + try { - const [settings] = await db - .select({ - id: systemSettings.id, - siteTitle: systemSettings.siteTitle, - allowGlobalUsageView: systemSettings.allowGlobalUsageView, - currencyDisplay: systemSettings.currencyDisplay, - billingModelSource: systemSettings.billingModelSource, - enableAutoCleanup: systemSettings.enableAutoCleanup, - cleanupRetentionDays: systemSettings.cleanupRetentionDays, - cleanupSchedule: systemSettings.cleanupSchedule, - cleanupBatchSize: systemSettings.cleanupBatchSize, - enableClientVersionCheck: systemSettings.enableClientVersionCheck, - verboseProviderError: systemSettings.verboseProviderError, - enableHttp2: systemSettings.enableHttp2, - interceptAnthropicWarmupRequests: systemSettings.interceptAnthropicWarmupRequests, - createdAt: systemSettings.createdAt, - updatedAt: systemSettings.updatedAt, - }) - .from(systemSettings) - .limit(1); + const settings = await selectSettingsRow(); if (settings) { return toSystemSettings(settings); } - const [created] = await db + await db .insert(systemSettings) .values({ siteTitle: DEFAULT_SITE_TITLE, @@ -130,51 +219,9 @@ export async function getSystemSettings(): Promise { currencyDisplay: "USD", billingModelSource: "original", }) - .onConflictDoNothing() - .returning({ - id: systemSettings.id, - siteTitle: systemSettings.siteTitle, - allowGlobalUsageView: systemSettings.allowGlobalUsageView, - currencyDisplay: systemSettings.currencyDisplay, - billingModelSource: systemSettings.billingModelSource, - enableAutoCleanup: systemSettings.enableAutoCleanup, - cleanupRetentionDays: systemSettings.cleanupRetentionDays, - cleanupSchedule: systemSettings.cleanupSchedule, - cleanupBatchSize: systemSettings.cleanupBatchSize, - enableClientVersionCheck: systemSettings.enableClientVersionCheck, - verboseProviderError: systemSettings.verboseProviderError, - enableHttp2: systemSettings.enableHttp2, - interceptAnthropicWarmupRequests: systemSettings.interceptAnthropicWarmupRequests, - createdAt: systemSettings.createdAt, - updatedAt: systemSettings.updatedAt, - }); - - if (created) { - return toSystemSettings(created); - } - - // 如果并发导致没有返回,重新查询一次 - const [fallback] = await db - .select({ - id: systemSettings.id, - siteTitle: systemSettings.siteTitle, - allowGlobalUsageView: systemSettings.allowGlobalUsageView, - currencyDisplay: systemSettings.currencyDisplay, - billingModelSource: systemSettings.billingModelSource, - enableAutoCleanup: systemSettings.enableAutoCleanup, - cleanupRetentionDays: systemSettings.cleanupRetentionDays, - cleanupSchedule: systemSettings.cleanupSchedule, - cleanupBatchSize: systemSettings.cleanupBatchSize, - enableClientVersionCheck: systemSettings.enableClientVersionCheck, - verboseProviderError: systemSettings.verboseProviderError, - enableHttp2: systemSettings.enableHttp2, - interceptAnthropicWarmupRequests: systemSettings.interceptAnthropicWarmupRequests, - createdAt: systemSettings.createdAt, - updatedAt: systemSettings.updatedAt, - }) - .from(systemSettings) - .limit(1); + .onConflictDoNothing(); + const fallback = await selectSettingsRow(); if (!fallback) { throw new Error("Failed to initialize system settings"); } diff --git a/src/repository/usage-logs.ts b/src/repository/usage-logs.ts index 071f7342f..1d5585880 100644 --- a/src/repository/usage-logs.ts +++ b/src/repository/usage-logs.ts @@ -267,10 +267,11 @@ export async function getTotalUsageForKey(keyString: string): Promise { export async function getDistinctModelsForKey(keyString: string): Promise { const result = await db.execute( - sql`select distinct coalesce(${messageRequest.originalModel}, ${messageRequest.model}) as model + sql`select distinct ${messageRequest.model} as model from ${messageRequest} where ${messageRequest.key} = ${keyString} and ${messageRequest.deletedAt} is null + and ${messageRequest.model} is not null order by model asc` ); diff --git a/tests/api/action-adapter-openapi.unit.test.ts b/tests/api/action-adapter-openapi.unit.test.ts new file mode 100644 index 000000000..0a99203ca --- /dev/null +++ b/tests/api/action-adapter-openapi.unit.test.ts @@ -0,0 +1,197 @@ +import { describe, expect, test, vi } from "vitest"; +import { z } from "@hono/zod-openapi"; +import { + IdParamSchema, + PaginationSchema, + SortSchema, + createActionRoute, + createActionRoutes, + createParamSchema, +} from "@/lib/api/action-adapter-openapi"; + +/** + * 说明: + * - 这些测试只覆盖 adapter 的“通用执行器”逻辑 + * - 不依赖 Next/Hono 的完整运行时 + * - 重点验证:参数映射、返回值包装、错误/异常处理、requiresAuth=false 分支 + */ + +function createMockContext(options?: { body?: unknown; jsonThrows?: boolean }) { + const body = options?.body ?? {}; + const jsonThrows = options?.jsonThrows ?? false; + + return { + req: { + json: async () => { + if (jsonThrows) { + throw new Error("invalid json"); + } + return body; + }, + }, + json: (payload: unknown, status = 200) => + new Response(JSON.stringify(payload), { + status, + headers: { "content-type": "application/json" }, + }), + } as const; +} + +describe("Action Adapter:createActionRoute(单元测试)", () => { + test("requiresAuth=false:返回非 ActionResult 时自动包装 {ok:true,data}", async () => { + const { handler } = createActionRoute( + "test", + "returnsRaw", + async () => { + return { hello: "world" }; + }, + { requiresAuth: false } + ); + + const response = (await handler(createMockContext({ body: {} }) as any)) as Response; + expect(response.status).toBe(200); + await expect(response.json()).resolves.toEqual({ ok: true, data: { hello: "world" } }); + }); + + test("默认参数推断:schema 单字段时应传入该字段值", async () => { + const action = vi.fn(async (id: number) => ({ id })); + const { handler } = createActionRoute("test", "singleArg", action as any, { + requiresAuth: false, + requestSchema: z.object({ id: z.number() }), + }); + + const response = (await handler(createMockContext({ body: { id: 123 } }) as any)) as Response; + expect(action).toHaveBeenCalledWith(123); + expect(response.status).toBe(200); + await expect(response.json()).resolves.toEqual({ ok: true, data: { id: 123 } }); + }); + + test("默认参数推断:多字段 schema 传入整个 body(单参)", async () => { + const action = vi.fn(async (body: { a: string; b: string }) => body); + const { handler } = createActionRoute("test", "multiKey", action as any, { + requiresAuth: false, + requestSchema: z.object({ a: z.string(), b: z.string() }), + }); + + const response = (await handler( + createMockContext({ body: { a: "x", b: "y" } }) as any + )) as Response; + expect(action).toHaveBeenCalledWith({ a: "x", b: "y" }); + expect(response.status).toBe(200); + await expect(response.json()).resolves.toEqual({ ok: true, data: { a: "x", b: "y" } }); + }); + + test("argsMapper:应优先使用显式映射以支持多参数 action", async () => { + const action = vi.fn(async (userId: number, data: { name: string }) => ({ userId, data })); + const { handler } = createActionRoute("test", "mappedArgs", action as any, { + requiresAuth: false, + requestSchema: z.object({ + userId: z.number(), + data: z.object({ name: z.string() }), + }), + argsMapper: (body: { userId: number; data: { name: string } }) => [body.userId, body.data], + }); + + const response = (await handler( + createMockContext({ body: { userId: 7, data: { name: "alice" } } }) as any + )) as Response; + expect(action).toHaveBeenCalledWith(7, { name: "alice" }); + expect(response.status).toBe(200); + await expect(response.json()).resolves.toEqual({ + ok: true, + data: { userId: 7, data: { name: "alice" } }, + }); + }); + + test("action 返回 ok=false:应返回 400 且透传 errorCode/errorParams", async () => { + const { handler } = createActionRoute( + "test", + "returnsError", + async () => ({ + ok: false, + error: "业务错误", + errorCode: "BIZ_ERROR", + errorParams: { field: "name" }, + }), + { requiresAuth: false } + ); + + const response = (await handler(createMockContext({ body: {} }) as any)) as Response; + expect(response.status).toBe(400); + await expect(response.json()).resolves.toEqual({ + ok: false, + error: "业务错误", + errorCode: "BIZ_ERROR", + errorParams: { field: "name" }, + }); + }); + + test("action 抛出 Error:应返回 500 且返回 error.message", async () => { + const { handler } = createActionRoute( + "test", + "throwsError", + async () => { + throw new Error("boom"); + }, + { requiresAuth: false } + ); + + const response = (await handler(createMockContext({ body: {} }) as any)) as Response; + expect(response.status).toBe(500); + await expect(response.json()).resolves.toEqual({ ok: false, error: "boom" }); + }); + + test("action 抛出非 Error:应返回 500 且返回通用错误消息", async () => { + const { handler } = createActionRoute( + "test", + "throwsUnknown", + async () => { + // eslint-disable-next-line no-throw-literal + throw "boom"; + }, + { requiresAuth: false } + ); + + const response = (await handler(createMockContext({ body: {} }) as any)) as Response; + expect(response.status).toBe(500); + await expect(response.json()).resolves.toEqual({ ok: false, error: "服务器内部错误" }); + }); + + test("请求体不是 JSON:应降级为 {} 并继续执行", async () => { + const action = vi.fn(async () => "ok"); + const { handler } = createActionRoute("test", "badJson", action as any, { + requiresAuth: false, + }); + + const response = (await handler(createMockContext({ jsonThrows: true }) as any)) as Response; + expect(action).toHaveBeenCalledTimes(1); + expect(response.status).toBe(200); + await expect(response.json()).resolves.toEqual({ ok: true, data: "ok" }); + }); +}); + +describe("Action Adapter:辅助导出函数(单元测试)", () => { + test("createActionRoutes:应批量生成 route/handler", () => { + const routes = createActionRoutes( + "demo", + { + a: async () => ({ ok: true, data: 1 }), + b: async () => 2, + }, + { + b: { requiresAuth: false }, + } + ); + + expect(routes).toHaveLength(2); + }); + + test("通用 schemas:应支持解析与默认值", () => { + const schema = createParamSchema({ name: z.string() }); + expect(schema.parse({ name: "x" })).toEqual({ name: "x" }); + + expect(IdParamSchema.parse({ id: 1 })).toEqual({ id: 1 }); + expect(PaginationSchema.parse({})).toEqual({ page: 1, pageSize: 20 }); + expect(SortSchema.parse({})).toEqual({ sortBy: undefined, sortOrder: "desc" }); + }); +}); diff --git a/tests/api/api-actions-integrity.test.ts b/tests/api/api-actions-integrity.test.ts index 6deb843b5..b6bf152e1 100644 --- a/tests/api/api-actions-integrity.test.ts +++ b/tests/api/api-actions-integrity.test.ts @@ -113,6 +113,22 @@ describe("OpenAPI 端点完整性检查", () => { } }); + test("我的用量模块的所有端点应该被注册", () => { + const expectedPaths = [ + "/api/actions/my-usage/getMyUsageMetadata", + "/api/actions/my-usage/getMyQuota", + "/api/actions/my-usage/getMyTodayStats", + "/api/actions/my-usage/getMyUsageLogs", + "/api/actions/my-usage/getMyAvailableModels", + "/api/actions/my-usage/getMyAvailableEndpoints", + ]; + + for (const path of expectedPaths) { + expect(openApiDoc.paths[path]).toBeDefined(); + expect(openApiDoc.paths[path].post).toBeDefined(); + } + }); + test("概览模块的所有端点应该被注册", () => { const expectedPaths = ["/api/actions/overview/getOverviewData"]; diff --git a/tests/api/auth.unit.test.ts b/tests/api/auth.unit.test.ts new file mode 100644 index 000000000..e05966d5b --- /dev/null +++ b/tests/api/auth.unit.test.ts @@ -0,0 +1,254 @@ +import { afterAll, beforeEach, describe, expect, test, vi } from "vitest"; +import { inArray } from "drizzle-orm"; +import { db } from "@/drizzle/db"; +import { keys, users } from "@/drizzle/schema"; +import { + clearAuthCookie, + getAuthCookie, + getLoginRedirectTarget, + getSession, + setAuthCookie, + validateKey, +} from "@/lib/auth"; + +/** + * 说明: + * - 本文件用于覆盖 auth.ts 的权限边界与 Cookie 行为 + * - 重点验证:allowReadOnlyAccess 白名单语义 + * - 以及 getSession/cookie 的读写一致性 + */ + +let currentCookieValue: string | undefined; +let currentAuthorizationValue: string | undefined; +const cookieSet = vi.fn((name: string, value: string) => { + if (name === "auth-token") currentCookieValue = value; +}); +const cookieDelete = vi.fn((name: string) => { + if (name === "auth-token") currentCookieValue = undefined; +}); + +vi.mock("next/headers", () => ({ + cookies: () => ({ + get: (name: string) => { + if (name !== "auth-token") return undefined; + return currentCookieValue ? { value: currentCookieValue } : undefined; + }, + set: cookieSet, + delete: cookieDelete, + has: (name: string) => name === "auth-token" && Boolean(currentCookieValue), + }), + headers: () => ({ + get: (name: string) => { + if (name.toLowerCase() !== "authorization") return null; + return currentAuthorizationValue ?? null; + }, + }), +})); + +type TestUser = { id: number; name: string }; +type TestKey = { id: number; userId: number; key: string; canLoginWebUi: boolean }; + +async function createTestUser(name: string): Promise { + const [row] = await db + .insert(users) + .values({ name }) + .returning({ id: users.id, name: users.name }); + if (!row) throw new Error("创建测试用户失败:未返回插入结果"); + return row; +} + +async function createTestKey(params: { + userId: number; + key: string; + canLoginWebUi: boolean; +}): Promise { + const [row] = await db + .insert(keys) + .values({ + userId: params.userId, + key: params.key, + name: `key-${params.key}`, + canLoginWebUi: params.canLoginWebUi, + dailyResetMode: "rolling", + dailyResetTime: "00:00", + }) + .returning({ + id: keys.id, + userId: keys.userId, + key: keys.key, + canLoginWebUi: keys.canLoginWebUi, + }); + + if (!row) throw new Error("创建测试 Key 失败:未返回插入结果"); + return row; +} + +describe("auth.ts:validateKey / getSession(安全边界)", () => { + const createdUserIds: number[] = []; + const createdKeyIds: number[] = []; + + afterAll(async () => { + const now = new Date(); + if (createdKeyIds.length > 0) { + await db + .update(keys) + .set({ deletedAt: now, updatedAt: now }) + .where(inArray(keys.id, createdKeyIds)); + } + if (createdUserIds.length > 0) { + await db + .update(users) + .set({ deletedAt: now, updatedAt: now }) + .where(inArray(users.id, createdUserIds)); + } + }); + + beforeEach(() => { + currentCookieValue = undefined; + currentAuthorizationValue = undefined; + cookieSet.mockClear(); + cookieDelete.mockClear(); + }); + + test("admin token:应返回 admin session(无需 DB)", async () => { + const adminToken = process.env.ADMIN_TOKEN; + expect(adminToken).toBeTruthy(); + + const session = await validateKey(adminToken as string); + expect(session?.user.role).toBe("admin"); + expect(session?.key.canLoginWebUi).toBe(true); + }); + + test("不存在的 key:validateKey 应返回 null", async () => { + const session = await validateKey(`non-existent-${Date.now()}`); + expect(session).toBeNull(); + }); + + test("canLoginWebUi=false 且 allowReadOnlyAccess=false:应拒绝", async () => { + const unique = `auth-${Date.now()}-${Math.random().toString(16).slice(2)}`; + const user = await createTestUser(`Test ${unique}`); + createdUserIds.push(user.id); + const key = await createTestKey({ + userId: user.id, + key: `test-key-${unique}`, + canLoginWebUi: false, + }); + createdKeyIds.push(key.id); + + const session = await validateKey(key.key, { allowReadOnlyAccess: false }); + expect(session).toBeNull(); + }); + + test("allowReadOnlyAccess=true:应允许只读 key 查询自己的数据", async () => { + const unique = `auth-ro-${Date.now()}-${Math.random().toString(16).slice(2)}`; + const user = await createTestUser(`Test ${unique}`); + createdUserIds.push(user.id); + const key = await createTestKey({ + userId: user.id, + key: `test-ro-key-${unique}`, + canLoginWebUi: false, + }); + createdKeyIds.push(key.id); + + const session = await validateKey(key.key, { allowReadOnlyAccess: true }); + expect(session?.key.key).toBe(key.key); + expect(session?.key.canLoginWebUi).toBe(false); + }); + + test("用户被软删除:validateKey 应返回 null", async () => { + const unique = `auth-del-${Date.now()}-${Math.random().toString(16).slice(2)}`; + const user = await createTestUser(`Test ${unique}`); + createdUserIds.push(user.id); + const key = await createTestKey({ + userId: user.id, + key: `test-key-${unique}`, + canLoginWebUi: true, + }); + createdKeyIds.push(key.id); + + const now = new Date(); + await db + .update(users) + .set({ deletedAt: now, updatedAt: now }) + .where(inArray(users.id, [user.id])); + + const session = await validateKey(key.key, { allowReadOnlyAccess: true }); + expect(session).toBeNull(); + }); + + test("getSession:无 Cookie 时返回 null;有 Cookie 时返回 session", async () => { + const noCookie = await getSession({ allowReadOnlyAccess: true }); + expect(noCookie).toBeNull(); + + const unique = `auth-sess-${Date.now()}-${Math.random().toString(16).slice(2)}`; + const user = await createTestUser(`Test ${unique}`); + createdUserIds.push(user.id); + const key = await createTestKey({ + userId: user.id, + key: `test-key-${unique}`, + canLoginWebUi: false, + }); + createdKeyIds.push(key.id); + + currentCookieValue = key.key; + const session = await getSession({ allowReadOnlyAccess: true }); + expect(session?.key.key).toBe(key.key); + }); + + test("getSession:仅 Authorization: Bearer 时也应返回 session", async () => { + const unique = `auth-bearer-${Date.now()}-${Math.random().toString(16).slice(2)}`; + const user = await createTestUser(`Test ${unique}`); + createdUserIds.push(user.id); + const key = await createTestKey({ + userId: user.id, + key: `test-key-${unique}`, + canLoginWebUi: false, + }); + createdKeyIds.push(key.id); + + currentAuthorizationValue = `Bearer ${key.key}`; + const session = await getSession({ allowReadOnlyAccess: true }); + expect(session?.key.key).toBe(key.key); + }); +}); + +describe("auth.ts:Cookie 工具函数与跳转目标", () => { + beforeEach(() => { + currentCookieValue = undefined; + currentAuthorizationValue = undefined; + cookieSet.mockClear(); + cookieDelete.mockClear(); + }); + + test("set/get/clear auth cookie:应读写一致", async () => { + await setAuthCookie("abc"); + expect(cookieSet).toHaveBeenCalled(); + + const value = await getAuthCookie(); + expect(value).toBe("abc"); + + await clearAuthCookie(); + expect(cookieDelete).toHaveBeenCalledWith("auth-token"); + expect(await getAuthCookie()).toBeUndefined(); + }); + + test("getLoginRedirectTarget:应根据 role 与 canLoginWebUi 决定跳转", () => { + const adminTarget = getLoginRedirectTarget({ + user: { role: "admin" } as any, + key: { canLoginWebUi: false } as any, + }); + expect(adminTarget).toBe("/dashboard"); + + const webUiTarget = getLoginRedirectTarget({ + user: { role: "user" } as any, + key: { canLoginWebUi: true } as any, + }); + expect(webUiTarget).toBe("/dashboard"); + + const readonlyTarget = getLoginRedirectTarget({ + user: { role: "user" } as any, + key: { canLoginWebUi: false } as any, + }); + expect(readonlyTarget).toBe("/my-usage"); + }); +}); diff --git a/tests/api/my-usage-readonly.test.ts b/tests/api/my-usage-readonly.test.ts new file mode 100644 index 000000000..7c938d492 --- /dev/null +++ b/tests/api/my-usage-readonly.test.ts @@ -0,0 +1,436 @@ +import { afterAll, beforeEach, describe, expect, test, vi } from "vitest"; +import { inArray } from "drizzle-orm"; +import { db } from "@/drizzle/db"; +import { keys, messageRequest, users } from "@/drizzle/schema"; +import { callActionsRoute } from "../test-utils"; + +/** + * 说明: + * - /api/actions 的鉴权在 adapter 层支持 Cookie 与 Authorization: Bearer + * - my-usage 的业务逻辑在 action 层仍会调用 getSession()(next/headers cookies/headers) + * - 测试环境下需要 mock next/headers,否则 getSession 无法读取认证信息 + * + * 这里用一个可变的 currentAuthToken 作为“当前请求 Cookie”,并确保: + * - adapter 校验用的 Cookie(callActionsRoute.authToken) + * - action 读取到的 Cookie(currentAuthToken) + * 两者保持一致,避免出现“adapter 通过但 action 读不到 session”的假失败。 + */ +let currentAuthToken: string | undefined; +let currentAuthorization: string | undefined; + +vi.mock("next/headers", () => ({ + cookies: () => ({ + get: (name: string) => { + if (name !== "auth-token") return undefined; + return currentAuthToken ? { value: currentAuthToken } : undefined; + }, + set: vi.fn(), + delete: vi.fn(), + has: (name: string) => name === "auth-token" && Boolean(currentAuthToken), + }), + headers: () => ({ + get: (name: string) => { + if (name.toLowerCase() !== "authorization") return null; + return currentAuthorization ?? null; + }, + }), +})); + +type TestKey = { id: number; userId: number; key: string; name: string }; +type TestUser = { id: number; name: string }; + +async function createTestUser(name: string): Promise { + const [row] = await db + .insert(users) + .values({ + name, + }) + .returning({ id: users.id, name: users.name }); + + if (!row) { + throw new Error("创建测试用户失败:未返回插入结果"); + } + + return row; +} + +async function createTestKey(params: { + userId: number; + key: string; + name: string; + canLoginWebUi: boolean; +}): Promise { + const [row] = await db + .insert(keys) + .values({ + userId: params.userId, + key: params.key, + name: params.name, + canLoginWebUi: params.canLoginWebUi, + // 为避免跨时区/临界点导致“今日”边界不稳定,这里固定使用 rolling + dailyResetMode: "rolling", + dailyResetTime: "00:00", + }) + .returning({ id: keys.id, userId: keys.userId, key: keys.key, name: keys.name }); + + if (!row) { + throw new Error("创建测试 Key 失败:未返回插入结果"); + } + + return row; +} + +async function createMessage(params: { + userId: number; + key: string; + model: string; + originalModel?: string; + endpoint?: string | null; + costUsd?: string | null; + inputTokens?: number | null; + outputTokens?: number | null; + blockedBy?: string | null; + createdAt: Date; +}): Promise { + const [row] = await db + .insert(messageRequest) + .values({ + providerId: 0, + userId: params.userId, + key: params.key, + model: params.model, + originalModel: params.originalModel ?? params.model, + endpoint: params.endpoint ?? "/v1/messages", + costUsd: params.costUsd ?? "0", + inputTokens: params.inputTokens ?? 0, + outputTokens: params.outputTokens ?? 0, + blockedBy: params.blockedBy ?? null, + createdAt: params.createdAt, + updatedAt: params.createdAt, + }) + .returning({ id: messageRequest.id }); + + if (!row?.id) { + throw new Error("创建 message_request 失败:未返回 id"); + } + + return row.id; +} + +describe("my-usage API:只读 Key 自助查询", () => { + const createdUserIds: number[] = []; + const createdKeyIds: number[] = []; + const createdMessageIds: number[] = []; + + afterAll(async () => { + // 软删除更安全:避免潜在外键约束或其他测试依赖 + const now = new Date(); + if (createdMessageIds.length > 0) { + await db + .update(messageRequest) + .set({ deletedAt: now, updatedAt: now }) + .where(inArray(messageRequest.id, createdMessageIds)); + } + + if (createdKeyIds.length > 0) { + await db + .update(keys) + .set({ deletedAt: now, updatedAt: now }) + .where(inArray(keys.id, createdKeyIds)); + } + + if (createdUserIds.length > 0) { + await db + .update(users) + .set({ deletedAt: now, updatedAt: now }) + .where(inArray(users.id, createdUserIds)); + } + }); + + beforeEach(() => { + currentAuthToken = undefined; + currentAuthorization = undefined; + }); + + test("未携带 auth-token:my-usage 端点应返回 401", async () => { + const { response, json } = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/my-usage/getMyTodayStats", + body: {}, + }); + + expect(response.status).toBe(401); + expect(json).toMatchObject({ ok: false }); + }); + + test("只读 Key:允许访问 my-usage 端点,但禁止访问其他 WebUI API", async () => { + const unique = `my-usage-readonly-${Date.now()}-${Math.random().toString(16).slice(2)}`; + const user = await createTestUser(`Test ${unique}`); + createdUserIds.push(user.id); + + const readonlyKey = await createTestKey({ + userId: user.id, + key: `test-readonly-key-${unique}`, + name: `readonly-${unique}`, + canLoginWebUi: false, + }); + createdKeyIds.push(readonlyKey.id); + + currentAuthToken = readonlyKey.key; + + // 允许访问 my-usage(allowReadOnlyAccess 白名单) + const meta = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/my-usage/getMyUsageMetadata", + authToken: readonlyKey.key, + body: {}, + }); + expect(meta.response.status).toBe(200); + expect(meta.json).toMatchObject({ ok: true }); + + const quota = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/my-usage/getMyQuota", + authToken: readonlyKey.key, + body: {}, + }); + expect(quota.response.status).toBe(200); + expect(quota.json).toMatchObject({ ok: true }); + + // 禁止访问需要 WebUI 权限的 actions(默认 validateKey 会拒绝 canLoginWebUi=false 的 key) + const usersApi = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/users/getUsers", + authToken: readonlyKey.key, + body: {}, + }); + expect(usersApi.response.status).toBe(401); + expect(usersApi.json).toMatchObject({ ok: false }); + + const usageLogsApi = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/usage-logs/getUsageLogs", + authToken: readonlyKey.key, + body: {}, + }); + expect(usageLogsApi.response.status).toBe(401); + expect(usageLogsApi.json).toMatchObject({ ok: false }); + }); + + test("Bearer-only:仅 Authorization 也应可查询 my-usage,但仍禁止访问 WebUI API", async () => { + const unique = `my-usage-bearer-${Date.now()}-${Math.random().toString(16).slice(2)}`; + const user = await createTestUser(`Test ${unique}`); + createdUserIds.push(user.id); + + const readonlyKey = await createTestKey({ + userId: user.id, + key: `test-readonly-key-${unique}`, + name: `readonly-${unique}`, + canLoginWebUi: false, + }); + createdKeyIds.push(readonlyKey.id); + + const now = new Date(); + const msgId = await createMessage({ + userId: user.id, + key: readonlyKey.key, + model: "gpt-4.1-mini", + endpoint: "/v1/messages", + costUsd: "0.0100", + inputTokens: 10, + outputTokens: 20, + createdAt: new Date(now.getTime() - 60 * 1000), + }); + createdMessageIds.push(msgId); + + currentAuthorization = `Bearer ${readonlyKey.key}`; + + const stats = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/my-usage/getMyTodayStats", + headers: { Authorization: currentAuthorization }, + body: {}, + }); + expect(stats.response.status).toBe(200); + expect(stats.json).toMatchObject({ ok: true }); + expect((stats.json as any).data.calls).toBe(1); + + const usersApi = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/users/getUsers", + headers: { Authorization: currentAuthorization }, + body: {}, + }); + expect(usersApi.response.status).toBe(401); + expect(usersApi.json).toMatchObject({ ok: false }); + }); + + test("今日统计:应与 message_request 数据一致,并排除 warmup 与其他 Key 数据", async () => { + const unique = `my-usage-stats-${Date.now()}-${Math.random().toString(16).slice(2)}`; + + const userA = await createTestUser(`Test ${unique}-A`); + createdUserIds.push(userA.id); + const keyA = await createTestKey({ + userId: userA.id, + key: `test-readonly-key-A-${unique}`, + name: `readonly-A-${unique}`, + canLoginWebUi: false, + }); + createdKeyIds.push(keyA.id); + + const userB = await createTestUser(`Test ${unique}-B`); + createdUserIds.push(userB.id); + const keyB = await createTestKey({ + userId: userB.id, + key: `test-readonly-key-B-${unique}`, + name: `readonly-B-${unique}`, + canLoginWebUi: false, + }); + createdKeyIds.push(keyB.id); + + const now = new Date(); + const t0 = new Date(now.getTime() - 60 * 1000); + + // A:两条正常计费请求 + 一条 warmup(应被排除) + const a1 = await createMessage({ + userId: userA.id, + key: keyA.key, + model: "gpt-4.1", + originalModel: "gpt-4.1-original", + endpoint: "/v1/messages", + costUsd: "0.0125", + inputTokens: 100, + outputTokens: 200, + createdAt: t0, + }); + const a2 = await createMessage({ + userId: userA.id, + key: keyA.key, + model: "gpt-4.1-mini", + originalModel: "gpt-4.1-mini-original", + endpoint: "/v1/chat/completions", + costUsd: "0.0075", + inputTokens: 50, + outputTokens: 80, + createdAt: t0, + }); + const warmup = await createMessage({ + userId: userA.id, + key: keyA.key, + model: "gpt-4.1-mini", + originalModel: "gpt-4.1-mini", + endpoint: "/v1/messages", + costUsd: null, + inputTokens: 999, + outputTokens: 999, + blockedBy: "warmup", + createdAt: t0, + }); + createdMessageIds.push(a1, a2, warmup); + + // B:一条正常请求(不应泄漏给 A) + const b1 = await createMessage({ + userId: userB.id, + key: keyB.key, + model: "gpt-4.1", + originalModel: "gpt-4.1", + endpoint: "/v1/messages", + costUsd: "0.1000", + inputTokens: 1000, + outputTokens: 1000, + createdAt: t0, + }); + createdMessageIds.push(b1); + + currentAuthToken = keyA.key; + + const { response, json } = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/my-usage/getMyTodayStats", + authToken: keyA.key, + body: {}, + }); + + expect(response.status).toBe(200); + expect(json).toMatchObject({ ok: true }); + + const data = (json as any).data as { + calls: number; + inputTokens: number; + outputTokens: number; + costUsd: number; + modelBreakdown: Array<{ + model: string | null; + billingModel: string | null; + calls: number; + costUsd: number; + inputTokens: number; + outputTokens: number; + }>; + billingModelSource: "original" | "redirected"; + }; + + // warmup 排除后:只剩两条 + expect(data.calls).toBe(2); + expect(data.inputTokens).toBe(150); + expect(data.outputTokens).toBe(280); + expect(data.costUsd).toBeCloseTo(0.02, 10); + + // breakdown:至少包含两个模型 + const breakdownByModel = new Map(data.modelBreakdown.map((row) => [row.model, row])); + expect(breakdownByModel.get("gpt-4.1")?.calls).toBe(1); + expect(breakdownByModel.get("gpt-4.1-mini")?.calls).toBe(1); + + // billingModelSource 不假设固定值,但要求 billingModel 字段与配置一致 + const originalModelByModel = new Map([ + ["gpt-4.1", "gpt-4.1-original"], + ["gpt-4.1-mini", "gpt-4.1-mini-original"], + ]); + for (const row of data.modelBreakdown) { + if (!row.model) continue; + const expectedBillingModel = + data.billingModelSource === "original" ? originalModelByModel.get(row.model) : row.model; + expect(row.billingModel).toBe(expectedBillingModel); + } + + // 同时验证 usage logs:不应返回 B 的日志(不泄漏) + const logs = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/my-usage/getMyUsageLogs", + authToken: keyA.key, + body: { page: 1, pageSize: 50 }, + }); + + expect(logs.response.status).toBe(200); + expect(logs.json).toMatchObject({ ok: true }); + + const logIds = ((logs.json as any).data.logs as Array<{ id: number }>).map((l) => l.id); + expect(logIds).toContain(a1); + expect(logIds).toContain(a2); + // warmup 行是否展示不做强约束(日志口径可见),但绝不能泄漏 B + expect(logIds).not.toContain(b1); + + // 筛选项接口:模型与端点列表应可用 + const models = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/my-usage/getMyAvailableModels", + authToken: keyA.key, + body: {}, + }); + expect(models.response.status).toBe(200); + expect((models.json as any).ok).toBe(true); + expect((models.json as any).data).toEqual(expect.arrayContaining(["gpt-4.1", "gpt-4.1-mini"])); + + const endpoints = await callActionsRoute({ + method: "POST", + pathname: "/api/actions/my-usage/getMyAvailableEndpoints", + authToken: keyA.key, + body: {}, + }); + expect(endpoints.response.status).toBe(200); + expect((endpoints.json as any).ok).toBe(true); + expect((endpoints.json as any).data).toEqual( + expect.arrayContaining(["/v1/messages", "/v1/chat/completions"]) + ); + }); +}); diff --git a/tests/cleanup-utils.ts b/tests/cleanup-utils.ts index 36f00f590..69230812e 100644 --- a/tests/cleanup-utils.ts +++ b/tests/cleanup-utils.ts @@ -4,9 +4,9 @@ * 用途:在测试后自动清理创建的测试数据 */ -import { and, isNull, like, or, sql } from "drizzle-orm"; +import { and, inArray, isNull, like, or, sql } from "drizzle-orm"; import { db } from "@/drizzle/db"; -import { users } from "@/drizzle/schema"; +import { keys as keysTable, users } from "@/drizzle/schema"; /** * 清理所有测试用户及其关联数据 @@ -53,26 +53,25 @@ export async function cleanupTestUsers(options?: { const testUserIds = testUsers.map((u) => u.id); // 2. 软删除关联的 Keys - const deletedKeys = await db.execute(sql` - UPDATE keys - SET deleted_at = NOW(), updated_at = NOW() - WHERE user_id = ANY(${testUserIds}) - AND deleted_at IS NULL - `); + const now = new Date(); + const deletedKeys = await db + .update(keysTable) + .set({ deletedAt: now, updatedAt: now }) + .where(and(inArray(keysTable.userId, testUserIds), isNull(keysTable.deletedAt))) + .returning({ id: keysTable.id }); // 3. 软删除测试用户 - const _deletedUsers = await db.execute(sql` - UPDATE users - SET deleted_at = NOW(), updated_at = NOW() - WHERE id = ANY(${testUserIds}) - AND deleted_at IS NULL - `); + await db + .update(users) + .set({ deletedAt: now, updatedAt: now }) + .where(and(inArray(users.id, testUserIds), isNull(users.deletedAt))) + .returning({ id: users.id }); console.log(`✅ 清理完成:删除 ${testUsers.length} 个用户和对应的 Keys`); return { deletedUsers: testUsers.length, - deletedKeys: deletedKeys.count ?? 0, + deletedKeys: deletedKeys.length, userNames: testUsers.map((u) => u.name), }; } catch (error) { diff --git a/tests/nextjs.mock.ts b/tests/nextjs.mock.ts index f72d77767..01aec6c95 100644 --- a/tests/nextjs.mock.ts +++ b/tests/nextjs.mock.ts @@ -26,6 +26,9 @@ vi.mock("next/headers", () => ({ delete: vi.fn(), has: vi.fn((name: string) => name === "auth-token" && !!process.env.TEST_ADMIN_TOKEN), })), + headers: vi.fn(() => ({ + get: vi.fn(() => null), + })), })); // ==================== Mock next-intl ==================== diff --git a/tests/setup.ts b/tests/setup.ts index 006f561ca..f8360ee65 100644 --- a/tests/setup.ts +++ b/tests/setup.ts @@ -10,22 +10,22 @@ import { afterAll, beforeAll } from "vitest"; // ==================== 加载环境变量 ==================== // 优先加载 .env.test(如果存在) -config({ path: ".env.test" }); +config({ path: ".env.test", quiet: true }); // 降级加载 .env -config({ path: ".env" }); +config({ path: ".env", quiet: true }); // ==================== 全局前置钩子 ==================== beforeAll(async () => { - console.log("\n🧪 Vitest 测试环境初始化...\n"); + console.log("\nVitest 测试环境初始化...\n"); // 安全检查:确保使用测试数据库 const dsn = process.env.DSN || ""; const dbName = dsn.split("/").pop() || ""; if (process.env.NODE_ENV === "production") { - throw new Error("❌ 禁止在生产环境运行测试"); + throw new Error("禁止在生产环境运行测试"); } // 强制要求:测试必须使用包含 'test' 的数据库(CI 和本地都检查) @@ -33,7 +33,7 @@ beforeAll(async () => { // 允许通过环境变量显式跳过检查(仅用于特殊情况) if (process.env.ALLOW_NON_TEST_DB !== "true") { throw new Error( - `❌ 安全检查失败: 数据库名称必须包含 'test' 字样\n` + + `安全检查失败: 数据库名称必须包含 'test' 字样\n` + ` 当前数据库: ${dbName}\n` + ` 建议使用测试专用数据库(如 claude_code_hub_test)\n` + ` 如需跳过检查,请设置环境变量: ALLOW_NON_TEST_DB=true` @@ -41,13 +41,13 @@ beforeAll(async () => { } // 即使跳过检查也要发出警告 - console.warn("⚠️ 警告: 当前数据库不包含 'test' 字样"); + console.warn("警告: 当前数据库不包含 'test' 字样"); console.warn(` 数据库: ${dbName}`); console.warn(" 建议使用独立的测试数据库避免数据污染\n"); } // 显示测试配置 - console.log("📋 测试配置:"); + console.log("测试配置:"); console.log(` - 数据库: ${dbName || "未配置"}`); console.log(` - Redis: ${process.env.REDIS_URL?.split("//")[1]?.split("@")[1] || "未配置"}`); console.log(` - API Base: ${process.env.API_BASE_URL || "http://localhost:13500"}`); @@ -58,36 +58,105 @@ beforeAll(async () => { try { const { syncDefaultErrorRules } = await import("@/repository/error-rules"); await syncDefaultErrorRules(); - console.log("✅ 默认错误规则已同步\n"); + console.log("默认错误规则已同步\n"); } catch (error) { - console.warn("⚠️ 无法同步默认错误规则:", error); + console.warn("无法同步默认错误规则:", error); } } + + // ==================== 并行 Worker 清理协调 ==================== + // setupFiles 会在每个 worker 中执行;如果每个 worker 都在 afterAll 清理数据库,会出现“互相清理”的竞态。 + // 这里用 Redis 计数器实现:只有最后一个结束的 worker 才执行 cleanup。 + try { + const shouldCleanup = Boolean(dsn) && process.env.AUTO_CLEANUP_TEST_DATA !== "false"; + if (!shouldCleanup) return; + + const dbNameForKey = dbName || "unknown"; + const counterKey = `cch:vitest:cleanup_workers:${dbNameForKey}`; + const { getRedisClient } = await import("@/lib/redis"); + const redis = getRedisClient(); + if (!redis) return; + + // 等待连接就绪(enableOfflineQueue=false,未 ready 时发命令会直接报错) + if (redis.status !== "ready") { + await new Promise((resolve) => { + const timeout = setTimeout(resolve, 2000); + redis.once("ready", () => { + clearTimeout(timeout); + resolve(); + }); + }); + } + + if (redis.status !== "ready") { + console.warn("Redis 未就绪,跳过并行清理协调(不影响测试结果)"); + return; + } + + const current = await redis.incr(counterKey); + if (current === 1) { + // 防止异常退出导致计数器常驻 + await redis.expire(counterKey, 60 * 15); + } + process.env.__VITEST_CLEANUP_COUNTER_KEY__ = counterKey; + } catch (error) { + console.warn("并行清理协调初始化失败(不影响测试结果):", error); + } }); // ==================== 全局清理钩子 ==================== afterAll(async () => { - console.log("\n🧹 Vitest 测试环境清理...\n"); + console.log("\nVitest 测试环境清理...\n"); // 清理测试期间创建的用户(仅清理最近 10 分钟内的) const dsn = process.env.DSN || ""; if (dsn && process.env.AUTO_CLEANUP_TEST_DATA !== "false") { try { - const { cleanupRecentTestData } = await import("./cleanup-utils"); - const result = await cleanupRecentTestData(); - if (result.deletedUsers > 0) { - console.log(`✅ 自动清理:删除 ${result.deletedUsers} 个测试用户\n`); + // 仅最后一个 worker 执行清理,避免并发互相删除 + const counterKey = process.env.__VITEST_CLEANUP_COUNTER_KEY__; + const { getRedisClient } = await import("@/lib/redis"); + const redis = counterKey ? getRedisClient() : null; + + if (counterKey && redis) { + if (redis.status !== "ready") { + await new Promise((resolve) => { + const timeout = setTimeout(resolve, 2000); + redis.once("ready", () => { + clearTimeout(timeout); + resolve(); + }); + }); + } + + if (redis.status === "ready") { + const remaining = await redis.decr(counterKey); + if (remaining <= 0) { + const { cleanupRecentTestData } = await import("./cleanup-utils"); + const result = await cleanupRecentTestData(); + if (result.deletedUsers > 0) { + console.log(`自动清理:删除 ${result.deletedUsers} 个测试用户\n`); + } + await redis.del(counterKey); + } else { + // 非最后一个 worker:跳过清理 + } + } else { + console.warn("Redis 未就绪,跳过自动清理(不影响测试结果)"); + } + } else { + // 无 Redis 协调:为了避免竞态,默认跳过清理 + console.warn("未启用清理协调,跳过自动清理(不影响测试结果)"); } } catch (error) { console.warn( - "⚠️ 自动清理失败(不影响测试结果):", + "自动清理失败(不影响测试结果):", error instanceof Error ? error.message : error ); } } - console.log("🧹 Vitest 测试环境清理完成\n"); + console.log("Vitest 测试环境清理完成\n"); }); // ==================== 全局 Mock 配置(可选)==================== diff --git a/tests/test-utils.ts b/tests/test-utils.ts index f21782add..dd13d3021 100644 --- a/tests/test-utils.ts +++ b/tests/test-utils.ts @@ -13,6 +13,7 @@ * - 以及需要校验“缺少 Cookie 直接 401”的端点 */ +import { Request as UndiciRequest } from "undici"; import { GET, POST } from "@/app/api/actions/[...route]/route"; export type ActionsRouteCallOptions = { @@ -52,13 +53,14 @@ export async function callActionsRoute(options: ActionsRouteCallOptions): Promis if (options.authToken) { const existing = headers.Cookie ? `${headers.Cookie}; ` : ""; headers.Cookie = `${existing}auth-token=${options.authToken}`; + headers.Authorization = headers.Authorization ?? `Bearer ${options.authToken}`; } if (options.method === "POST") { headers["Content-Type"] = headers["Content-Type"] ?? "application/json"; } - const request = new Request(url, { + const request = new UndiciRequest(url, { method: options.method, headers, body: options.method === "POST" ? JSON.stringify(options.body ?? {}) : undefined, diff --git a/tests/unit/proxy/session-guard-warmup-intercept.test.ts b/tests/unit/proxy/session-guard-warmup-intercept.test.ts index 7e12dc9ea..a63ac88fd 100644 --- a/tests/unit/proxy/session-guard-warmup-intercept.test.ts +++ b/tests/unit/proxy/session-guard-warmup-intercept.test.ts @@ -40,8 +40,10 @@ vi.mock("@/lib/session-tracker", () => ({ vi.mock("@/lib/logger", () => ({ logger: { debug: vi.fn(), + info: vi.fn(), warn: vi.fn(), error: vi.fn(), + fatal: vi.fn(), trace: vi.fn(), }, })); diff --git a/tests/unit/repository/warmup-stats-exclusion.test.ts b/tests/unit/repository/warmup-stats-exclusion.test.ts index b26dfc05a..1339321f6 100644 --- a/tests/unit/repository/warmup-stats-exclusion.test.ts +++ b/tests/unit/repository/warmup-stats-exclusion.test.ts @@ -179,8 +179,12 @@ describe("Warmup 请求:不计入任何聚合统计", () => { vi.doMock("@/lib/logger", () => ({ logger: { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), trace: vi.fn(), error: vi.fn(), + fatal: vi.fn(), }, })); diff --git a/vitest.my-usage.config.ts b/vitest.my-usage.config.ts new file mode 100644 index 000000000..40533a6bd --- /dev/null +++ b/vitest.my-usage.config.ts @@ -0,0 +1,58 @@ +import path from "node:path"; +import { defineConfig } from "vitest/config"; + +/** + * my-usage 专项覆盖率配置 + * + * 目的: + * - 仅统计本次改动相关模块,避免把需要完整 Next/Redis/Bull 的重模块纳入全局阈值 + * - 对“只读 Key 自助查询”这类安全敏感接口设置更高覆盖率门槛(>= 80%) + */ +export default defineConfig({ + test: { + globals: true, + environment: "node", + setupFiles: ["./tests/setup.ts"], + + include: [ + "tests/api/my-usage-readonly.test.ts", + "tests/api/api-actions-integrity.test.ts", + "tests/api/auth.unit.test.ts", + "tests/api/action-adapter-openapi.unit.test.ts", + ], + exclude: ["node_modules", ".next", "dist", "build", "coverage", "tests/integration/**"], + + coverage: { + provider: "v8", + reporter: ["text", "html", "json"], + reportsDirectory: "./coverage-my-usage", + + include: [ + "src/actions/my-usage.ts", + "src/lib/auth.ts", + "src/lib/api/action-adapter-openapi.ts", + ], + exclude: ["node_modules/", "tests/", "**/*.d.ts", ".next/"], + + thresholds: { + lines: 80, + functions: 80, + branches: 70, + statements: 80, + }, + }, + + reporters: ["verbose"], + isolate: true, + mockReset: true, + restoreMocks: true, + clearMocks: true, + }, + + resolve: { + alias: { + "@": path.resolve(__dirname, "./src"), + "server-only": path.resolve(__dirname, "./tests/server-only.mock.ts"), + }, + }, +});