diff --git a/src/actions/keys.ts b/src/actions/keys.ts index db11e2308..d2ac784ac 100644 --- a/src/actions/keys.ts +++ b/src/actions/keys.ts @@ -343,6 +343,10 @@ export async function editKey( } } + // 仅当调用方显式携带 expiresAt 字段时才更新/清除该字段: + // - 避免像“仅修改限额”这类局部更新把 expiresAt 意外清空 + const hasExpiresAtField = Object.hasOwn(data, "expiresAt"); + const validatedData = KeyFormSchema.parse(data); // 服务端验证:Key限额不能超过用户限额 @@ -417,9 +421,23 @@ export async function editKey( // 移除 providerGroup 子集校验(用户分组由 Key 分组自动计算) - // 转换 expiresAt: undefined → null(清除日期),string → Date(设置日期) - const expiresAt = - validatedData.expiresAt === undefined ? null : new Date(validatedData.expiresAt); + // 转换 expiresAt: + // - 未携带 expiresAt:不更新该字段 + // - 携带 expiresAt 但为空:清除(永不过期) + // - 携带 expiresAt 且为字符串:设置为对应 Date + const expiresAt = hasExpiresAtField + ? validatedData.expiresAt === undefined + ? null + : new Date(validatedData.expiresAt) + : undefined; + + if (expiresAt && Number.isNaN(expiresAt.getTime())) { + return { + ok: false, + error: tError("INVALID_FORMAT"), + errorCode: ERROR_CODES.INVALID_FORMAT, + }; + } const isAdmin = session.user.role === "admin"; const prevProviderGroup = normalizeProviderGroup(key.providerGroup); @@ -428,7 +446,7 @@ export async function editKey( await updateKey(keyId, { name: validatedData.name, - expires_at: expiresAt, + ...(hasExpiresAtField ? { expires_at: expiresAt } : {}), can_login_web_ui: validatedData.canLoginWebUi, ...(data.isEnabled !== undefined ? { is_enabled: data.isEnabled } : {}), limit_5h_usd: validatedData.limit5hUsd, diff --git a/src/lib/validation/schemas.ts b/src/lib/validation/schemas.ts index 85f538629..406157740 100644 --- a/src/lib/validation/schemas.ts +++ b/src/lib/validation/schemas.ts @@ -199,8 +199,11 @@ export const UpdateUserSchema = z.object({ isEnabled: z.boolean().optional(), expiresAt: z.preprocess( (val) => { - // null/undefined/空字符串 -> 视为未设置 - if (val === null || val === undefined || val === "") return undefined; + // 更新语义: + // - undefined:不更新该字段 + // - null/空字符串:显式清除过期时间(永不过期) + if (val === undefined) return undefined; + if (val === null || val === "") return null; // 已经是 Date 对象 if (val instanceof Date) { @@ -222,6 +225,7 @@ export const UpdateUserSchema = z.object({ }, z .date() + .nullable() .optional() .superRefine((date, ctx) => { if (!date) { diff --git a/tests/unit/actions/keys-edit-key-expires-at-clear.test.ts b/tests/unit/actions/keys-edit-key-expires-at-clear.test.ts new file mode 100644 index 000000000..9a71a34b1 --- /dev/null +++ b/tests/unit/actions/keys-edit-key-expires-at-clear.test.ts @@ -0,0 +1,146 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; + +const getSessionMock = vi.fn(); +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("next/cache", () => ({ + revalidatePath: vi.fn(), +})); + +const getTranslationsMock = vi.fn(async () => (key: string) => key); +vi.mock("next-intl/server", () => ({ + getTranslations: getTranslationsMock, +})); + +const findKeyByIdMock = vi.fn(); +const updateKeyMock = vi.fn(); + +vi.mock("@/repository/key", () => ({ + countActiveKeysByUser: vi.fn(async () => 1), + createKey: vi.fn(async () => ({})), + deleteKey: vi.fn(async () => true), + findActiveKeyByUserIdAndName: vi.fn(async () => null), + findKeyById: findKeyByIdMock, + findKeyList: vi.fn(async () => []), + findKeysWithStatistics: vi.fn(async () => []), + updateKey: updateKeyMock, +})); + +const findUserByIdMock = vi.fn(); +vi.mock("@/repository/user", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + findUserById: findUserByIdMock, + }; +}); + +describe("editKey: expiresAt 清除/不更新语义", () => { + beforeEach(() => { + vi.clearAllMocks(); + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + + findKeyByIdMock.mockResolvedValue({ + id: 1, + userId: 10, + key: "sk-test", + name: "k", + isEnabled: true, + expiresAt: new Date("2026-01-04T23:59:59.999Z"), + canLoginWebUi: true, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 0, + providerGroup: "default", + cacheTtlPreference: null, + createdAt: new Date(), + updatedAt: new Date(), + deletedAt: null, + }); + + findUserByIdMock.mockResolvedValue({ + id: 10, + name: "u", + description: "", + role: "user", + rpm: null, + dailyQuota: null, + providerGroup: "default", + tags: [], + limit5hUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: null, + isEnabled: true, + expiresAt: null, + allowedClients: [], + allowedModels: [], + createdAt: new Date(), + updatedAt: new Date(), + deletedAt: null, + }); + + updateKeyMock.mockResolvedValue({ id: 1 }); + }); + + test("不携带 expiresAt 字段时不应更新 expires_at", async () => { + const { editKey } = await import("@/actions/keys"); + + const res = await editKey(1, { name: "k2" }); + + expect(res.ok).toBe(true); + expect(updateKeyMock).toHaveBeenCalledTimes(1); + + const updatePayload = updateKeyMock.mock.calls[0]?.[1] as Record; + expect(Object.hasOwn(updatePayload, "expires_at")).toBe(false); + }); + + test("携带 expiresAt=undefined 时应清除 expires_at(写入 null)", async () => { + const { editKey } = await import("@/actions/keys"); + + const res = await editKey(1, { name: "k2", expiresAt: undefined }); + + expect(res.ok).toBe(true); + expect(updateKeyMock).toHaveBeenCalledTimes(1); + expect(updateKeyMock).toHaveBeenCalledWith( + 1, + expect.objectContaining({ + expires_at: null, + }) + ); + }); + + test("携带 expiresAt=YYYY-MM-DD 时应写入对应 Date", async () => { + const { editKey } = await import("@/actions/keys"); + + const res = await editKey(1, { name: "k2", expiresAt: "2026-01-04" }); + + expect(res.ok).toBe(true); + expect(updateKeyMock).toHaveBeenCalledTimes(1); + + const updatePayload = updateKeyMock.mock.calls[0]?.[1] as Record; + expect(updatePayload.expires_at).toBeInstanceOf(Date); + expect(Number.isNaN((updatePayload.expires_at as Date).getTime())).toBe(false); + }); + + test("携带非法 expiresAt 字符串应返回 INVALID_FORMAT", async () => { + const { editKey } = await import("@/actions/keys"); + + const res = await editKey(1, { name: "k2", expiresAt: "not-a-date" }); + + expect(res.ok).toBe(false); + if (!res.ok) { + expect(res.errorCode).toBe("INVALID_FORMAT"); + } + }); +}); diff --git a/tests/unit/actions/users-edit-user-expires-at-clear.test.ts b/tests/unit/actions/users-edit-user-expires-at-clear.test.ts new file mode 100644 index 000000000..59688791b --- /dev/null +++ b/tests/unit/actions/users-edit-user-expires-at-clear.test.ts @@ -0,0 +1,49 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; + +const getSessionMock = vi.fn(); +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("next/cache", () => ({ + revalidatePath: vi.fn(), +})); + +const getTranslationsMock = vi.fn(async () => (key: string) => key); +const getLocaleMock = vi.fn(async () => "en"); +vi.mock("next-intl/server", () => ({ + getTranslations: getTranslationsMock, + getLocale: getLocaleMock, +})); + +const updateUserMock = vi.fn(); +vi.mock("@/repository/user", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + updateUser: updateUserMock, + }; +}); + +describe("editUser: expiresAt 清除应写入数据库更新", () => { + beforeEach(() => { + vi.clearAllMocks(); + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + updateUserMock.mockResolvedValue({ id: 123 }); + }); + + test("传入 expiresAt=null 应调用 updateUser(..., { expiresAt: null })", async () => { + const { editUser } = await import("@/actions/users"); + + const res = await editUser(123, { expiresAt: null }); + + expect(res.ok).toBe(true); + expect(updateUserMock).toHaveBeenCalledTimes(1); + expect(updateUserMock).toHaveBeenCalledWith( + 123, + expect.objectContaining({ + expiresAt: null, + }) + ); + }); +}); diff --git a/tests/unit/dashboard/edit-key-form-expiry-clear-ui.test.tsx b/tests/unit/dashboard/edit-key-form-expiry-clear-ui.test.tsx new file mode 100644 index 000000000..bea163770 --- /dev/null +++ b/tests/unit/dashboard/edit-key-form-expiry-clear-ui.test.tsx @@ -0,0 +1,137 @@ +/** + * @vitest-environment happy-dom + */ + +import fs from "node:fs"; +import path from "node:path"; +import type { ReactNode } from "react"; +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { NextIntlClientProvider } from "next-intl"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { Dialog } from "@/components/ui/dialog"; +import { EditKeyForm } from "@/app/[locale]/dashboard/_components/user/forms/edit-key-form"; + +vi.mock("next/navigation", () => ({ + useRouter: () => ({ refresh: vi.fn() }), +})); + +const sonnerMocks = vi.hoisted(() => ({ + toast: { + success: vi.fn(), + error: vi.fn(), + }, +})); +vi.mock("sonner", () => sonnerMocks); + +const keysActionMocks = vi.hoisted(() => ({ + editKey: vi.fn(async () => ({ ok: true })), +})); +vi.mock("@/actions/keys", () => keysActionMocks); + +const providersActionMocks = vi.hoisted(() => ({ + getAvailableProviderGroups: vi.fn(async () => []), +})); +vi.mock("@/actions/providers", () => providersActionMocks); + +function loadMessages() { + const base = path.join(process.cwd(), "messages/en"); + const read = (name: string) => JSON.parse(fs.readFileSync(path.join(base, name), "utf8")); + + return { + common: read("common.json"), + errors: read("errors.json"), + quota: read("quota.json"), + ui: read("ui.json"), + dashboard: read("dashboard.json"), + forms: read("forms.json"), + }; +} + +function render(node: ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + act(() => { + root.render(node); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +function clickButtonByText(text: string) { + const buttons = Array.from(document.body.querySelectorAll("button")); + const btn = buttons.find((b) => (b.textContent || "").includes(text)); + if (!btn) { + throw new Error(`未找到按钮: ${text}`); + } + btn.dispatchEvent(new MouseEvent("click", { bubbles: true })); +} + +describe("EditKeyForm: 清除 expiresAt 后应携带 expiresAt 字段提交(用于触发后端清除)", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + test("点击 Clear Date 后提交应调用 editKey 并携带 expiresAt 字段", async () => { + const messages = loadMessages(); + + const { unmount } = render( + + {}}> + + + + ); + + await act(async () => { + clickButtonByText("2026-01-04"); + }); + + await act(async () => { + clickButtonByText("Clear Date"); + }); + + const submit = document.body.querySelector('button[type="submit"]') as HTMLButtonElement | null; + expect(submit).toBeTruthy(); + + await act(async () => { + submit?.dispatchEvent(new MouseEvent("click", { bubbles: true })); + await new Promise((r) => setTimeout(r, 0)); + }); + + expect(keysActionMocks.editKey).toHaveBeenCalledTimes(1); + const [, payload] = keysActionMocks.editKey.mock.calls[0] as [number, any]; + + // 关键点:必须显式携带 expiresAt 字段(即使为 undefined),后端才会识别为“清除” + expect(Object.hasOwn(payload, "expiresAt")).toBe(true); + + unmount(); + }); +}); diff --git a/tests/unit/dashboard/user-form-expiry-clear-ui.test.tsx b/tests/unit/dashboard/user-form-expiry-clear-ui.test.tsx new file mode 100644 index 000000000..9e50f7b77 --- /dev/null +++ b/tests/unit/dashboard/user-form-expiry-clear-ui.test.tsx @@ -0,0 +1,121 @@ +/** + * @vitest-environment happy-dom + */ + +import fs from "node:fs"; +import path from "node:path"; +import type { ReactNode } from "react"; +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { NextIntlClientProvider } from "next-intl"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { Dialog } from "@/components/ui/dialog"; +import { UserForm } from "@/app/[locale]/dashboard/_components/user/forms/user-form"; + +vi.mock("next/navigation", () => ({ + useRouter: () => ({ refresh: vi.fn() }), +})); + +const sonnerMocks = vi.hoisted(() => ({ + toast: { + success: vi.fn(), + error: vi.fn(), + }, +})); +vi.mock("sonner", () => sonnerMocks); + +const usersActionMocks = vi.hoisted(() => ({ + editUser: vi.fn(async () => ({ ok: true })), + addUser: vi.fn(async () => ({ ok: true, data: { user: { id: 1 } } })), +})); +vi.mock("@/actions/users", () => usersActionMocks); + +const providersActionMocks = vi.hoisted(() => ({ + getAvailableProviderGroups: vi.fn(async () => []), +})); +vi.mock("@/actions/providers", () => providersActionMocks); + +function loadMessages() { + const base = path.join(process.cwd(), "messages/en"); + const read = (name: string) => JSON.parse(fs.readFileSync(path.join(base, name), "utf8")); + + return { + common: read("common.json"), + errors: read("errors.json"), + notifications: read("notifications.json"), + ui: read("ui.json"), + dashboard: read("dashboard.json"), + forms: read("forms.json"), + }; +} + +function render(node: ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + act(() => { + root.render(node); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +function clickButtonByText(text: string) { + const buttons = Array.from(document.body.querySelectorAll("button")); + const btn = buttons.find((b) => (b.textContent || "").includes(text)); + if (!btn) { + throw new Error(`未找到按钮: ${text}`); + } + btn.dispatchEvent(new MouseEvent("click", { bubbles: true })); +} + +describe("UserForm: 清除 expiresAt 后应提交 null", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + test("编辑模式:点击 Clear Date 后提交应调用 editUser(..., { expiresAt: null })", async () => { + const messages = loadMessages(); + const expiresAt = new Date("2026-01-04T23:59:59.999Z"); + + const { unmount } = render( + + {}}> + + + + ); + + await act(async () => { + clickButtonByText("2026-01-04"); + }); + + await act(async () => { + clickButtonByText("Clear Date"); + }); + + const submit = document.body.querySelector('button[type="submit"]') as HTMLButtonElement | null; + expect(submit).toBeTruthy(); + + await act(async () => { + submit?.dispatchEvent(new MouseEvent("click", { bubbles: true })); + await new Promise((r) => setTimeout(r, 0)); + }); + + expect(usersActionMocks.editUser).toHaveBeenCalledTimes(1); + const [, payload] = usersActionMocks.editUser.mock.calls[0] as [number, any]; + expect(payload.expiresAt).toBeNull(); + + unmount(); + }); +}); diff --git a/tests/unit/validation/user-schemas-expires-at-clear.test.ts b/tests/unit/validation/user-schemas-expires-at-clear.test.ts new file mode 100644 index 000000000..a99360b5e --- /dev/null +++ b/tests/unit/validation/user-schemas-expires-at-clear.test.ts @@ -0,0 +1,94 @@ +import { describe, expect, test } from "vitest"; +import { CreateUserSchema, UpdateUserSchema } from "@/lib/validation/schemas"; + +describe("UpdateUserSchema: expiresAt 清除语义", () => { + test("expiresAt=null 应解析为 null(显式清除)", () => { + const parsed = UpdateUserSchema.parse({ expiresAt: null }); + expect(parsed.expiresAt).toBeNull(); + }); + + test("expiresAt='' 应解析为 null(显式清除)", () => { + const parsed = UpdateUserSchema.parse({ expiresAt: "" }); + expect(parsed.expiresAt).toBeNull(); + }); + + test("expiresAt 缺省应保持 undefined(不更新字段)", () => { + const parsed = UpdateUserSchema.parse({}); + expect(parsed.expiresAt).toBeUndefined(); + }); + + test("expiresAt=ISO 字符串应解析为 Date", () => { + const parsed = UpdateUserSchema.parse({ expiresAt: "2026-01-04T23:59:59.999Z" }); + expect(parsed.expiresAt).toBeInstanceOf(Date); + }); + + test("expiresAt=非法字符串应校验失败", () => { + const result = UpdateUserSchema.safeParse({ expiresAt: "not-a-date" }); + expect(result.success).toBe(false); + }); + + test("expiresAt=非法 Date 应校验失败", () => { + const bad = new Date("not-a-date"); + const result = UpdateUserSchema.safeParse({ expiresAt: bad }); + expect(result.success).toBe(false); + }); + + test("expiresAt=非字符串/非 Date 类型应校验失败", () => { + const result = UpdateUserSchema.safeParse({ expiresAt: 123 }); + expect(result.success).toBe(false); + }); + + test("expiresAt 超过 10 年应被拒绝", () => { + const tooFar = new Date(); + tooFar.setFullYear(tooFar.getFullYear() + 11); + + const result = UpdateUserSchema.safeParse({ expiresAt: tooFar }); + expect(result.success).toBe(false); + }); +}); + +describe("CreateUserSchema: expiresAt 兼容性", () => { + test("CreateUserSchema 仍将 expiresAt=null 视为未设置", () => { + const parsed = CreateUserSchema.parse({ name: "test-user", expiresAt: null }); + expect(parsed.expiresAt).toBeUndefined(); + }); + + test("CreateUserSchema 支持 expiresAt=Date(未来时间)", () => { + const future = new Date(); + future.setDate(future.getDate() + 1); + const parsed = CreateUserSchema.parse({ name: "test-user", expiresAt: future }); + expect(parsed.expiresAt).toBeInstanceOf(Date); + }); + + test("CreateUserSchema 支持 expiresAt=ISO 字符串(未来时间)", () => { + const future = new Date(); + future.setDate(future.getDate() + 1); + const parsed = CreateUserSchema.parse({ name: "test-user", expiresAt: future.toISOString() }); + expect(parsed.expiresAt).toBeInstanceOf(Date); + }); + + test("CreateUserSchema: expiresAt=过去时间应被拒绝", () => { + const past = new Date(); + past.setDate(past.getDate() - 1); + const result = CreateUserSchema.safeParse({ name: "test-user", expiresAt: past }); + expect(result.success).toBe(false); + }); + + test("CreateUserSchema: expiresAt 超过 10 年应被拒绝", () => { + const farFuture = new Date(); + farFuture.setFullYear(farFuture.getFullYear() + 11); + const result = CreateUserSchema.safeParse({ name: "test-user", expiresAt: farFuture }); + expect(result.success).toBe(false); + }); + + test("CreateUserSchema: expiresAt=非法 Date 应校验失败", () => { + const bad = new Date("not-a-date"); + const result = CreateUserSchema.safeParse({ name: "test-user", expiresAt: bad }); + expect(result.success).toBe(false); + }); + + test("CreateUserSchema: expiresAt=非法字符串应校验失败", () => { + const result = CreateUserSchema.safeParse({ name: "test-user", expiresAt: "not-a-date" }); + expect(result.success).toBe(false); + }); +});