Skip to content

Commit

Permalink
✨ feat: support duplicate topic
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Jan 4, 2024
1 parent 4c5a345 commit 9074d69
Show file tree
Hide file tree
Showing 10 changed files with 334 additions and 16 deletions.
57 changes: 44 additions & 13 deletions src/app/chat/features/TopicListContent/Topic/TopicContent.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ActionIcon, EditableText, Icon } from '@lobehub/ui';
import { App, Dropdown, type MenuProps, Typography } from 'antd';
import { createStyles } from 'antd-style';
import { MoreVertical, PencilLine, Star, Trash, Wand2 } from 'lucide-react';
import { LucideCopy, MoreVertical, PencilLine, Star, Trash, Wand2 } from 'lucide-react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
Expand Down Expand Up @@ -33,14 +33,23 @@ interface TopicContentProps {
const TopicContent = memo<TopicContentProps>(({ id, title, fav, showMore }) => {
const { t } = useTranslation('common');

const [editing, favoriteTopic, updateTopicTitle, removeTopic, autoRenameTopicTitle] =
useChatStore((s) => [
s.topicRenamingId === id,
s.favoriteTopic,
s.updateTopicTitle,
s.removeTopic,
s.autoRenameTopicTitle,
]);
const [
editing,
favoriteTopic,
updateTopicTitle,
removeTopic,
autoRenameTopicTitle,
duplicateTopic,
// sessionId,
] = useChatStore((s) => [
s.topicRenamingId === id,
s.favoriteTopic,
s.updateTopicTitle,
s.removeTopic,
s.autoRenameTopicTitle,
s.duplicateTopic,
s.activeId,
]);
const { styles, theme } = useStyles();

const toggleEditing = (visible?: boolean) => {
Expand All @@ -51,6 +60,14 @@ const TopicContent = memo<TopicContentProps>(({ id, title, fav, showMore }) => {

const items = useMemo<MenuProps['items']>(
() => [
{
icon: <Icon icon={Wand2} />,
key: 'autoRename',
label: t('topic.actions.autoRename', { ns: 'chat' }),
onClick: () => {
autoRenameTopicTitle(id);
},
},
{
icon: <Icon icon={PencilLine} />,
key: 'rename',
Expand All @@ -60,14 +77,28 @@ const TopicContent = memo<TopicContentProps>(({ id, title, fav, showMore }) => {
},
},
{
icon: <Icon icon={Wand2} />,
key: 'autoRename',
label: t('topic.actions.autoRename', { ns: 'chat' }),
type: 'divider',
},
{
icon: <Icon icon={LucideCopy} />,
key: 'duplicate',
label: t('topic.actions.duplicate', { ns: 'chat' }),
onClick: () => {
autoRenameTopicTitle(id);
duplicateTopic(id);
},
},
// {
// icon: <Icon icon={LucideDownload} />,
// key: 'export',
// label: t('topic.actions.export', { ns: 'chat' }),
// onClick: () => {
// configService.exportSingleTopic(sessionId, id);
// },
// },
{
type: 'divider',
},
// {
// icon: <Icon icon={Share2} />,
// key: 'share',
// label: t('share'),
Expand Down
48 changes: 48 additions & 0 deletions src/database/models/__tests__/message.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { afterEach, beforeEach, describe, expect, it } from 'vitest';

import { DB_Message } from '@/database/schemas/message';
import { ChatMessage } from '@/types/message';

import { CreateMessageParams, MessageModel } from '../message';
Expand Down Expand Up @@ -324,4 +325,51 @@ describe('MessageModel', () => {
expect(messagesInDb).toHaveLength(0);
});
});

describe('duplicateMessages', () => {
it('should duplicate messages and update parentId for copied messages', async () => {
// 创建原始消息和父消息
const parentMessageData: DB_Message = {
content: 'Parent message content',
role: 'user',
sessionId: 'session1',
};
const parentMessage = await MessageModel.create(parentMessageData);

const childMessageData: DB_Message = {
content: 'Child message content',
role: 'user',
sessionId: 'session1',
parentId: parentMessage.id,
};

await MessageModel.create(childMessageData);

// 获取数据库中的消息以进行复制
const originalMessages = await MessageModel.queryAll();

// 执行复制操作
const duplicatedMessages = await MessageModel.duplicateMessages(originalMessages);

// 验证复制的消息数量是否正确
expect(duplicatedMessages.length).toBe(originalMessages.length);

// 验证每个复制的消息是否具有新的唯一ID,并且parentId被正确更新
for (const original of originalMessages) {
const copied = duplicatedMessages.find((m) => m.content === original.content);
expect(copied).toBeDefined();
expect(copied).not.toBeNull();
expect(copied!.id).not.toBe(original.id);
if (original.parentId) {
const originalParent = originalMessages.find((m) => m.id === original.parentId);
expect(originalParent).toBeDefined();
const copiedParent = duplicatedMessages.find(
(m) => m.content === originalParent!.content,
);

expect(copied!.parentId).toBe(copiedParent!.id);
}
}
});
});
});
135 changes: 135 additions & 0 deletions src/database/models/__tests__/topic.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { afterEach, beforeEach, describe, expect, it } from 'vitest';

import { DBModel } from '@/database/core/types/db';
import { CreateMessageParams, MessageModel } from '@/database/models/message';
import { DB_Message } from '@/database/schemas/message';
import { DB_Topic } from '@/database/schemas/topic';
import { nanoid } from '@/utils/uuid';
import * as uuidUtils from '@/utils/uuid';

Expand Down Expand Up @@ -287,4 +289,137 @@ describe('TopicModel', () => {
expect(messagesInDb1).toHaveLength(0);
expect(messagesInDb2).toHaveLength(0);
});

describe('duplicateTopic', () => {
let originalTopic: DBModel<DB_Topic>;
let originalMessages: any[];

beforeEach(async () => {
// 创建一个原始主题
const { id } = await TopicModel.create({
title: 'Original Topic',
sessionId: 'session1',
favorite: false,
});
originalTopic = await TopicModel.findById(id);

// 创建一些关联到原始主题的消息
originalMessages = await Promise.all(
['Message 1', 'Message 2'].map((text) =>
MessageModel.create({
content: text,
topicId: originalTopic.id,
sessionId: originalTopic.sessionId!,
role: 'user',
}),
),
);
});

afterEach(async () => {
// 清理数据库中的所有主题和消息
await TopicModel.clearTable();
await MessageModel.clearTable();
});

it('should duplicate a topic with all associated messages', async () => {
// 执行复制操作
await TopicModel.duplicateTopic(originalTopic.id);

// 验证复制后的主题是否存在
const duplicatedTopic = await TopicModel.findBySessionId(originalTopic.sessionId!);
expect(duplicatedTopic).toHaveLength(2);

// 验证复制后的消息是否存在
const duplicatedMessages = await MessageModel.query({
sessionId: originalTopic.sessionId!,
topicId: duplicatedTopic[1].id, // 假设复制的主题是第二个
});
expect(duplicatedMessages).toHaveLength(originalMessages.length);
});

it('should throw an error if the topic does not exist', async () => {
// 尝试复制一个不存在的主题
const nonExistentTopicId = nanoid();
await expect(TopicModel.duplicateTopic(nonExistentTopicId)).rejects.toThrow(
`Topic with id ${nonExistentTopicId} not found`,
);
});

it('should preserve the properties of the duplicated topic', async () => {
// 执行复制操作
await TopicModel.duplicateTopic(originalTopic.id);

// 获取复制的主题
const topics = await TopicModel.findBySessionId(originalTopic.sessionId!);
const duplicatedTopic = topics.find((topic) => topic.id !== originalTopic.id);

// 验证复制的主题是否保留了原始主题的属性
expect(duplicatedTopic).toBeDefined();
expect(duplicatedTopic).toMatchObject({
title: originalTopic.title,
favorite: originalTopic.favorite,
sessionId: originalTopic.sessionId,
});
// 确保生成了新的 ID
expect(duplicatedTopic.id).not.toBe(originalTopic.id);
});

it('should properly handle the messages hierarchy when duplicating', async () => {
// 创建一个子消息关联到其中一个原始消息
const { id } = await MessageModel.create({
content: 'Child Message',
topicId: originalTopic.id,
parentId: originalMessages[0].id,
sessionId: originalTopic.sessionId!,
role: 'user',
});
const childMessage = await MessageModel.findById(id);

// 执行复制操作
await TopicModel.duplicateTopic(originalTopic.id);

// 获取复制的消息
const duplicatedMessages = await MessageModel.queryBySessionId(originalTopic.sessionId!);

// 验证复制的子消息是否存在并且 parentId 已更新
const duplicatedChildMessage = duplicatedMessages.find(
(message) => message.content === childMessage.content && message.id !== childMessage.id,
);

expect(duplicatedChildMessage).toBeDefined();
expect(duplicatedChildMessage.parentId).not.toBe(childMessage.parentId);
expect(duplicatedChildMessage.parentId).toBeDefined();
});

it('should fail if the database transaction fails', async () => {
// 强制数据库事务失败,例如通过在复制过程中抛出异常
const dbTransactionFailedError = new Error('DB transaction failed');
vi.spyOn(TopicModel['db'], 'transaction').mockImplementation(async () => {
throw dbTransactionFailedError;
});

// 尝试复制主题并捕捉期望的错误
await expect(TopicModel.duplicateTopic(originalTopic.id)).rejects.toThrow(
dbTransactionFailedError,
);
});

it('should not create partial duplicates if the process fails at some point', async () => {
// 假设复制消息的过程中发生了错误
vi.spyOn(MessageModel, 'duplicateMessages').mockImplementation(async () => {
throw new Error('Failed to duplicate messages');
});

// 尝试复制主题,期望会抛出错误
await expect(TopicModel.duplicateTopic(originalTopic.id)).rejects.toThrow();

// 确保没有创建任何副本
const topics = await TopicModel.findBySessionId(originalTopic.sessionId!);
expect(topics).toHaveLength(1); // 只有原始主题

const messages = await MessageModel.queryBySessionId(originalTopic.sessionId!);
expect(messages).toHaveLength(originalMessages.length); // 只有原始消息
});
});
});
35 changes: 35 additions & 0 deletions src/database/models/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,41 @@ class _MessageModel extends BaseModel {
return this.table.where('sessionId').equals(sessionId).toArray();
}

queryByTopicId = async (topicId: string) => {
const dbMessages = await this.table.where('topicId').equals(topicId).toArray();

return dbMessages.map((message) => this.mapToChatMessage(message));
};

async duplicateMessages(messages: ChatMessage[]): Promise<ChatMessage[]> {
const duplicatedMessages = await this.createDuplicateMessages(messages);
// 批量添加复制后的消息到数据库
await this.batchCreate(duplicatedMessages);
return duplicatedMessages;
}

async createDuplicateMessages(messages: ChatMessage[]): Promise<ChatMessage[]> {
// 创建一个映射来存储原始消息ID和复制消息ID之间的关系
const idMapping = new Map<string, string>();

// 首先复制所有消息,并为每个复制的消息生成新的ID
const duplicatedMessages = messages.map((originalMessage) => {
const newId = nanoid();
idMapping.set(originalMessage.id, newId);

return { ...originalMessage, id: newId };
});

// 更新 parentId 为复制后的新ID
for (const duplicatedMessage of duplicatedMessages) {
if (duplicatedMessage.parentId && idMapping.has(duplicatedMessage.parentId)) {
duplicatedMessage.parentId = idMapping.get(duplicatedMessage.parentId);
}
}

return duplicatedMessages;
}

private mapChatMessageToDBMessage(message: ChatMessage): DB_Message {
const { extra, ...messageData } = message;

Expand Down
35 changes: 33 additions & 2 deletions src/database/models/topic.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { BaseModel } from '@/database/core';
import { DBModel } from '@/database/core/types/db';
import { MessageModel } from '@/database/models/message';
import { DB_Topic, DB_TopicSchema } from '@/database/schemas/topic';
import { ChatTopic } from '@/types/topic';
import { nanoid } from '@/utils/uuid';
Expand Down Expand Up @@ -53,14 +55,14 @@ class _TopicModel extends BaseModel {
// handle pageSize
const pagedTopics = sortedTopics.slice(offset, offset + pageSize);

return pagedTopics.map((i) => ({ ...i, favorite: !!i.favorite }));
return pagedTopics.map((i) => this.mapToChatTopic(i));
}

async findBySessionId(sessionId: string) {
return this.table.where({ sessionId }).toArray();
}

async findById(id: string) {
async findById(id: string): Promise<DBModel<DB_Topic>> {
return this.table.get(id);
}

Expand Down Expand Up @@ -198,6 +200,35 @@ class _TopicModel extends BaseModel {
console.timeEnd('queryTopicsByKeyword');
return uniqueTopics.map((i) => ({ ...i, favorite: !!i.favorite }));
}

async duplicateTopic(topicId: string) {
return this.db.transaction('rw', this.db.topics, this.db.messages, async () => {
// Step 1: get DB_Topic
const topic = await this.findById(topicId);

if (!topic) {
throw new Error(`Topic with id ${topicId} not found`);
}

// Step 3: 查询与 `topic` 关联的 `messages`
const originalMessages = await MessageModel.queryByTopicId(topicId);

const duplicateMessages = await MessageModel.duplicateMessages(originalMessages);

const { id } = await this.create({
...this.mapToChatTopic(topic),
messages: duplicateMessages.map((m) => m.id),
sessionId: topic.sessionId!,
});

return id;
});
}

private mapToChatTopic = (dbTopic: DBModel<DB_Topic>): ChatTopic => ({
...dbTopic,
favorite: !!dbTopic.favorite,
});
}

export const TopicModel = new _TopicModel();
Loading

0 comments on commit 9074d69

Please sign in to comment.