Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/actions/keys.ts
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ export async function editKey(
}
}

// 仅当调用方显式携带 expiresAt 字段时才更新/清除该字段:
// - 避免像“仅修改限额”这类局部更新把 expiresAt 意外清空
const hasExpiresAtField = Object.hasOwn(data, "expiresAt");

const validatedData = KeyFormSchema.parse(data);

// 服务端验证:Key限额不能超过用户限额
Expand Down Expand Up @@ -417,9 +421,23 @@ export async function editKey(

// 移除 providerGroup 子集校验(用户分组由 Key 分组自动计算)

// 转换 expiresAt: undefined → null(清除日期),string → Date(设置日期)
const expiresAt =
validatedData.expiresAt === undefined ? null : new Date(validatedData.expiresAt);
// 转换 expiresAt:
// - 未携带 expiresAt:不更新该字段
// - 携带 expiresAt 但为空:清除(永不过期)
// - 携带 expiresAt 且为字符串:设置为对应 Date
const expiresAt = hasExpiresAtField
? validatedData.expiresAt === undefined
? null
: new Date(validatedData.expiresAt)
: undefined;

if (expiresAt && Number.isNaN(expiresAt.getTime())) {
return {
ok: false,
error: tError("INVALID_FORMAT"),
errorCode: ERROR_CODES.INVALID_FORMAT,
};
}

const isAdmin = session.user.role === "admin";
const prevProviderGroup = normalizeProviderGroup(key.providerGroup);
Expand All @@ -428,7 +446,7 @@ export async function editKey(

await updateKey(keyId, {
name: validatedData.name,
expires_at: expiresAt,
...(hasExpiresAtField ? { expires_at: expiresAt } : {}),
can_login_web_ui: validatedData.canLoginWebUi,
...(data.isEnabled !== undefined ? { is_enabled: data.isEnabled } : {}),
limit_5h_usd: validatedData.limit5hUsd,
Expand Down
8 changes: 6 additions & 2 deletions src/lib/validation/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ export const UpdateUserSchema = z.object({
isEnabled: z.boolean().optional(),
expiresAt: z.preprocess(
(val) => {
// null/undefined/空字符串 -> 视为未设置
if (val === null || val === undefined || val === "") return undefined;
// 更新语义:
// - undefined:不更新该字段
// - null/空字符串:显式清除过期时间(永不过期)
if (val === undefined) return undefined;
if (val === null || val === "") return null;

// 已经是 Date 对象
if (val instanceof Date) {
Expand All @@ -222,6 +225,7 @@ export const UpdateUserSchema = z.object({
},
z
.date()
.nullable()
.optional()
.superRefine((date, ctx) => {
if (!date) {
Expand Down
146 changes: 146 additions & 0 deletions tests/unit/actions/keys-edit-key-expires-at-clear.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import { beforeEach, describe, expect, test, vi } from "vitest";

const getSessionMock = vi.fn();
vi.mock("@/lib/auth", () => ({
getSession: getSessionMock,
}));

vi.mock("next/cache", () => ({
revalidatePath: vi.fn(),
}));

const getTranslationsMock = vi.fn(async () => (key: string) => key);
vi.mock("next-intl/server", () => ({
getTranslations: getTranslationsMock,
}));

const findKeyByIdMock = vi.fn();
const updateKeyMock = vi.fn();

vi.mock("@/repository/key", () => ({
countActiveKeysByUser: vi.fn(async () => 1),
createKey: vi.fn(async () => ({})),
deleteKey: vi.fn(async () => true),
findActiveKeyByUserIdAndName: vi.fn(async () => null),
findKeyById: findKeyByIdMock,
findKeyList: vi.fn(async () => []),
findKeysWithStatistics: vi.fn(async () => []),
updateKey: updateKeyMock,
}));

const findUserByIdMock = vi.fn();
vi.mock("@/repository/user", async (importOriginal) => {
const actual = await importOriginal<typeof import("@/repository/user")>();
return {
...actual,
findUserById: findUserByIdMock,
};
});

describe("editKey: expiresAt 清除/不更新语义", () => {
beforeEach(() => {
vi.clearAllMocks();
getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } });

findKeyByIdMock.mockResolvedValue({
id: 1,
userId: 10,
key: "sk-test",
name: "k",
isEnabled: true,
expiresAt: new Date("2026-01-04T23:59:59.999Z"),
canLoginWebUi: true,
limit5hUsd: null,
limitDailyUsd: null,
dailyResetMode: "fixed",
dailyResetTime: "00:00",
limitWeeklyUsd: null,
limitMonthlyUsd: null,
limitTotalUsd: null,
limitConcurrentSessions: 0,
providerGroup: "default",
cacheTtlPreference: null,
createdAt: new Date(),
updatedAt: new Date(),
deletedAt: null,
});

findUserByIdMock.mockResolvedValue({
id: 10,
name: "u",
description: "",
role: "user",
rpm: null,
dailyQuota: null,
providerGroup: "default",
tags: [],
limit5hUsd: null,
dailyResetMode: "fixed",
dailyResetTime: "00:00",
limitWeeklyUsd: null,
limitMonthlyUsd: null,
limitTotalUsd: null,
limitConcurrentSessions: null,
isEnabled: true,
expiresAt: null,
allowedClients: [],
allowedModels: [],
createdAt: new Date(),
updatedAt: new Date(),
deletedAt: null,
});

updateKeyMock.mockResolvedValue({ id: 1 });
});

test("不携带 expiresAt 字段时不应更新 expires_at", async () => {
const { editKey } = await import("@/actions/keys");

const res = await editKey(1, { name: "k2" });

expect(res.ok).toBe(true);
expect(updateKeyMock).toHaveBeenCalledTimes(1);

const updatePayload = updateKeyMock.mock.calls[0]?.[1] as Record<string, unknown>;
expect(Object.hasOwn(updatePayload, "expires_at")).toBe(false);
});

test("携带 expiresAt=undefined 时应清除 expires_at(写入 null)", async () => {
const { editKey } = await import("@/actions/keys");

const res = await editKey(1, { name: "k2", expiresAt: undefined });

expect(res.ok).toBe(true);
expect(updateKeyMock).toHaveBeenCalledTimes(1);
expect(updateKeyMock).toHaveBeenCalledWith(
1,
expect.objectContaining({
expires_at: null,
})
);
});

test("携带 expiresAt=YYYY-MM-DD 时应写入对应 Date", async () => {
const { editKey } = await import("@/actions/keys");

const res = await editKey(1, { name: "k2", expiresAt: "2026-01-04" });

expect(res.ok).toBe(true);
expect(updateKeyMock).toHaveBeenCalledTimes(1);

const updatePayload = updateKeyMock.mock.calls[0]?.[1] as Record<string, unknown>;
expect(updatePayload.expires_at).toBeInstanceOf(Date);
expect(Number.isNaN((updatePayload.expires_at as Date).getTime())).toBe(false);
});

test("携带非法 expiresAt 字符串应返回 INVALID_FORMAT", async () => {
const { editKey } = await import("@/actions/keys");

const res = await editKey(1, { name: "k2", expiresAt: "not-a-date" });

expect(res.ok).toBe(false);
if (!res.ok) {
expect(res.errorCode).toBe("INVALID_FORMAT");
}
});
});
49 changes: 49 additions & 0 deletions tests/unit/actions/users-edit-user-expires-at-clear.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { beforeEach, describe, expect, test, vi } from "vitest";

const getSessionMock = vi.fn();
vi.mock("@/lib/auth", () => ({
getSession: getSessionMock,
}));

vi.mock("next/cache", () => ({
revalidatePath: vi.fn(),
}));

const getTranslationsMock = vi.fn(async () => (key: string) => key);
const getLocaleMock = vi.fn(async () => "en");
vi.mock("next-intl/server", () => ({
getTranslations: getTranslationsMock,
getLocale: getLocaleMock,
}));

const updateUserMock = vi.fn();
vi.mock("@/repository/user", async (importOriginal) => {
const actual = await importOriginal<typeof import("@/repository/user")>();
return {
...actual,
updateUser: updateUserMock,
};
});

describe("editUser: expiresAt 清除应写入数据库更新", () => {
beforeEach(() => {
vi.clearAllMocks();
getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } });
updateUserMock.mockResolvedValue({ id: 123 });
});

test("传入 expiresAt=null 应调用 updateUser(..., { expiresAt: null })", async () => {
const { editUser } = await import("@/actions/users");

const res = await editUser(123, { expiresAt: null });

expect(res.ok).toBe(true);
expect(updateUserMock).toHaveBeenCalledTimes(1);
expect(updateUserMock).toHaveBeenCalledWith(
123,
expect.objectContaining({
expiresAt: null,
})
);
});
});
Loading
Loading