Skip to content

Commit

Permalink
♻️ refactor: refactor service to a uniform interface (lobehub#2062)
Browse files Browse the repository at this point in the history
* ♻️ refactor: refactor the session service interface

* ♻️ refactor: add file service interface

* ♻️ refactor: add message service interface

* ♻️ refactor: add topic service interface

* ✅ test: add test for agent action
  • Loading branch information
arvinxx authored and TheNameIsNigel committed May 15, 2024
1 parent 5261a70 commit e6a3d51
Show file tree
Hide file tree
Showing 40 changed files with 691 additions and 320 deletions.
14 changes: 7 additions & 7 deletions contributing/Basic/Feature-Development.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ This requirement involves upgrading the Sessions feature to transform it from a

To handle these groups, we need to refactor the implementation logic of `useFetchSessions`. Here are the key changes:

1. Use the `sessionService.getSessionsWithGroup` method to call the backend API and retrieve the grouped session data.
1. Use the `sessionService.getGroupedSessions` method to call the backend API and retrieve the grouped session data.
2. Save the retrieved data into three different state fields: `pinnedSessions`, `customSessionGroups`, and `defaultSessions`.

#### `useFetchSessions` Method
Expand All @@ -247,7 +247,7 @@ export const createSessionSlice: StateCreator<
> = (set, get) => ({
// ... other methods
useFetchSessions: () =>
useSWR<ChatSessionList>(FETCH_SESSIONS_KEY, sessionService.getSessionsWithGroup, {
useSWR<ChatSessionList>(FETCH_SESSIONS_KEY, sessionService.getGroupedSessions, {
onSuccess: (data) => {
set(
{
Expand All @@ -267,23 +267,23 @@ export const createSessionSlice: StateCreator<

After successfully retrieving the data, we use the `set` method to update the `customSessionGroups`, `defaultSessions`, `pinnedSessions`, and `sessions` states. This ensures that the states are synchronized with the latest session data.

#### `sessionService.getSessionsWithGroup` Method
#### `sessionService.getGroupedSessions` Method

The `sessionService.getSessionsWithGroup` method is responsible for calling the backend API `SessionModel.queryWithGroups()`.
The `sessionService.getGroupedSessions` method is responsible for calling the backend API `SessionModel.queryWithGroups()`.

```typescript
class SessionService {
// ... other SessionGroup related implementations

async getSessionsWithGroup(): Promise<ChatSessionList> {
async getGroupedSessions(): Promise<ChatSessionList> {
return SessionModel.queryWithGroups();
}
}
```

#### `SessionModel.queryWithGroups` Method

This method is the core method called by `sessionService.getSessionsWithGroup`, and it is responsible for querying and organizing session data. The code is as follows:
This method is the core method called by `sessionService.getGroupedSessions`, and it is responsible for querying and organizing session data. The code is as follows:

```typescript
class _SessionModel extends BaseModel {
Expand Down Expand Up @@ -617,7 +617,7 @@ class ConfigService {
// ... Other code omitted

exportSessions = async () => {
const sessions = await sessionService.getSessions();
const sessions = await sessionService.getAllSessions();
+ const sessionGroups = await sessionService.getSessionGroups();
const messages = await messageService.getAllMessages();
const topics = await topicService.getAllTopics();
Expand Down
14 changes: 7 additions & 7 deletions contributing/Basic/Feature-Development.zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ export const createSessionGroupSlice: StateCreator<

为了处理这些分组,我们需要改造 `useFetchSessions` 的实现逻辑。以下是关键的改动点:

1. 使用 `sessionService.getSessionsWithGroup` 方法负责调用后端接口来获取分组后的会话数据;
1. 使用 `sessionService.getGroupedSessions` 方法负责调用后端接口来获取分组后的会话数据;
2. 将获取后的数据保存为三到不同的状态字段中:`pinnedSessions``customSessionGroups``defaultSessions`

#### `useFetchSessions` 方法
Expand All @@ -247,7 +247,7 @@ export const createSessionSlice: StateCreator<
> = (set, get) => ({
// ... 其他方法
useFetchSessions: () =>
useSWR<ChatSessionList>(FETCH_SESSIONS_KEY, sessionService.getSessionsWithGroup, {
useSWR<ChatSessionList>(FETCH_SESSIONS_KEY, sessionService.getGroupedSessions, {
onSuccess: (data) => {
set(
{
Expand All @@ -267,23 +267,23 @@ export const createSessionSlice: StateCreator<

在成功获取数据后,我们使用 `set` 方法来更新 `customSessionGroups``defaultSessions``pinnedSessions``sessions` 状态。这将保证状态与最新的会话数据同步。

#### getSessionsWithGroup
#### getGroupedSessions

使用 `sessionService.getSessionsWithGroup` 方法负责调用后端接口 `SessionModel.queryWithGroups()`
使用 `sessionService.getGroupedSessions` 方法负责调用后端接口 `SessionModel.queryWithGroups()`

```typescript
class SessionService {
// ... 其他 SessionGroup 相关的实现

async getSessionsWithGroup(): Promise<ChatSessionList> {
async getGroupedSessions(): Promise<ChatSessionList> {
return SessionModel.queryWithGroups();
}
}
```

#### `SessionModel.queryWithGroups` 方法

此方法是 `sessionService.getSessionsWithGroup` 调用的核心方法,它负责查询和组织会话数据,代码如下:
此方法是 `sessionService.getGroupedSessions` 调用的核心方法,它负责查询和组织会话数据,代码如下:

```typescript
class _SessionModel extends BaseModel {
Expand Down Expand Up @@ -611,7 +611,7 @@ class ConfigService {
// ... 省略其他

exportSessions = async () => {
const sessions = await sessionService.getSessions();
const sessions = await sessionService.getAllSessions();
+ const sessionGroups = await sessionService.getSessionGroups();
const messages = await messageService.getAllMessages();
const topics = await topicService.getAllTopics();
Expand Down
4 changes: 2 additions & 2 deletions src/app/home/Redirect.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import { sessionService } from '@/services/session';

const checkHasConversation = async () => {
const hasMessages = await messageService.hasMessages();
const hasAgents = await sessionService.hasSessions();
return hasMessages || hasAgents;
const hasAgents = await sessionService.countSessions();
return hasMessages || hasAgents === 0;
};

const Redirect = memo(() => {
Expand Down
4 changes: 1 addition & 3 deletions src/database/client/models/__tests__/session.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,10 @@ describe('SessionModel', () => {

expect(updatedSession).toHaveProperty('group', 'newGroup');
});
});

describe('updatePinned', () => {
it('should update pinned status of a session', async () => {
const createdSession = await SessionModel.create('agent', sessionData);
await SessionModel.updatePinned(createdSession.id, true);
await SessionModel.update(createdSession.id, { pinned: 1 });
const updatedSession = await SessionModel.findById(createdSession.id);
expect(updatedSession).toHaveProperty('pinned', 1);
});
Expand Down
8 changes: 4 additions & 4 deletions src/database/client/models/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ class _SessionModel extends BaseModel {
return (await this.table.count()) === 0;
}

async count() {
return this.table.count();
}

// **************** Create *************** //

async create(type: 'agent' | 'group', defaultValue: Partial<LobeAgentSession>, id = uuid()) {
Expand Down Expand Up @@ -238,10 +242,6 @@ class _SessionModel extends BaseModel {
return super._updateWithSync(id, data);
}

async updatePinned(id: string, pinned: boolean) {
return this.update(id, { pinned: pinned ? 1 : 0 });
}

async updateConfig(id: string, data: DeepPartial<LobeAgentConfig>) {
const session = await this.findById(id);
if (!session) return;
Expand Down
8 changes: 4 additions & 4 deletions src/services/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ConfigService {
return await sessionService.batchCreateSessions(sessions);
};
importMessages = async (messages: ChatMessage[]) => {
return messageService.batchCreate(messages);
return messageService.batchCreateMessages(messages);
};
importSettings = async (settings: GlobalSettings) => {
useGlobalStore.getState().importAppSettings(settings);
Expand Down Expand Up @@ -105,7 +105,7 @@ class ConfigService {
* export all agents
*/
exportAgents = async () => {
const agents = await sessionService.getAllAgents();
const agents = await sessionService.getSessionsByType('agent');
const sessionGroups = await sessionService.getSessionGroups();

const config = createConfigFile('agents', { sessionGroups, sessions: agents });
Expand All @@ -117,7 +117,7 @@ class ConfigService {
* export all sessions
*/
exportSessions = async () => {
const sessions = await sessionService.getSessions();
const sessions = await sessionService.getSessionsByType();
const sessionGroups = await sessionService.getSessionGroups();
const messages = await messageService.getAllMessages();
const topics = await topicService.getAllTopics();
Expand Down Expand Up @@ -188,7 +188,7 @@ class ConfigService {
* export all data
*/
exportAll = async () => {
const sessions = await sessionService.getSessions();
const sessions = await sessionService.getSessionsByType();
const sessionGroups = await sessionService.getSessionGroups();
const messages = await messageService.getAllMessages();
const topics = await topicService.getAllTopics();
Expand Down
4 changes: 2 additions & 2 deletions src/services/file/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ import { Mock, beforeEach, describe, expect, it, vi } from 'vitest';
import { FileModel } from '@/database/client/models/file';
import { DB_File } from '@/database/client/schemas/files';

import { FileService } from './client';
import { ClientService } from './client';

const fileService = new FileService();
const fileService = new ClientService();

// Mocks for the FileModel
vi.mock('@/database/client/models/file', () => ({
Expand Down
68 changes: 35 additions & 33 deletions src/services/file/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@ import { FilePreview } from '@/types/files';
import compressImage from '@/utils/compressImage';

import { API_ENDPOINTS } from '../_url';
import { IFileService } from './type';

export class FileService {
private isImage(fileType: string) {
const imageRegex = /^image\//;
return imageRegex.test(fileType);
}
export class ClientService implements IFileService {
async uploadFile(file: DB_File) {
// 跳过图片上传测试
const isTestData = file.size === 1;
Expand All @@ -22,26 +19,6 @@ export class FileService {
return FileModel.create(file);
}

async uploadImageFile(file: DB_File) {
// 加载图片
const url = file.url || URL.createObjectURL(new Blob([file.data]));

const img = new Image();
img.src = url;
await (() =>
new Promise((resolve) => {
img.addEventListener('load', resolve);
}))();

// 压缩图片
const base64String = compressImage({ img, type: file.fileType });
const binaryString = atob(base64String.split('base64,')[1]);
const uint8Array = Uint8Array.from(binaryString, (char) => char.charCodeAt(0));
file.data = uint8Array.buffer;

return FileModel.create(file);
}

async uploadImageByUrl(url: string, file: Pick<DB_File, 'name' | 'metadata'>) {
const res = await fetch(API_ENDPOINTS.proxy, { body: url, method: 'POST' });
const data = await res.arrayBuffer();
Expand All @@ -57,14 +34,6 @@ export class FileService {
});
}

async removeFile(id: string) {
return FileModel.delete(id);
}

async removeAllFiles() {
return FileModel.clear();
}

async getFile(id: string): Promise<FilePreview> {
const item = await FileModel.findById(id);
if (!item) {
Expand All @@ -83,4 +52,37 @@ export class FileService {
url,
};
}

async removeFile(id: string) {
return FileModel.delete(id);
}

async removeAllFiles() {
return FileModel.clear();
}

private isImage(fileType: string) {
const imageRegex = /^image\//;
return imageRegex.test(fileType);
}

private async uploadImageFile(file: DB_File) {
// 加载图片
const url = file.url || URL.createObjectURL(new Blob([file.data]));

const img = new Image();
img.src = url;
await (() =>
new Promise((resolve) => {
img.addEventListener('load', resolve);
}))();

// 压缩图片
const base64String = compressImage({ img, type: file.fileType });
const binaryString = atob(base64String.split('base64,')[1]);
const uint8Array = Uint8Array.from(binaryString, (char) => char.charCodeAt(0));
file.data = uint8Array.buffer;

return FileModel.create(file);
}
}
10 changes: 8 additions & 2 deletions src/services/file/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import { FileService } from './client';
// import { getClientConfig } from '@/config/client';
import { ClientService } from './client';

export const fileService = new FileService();
// import { ServerService } from './server';
//
// const { ENABLED_SERVER_SERVICE } = getClientConfig();
//
// export const fileService = ENABLED_SERVER_SERVICE ? new ServerService() : new ClientService();
export const fileService = new ClientService();
11 changes: 11 additions & 0 deletions src/services/file/type.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/* eslint-disable typescript-sort-keys/interface */
import { DB_File } from '@/database/client/schemas/files';
import { FilePreview } from '@/types/files';

export interface IFileService {
uploadFile(file: DB_File): Promise<any>;
uploadImageByUrl(url: string, file: Pick<DB_File, 'name' | 'metadata'>): Promise<any>;
removeFile(id: string): Promise<any>;
removeAllFiles(): Promise<any>;
getFile(id: string): Promise<FilePreview>;
}
Loading

0 comments on commit e6a3d51

Please sign in to comment.