Skip to content

Commit

Permalink
Merge pull request #111 from solidSpoon/migrate-to-ai-sdk
Browse files Browse the repository at this point in the history
migrate from LangChain to AI SDK
  • Loading branch information
solidSpoon authored Feb 15, 2025
2 parents fc07a5c + 35af2c9 commit afe656a
Show file tree
Hide file tree
Showing 22 changed files with 235 additions and 352 deletions.
6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "dash-player",
"productName": "DashPlayer",
"version": "5.1.2",
"version": "5.1.3",
"description": "My Electron application description",
"main": ".vite/build/main.js",
"scripts": {
Expand Down Expand Up @@ -55,11 +55,11 @@
"vitest": "^1.3.1"
},
"dependencies": {
"@ai-sdk/openai": "^1.1.11",
"@electron-forge/publisher-github": "^7.4.0",
"@ffmpeg-installer/ffmpeg": "^1.1.0",
"@floating-ui/react": "^0.26.9",
"@hookform/resolvers": "^3.9.1",
"@langchain/core": "^0.3.5",
"@radix-ui/react-aspect-ratio": "^1.1.0",
"@radix-ui/react-checkbox": "^1.0.4",
"@radix-ui/react-context-menu": "^2.1.5",
Expand All @@ -82,6 +82,7 @@
"@types/fluent-ffmpeg": "^2.1.24",
"@uidotdev/usehooks": "^2.4.1",
"@vitejs/plugin-react": "^4.2.1",
"ai": "^4.1.41",
"axios": "^1.6.8",
"better-sqlite3": "^9.4.0",
"class-variance-authority": "^0.7.0",
Expand All @@ -106,7 +107,6 @@
"iconv-lite": "^0.6.3",
"inversify": "^6.0.2",
"jschardet": "^3.1.2",
"langchain": "^0.3.2",
"leven": "^3.1.0",
"lucide-react": "^0.378.0",
"moment": "^2.30.1",
Expand Down
7 changes: 3 additions & 4 deletions src/backend/controllers/AiFuncController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import TtsService from '@/backend/services/TtsService';
import registerRoute from '@/common/api/register';
import AiServiceImpl from '@/backend/services/AiServiceImpl';
import ChatServiceImpl from '@/backend/services/impl/ChatServiceImpl';
import { MsgT, toLangChainMsg } from '@/common/types/msg/interfaces/MsgT';
import UrlUtil from '@/common/utils/UrlUtil';
import { inject, injectable } from 'inversify';
import Controller from '@/backend/interfaces/controller';
import TYPES from '@/backend/ioc/types';
import DpTaskService from '@/backend/services/DpTaskService';
import WhisperService from '@/backend/services/WhisperService';
import { CoreMessage } from 'ai';

@injectable()
export default class AiFuncController implements Controller {
Expand Down Expand Up @@ -83,10 +83,9 @@ export default class AiFuncController implements Controller {
return UrlUtil.dp(await TtsService.tts(string));
}

public async chat({ msgs }: { msgs: MsgT[] }): Promise<number> {
public async chat({ msgs }: { msgs: CoreMessage[] }): Promise<number> {
const taskId = await this.dpTaskService.create();
const ms = msgs.map((msg) => toLangChainMsg(msg));
this.chatService.chat(taskId, ms).then();
this.chatService.chat(taskId, msgs).then();
return taskId;
}

Expand Down
4 changes: 2 additions & 2 deletions src/backend/ioc/inversify.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ import MediaServiceImpl from '@/backend/services/impl/MediaServiceImpl';
import ClientProviderService from '@/backend/services/ClientProviderService';
import YouDaoProvider from '@/backend/services/impl/clients/YouDaoProvider';
import TencentProvider from '@/backend/services/impl/clients/TencentProvider';
import { ChatOpenAI } from '@langchain/openai';
import TranslateServiceImpl from '@/backend/services/impl/TranslateServiceImpl';
import TranslateService from '@/backend/services/AiTransServiceImpl';
import TagServiceImpl from '@/backend/services/impl/TagServiceImpl';
Expand All @@ -62,13 +61,14 @@ import WatchHistoryServiceImpl from '@/backend/services/impl/WatchHistoryService
import WatchHistoryController from '@/backend/controllers/WatchHistoryController';
import { OpenAIServiceImpl } from '@/backend/services/impl/OpenAIServiceImpl';
import { OpenAiService } from '@/backend/services/OpenAiService';
import AiProviderService from '@/backend/services/AiProviderService';


const container = new Container();
// Clients
container.bind<ClientProviderService<YouDaoClient>>(TYPES.YouDaoClientProvider).to(YouDaoProvider).inSingletonScope();
container.bind<ClientProviderService<TencentClient>>(TYPES.TencentClientProvider).to(TencentProvider).inSingletonScope();
container.bind<ClientProviderService<ChatOpenAI>>(TYPES.OpenAiClientProvider).to(AiProviderServiceImpl).inSingletonScope();
container.bind<AiProviderService>(TYPES.AiProviderService).to(AiProviderServiceImpl).inSingletonScope();
// Controllers
container.bind<Controller>(TYPES.Controller).to(FavoriteClipsController).inSingletonScope();
container.bind<Controller>(TYPES.Controller).to(DownloadVideoController).inSingletonScope();
Expand Down
2 changes: 1 addition & 1 deletion src/backend/ioc/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const TYPES = {
// Clients
YouDaoClientProvider: Symbol('YouDaoClientProvider'),
TencentClientProvider: Symbol('TencentClientProvider'),
OpenAiClientProvider: Symbol('OpenAiClientProvider'),
AiProviderService: Symbol('AiProviderService'),
};

export default TYPES;
5 changes: 5 additions & 0 deletions src/backend/services/AiProviderService.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import { LanguageModelV1 } from 'ai';

export default interface AiProviderService {
getModel(): LanguageModelV1 | null;
}
6 changes: 4 additions & 2 deletions src/backend/services/AiServiceImpl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import { getSubtitleContent, srtSlice } from '@/common/utils/srtSlice';
import { inject, injectable } from 'inversify';
import TYPES from '@/backend/ioc/types';
import ChatService from '@/backend/services/ChatService';
import { HumanMessage } from '@langchain/core/messages';

export interface AiService {
polish(taskId: number, sentence: string): Promise<void>;
Expand Down Expand Up @@ -55,7 +54,10 @@ export default class AiServiceImpl implements AiService {

public async formatSplit(taskId: number, text: string) {
// await AiFunc.run(taskId, null, AiFuncFormatSplitPrompt.promptFunc(text));
await this.chatService.chat(taskId, [new HumanMessage(AiFuncFormatSplitPrompt.promptFunc(text))]);
await this.chatService.chat(taskId, [{
role: 'user',
content: AiFuncFormatSplitPrompt.promptFunc(text)
}]);
}

public async analyzeWord(taskId: number, sentence: string) {
Expand Down
4 changes: 2 additions & 2 deletions src/backend/services/ChatService.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { BaseMessage } from '@langchain/core/messages';
import { ZodObject } from 'zod';
import { CoreMessage } from 'ai';

export default interface ChatService {
chat(taskId: number, msgs: BaseMessage[]): Promise<void>;
chat(taskId: number, msgs: CoreMessage[]): Promise<void>;
run(taskId: number, resultSchema: ZodObject<any>, promptStr: string): Promise<void>;
}

Expand Down
68 changes: 28 additions & 40 deletions src/backend/services/impl/ChatServiceImpl.ts
Original file line number Diff line number Diff line change
@@ -1,44 +1,41 @@
import RateLimiter from '@/common/utils/RateLimiter';
import { BaseMessage } from '@langchain/core/messages';
import { inject, injectable } from 'inversify';
import DpTaskService from '@/backend/services/DpTaskService';
import TYPES from '@/backend/ioc/types';
import ChatService from '@/backend/services/ChatService';
import { ChatOpenAI } from '@langchain/openai';
import ClientProviderService from '@/backend/services/ClientProviderService';
import { ZodObject } from 'zod';
import { storeGet } from '@/backend/store';


import { CoreMessage, streamObject, streamText } from 'ai';
import AiProviderService from '@/backend/services/AiProviderService';
@injectable()
export default class ChatServiceImpl implements ChatService {

@inject(TYPES.DpTaskService)
private dpTaskService!: DpTaskService;

@inject(TYPES.OpenAiClientProvider)
private aiProviderService!: ClientProviderService<ChatOpenAI>;
@inject(TYPES.AiProviderService)
private aiProviderService!: AiProviderService;


public async chat(taskId: number, msgs: BaseMessage[]) {
public async chat(taskId: number, msgs: CoreMessage[]) {
await RateLimiter.wait('gpt');
const chat = this.aiProviderService.getClient();
if (chat) {
const model = this.aiProviderService.getModel();
if (!model) {
this.dpTaskService.fail(taskId, {
progress: 'OpenAI api key or endpoint is empty'
});
return;
}
this.dpTaskService.process(taskId, {
progress: 'AI is thinking...'
});
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
const resStream = await chat.stream(msgs);
const chunks = [];

const result = streamText({
model: model,
messages: msgs
});
let res = '';
for await (const chunk of resStream) {
res += chunk.content;
chunks.push(chunk);
for await (const chunk of result.textStream) {
res += chunk;
this.dpTaskService.process(taskId, {
progress: `AI typing, ${res.length} characters`,
result: res
Expand All @@ -52,38 +49,29 @@ export default class ChatServiceImpl implements ChatService {

public async run(taskId: number, resultSchema: ZodObject<any>, promptStr: string) {
await RateLimiter.wait('gpt');
const chat = this.aiProviderService.getClient();
if (!chat) {
const model = this.aiProviderService.getModel();
if (!model) {
this.dpTaskService.fail(taskId, {
progress: 'OpenAI api key or endpoint is empty'
});
return;
}
const structuredLlm = chat.withStructuredOutput(resultSchema);

const { partialObjectStream } = streamObject({
model: model,
schema: resultSchema,
prompt: promptStr,
});
this.dpTaskService.process(taskId, {
progress: 'AI is analyzing...'
});

const streaming = storeGet('apiKeys.openAi.stream') === 'on';

let resStr = null;
if (streaming) {
const resStream = await structuredLlm.stream(promptStr);
for await (const chunk of resStream) {
resStr = JSON.stringify(chunk);
this.dpTaskService.process(taskId, {
progress: 'AI is analyzing...',
result: resStr
});
}
} else {
const res = await structuredLlm.invoke(promptStr);
resStr = JSON.stringify(res);
for await (const partialObject of partialObjectStream) {
this.dpTaskService.process(taskId, {
progress: 'AI is analyzing...',
result: JSON.stringify(partialObject)
});
}
this.dpTaskService.finish(taskId, {
progress: 'AI has responded',
result: resStr
progress: 'AI has responded'
});
}
}
Expand Down
23 changes: 11 additions & 12 deletions src/backend/services/impl/clients/AiProviderServiceImpl.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import { ChatOpenAI } from '@langchain/openai';
import { storeGet } from '@/backend/store';
import StrUtil from '@/common/utils/str-util';
import { joinUrl } from '@/common/utils/Util';
import { injectable } from 'inversify';
import ClientProviderService from '@/backend/services/ClientProviderService';
import AiProviderService from '@/backend/services/AiProviderService';
import { createOpenAI } from '@ai-sdk/openai';
import { LanguageModelV1 } from 'ai';


@injectable()
export default class AiProviderServiceImpl implements ClientProviderService<ChatOpenAI> {
public getClient(): ChatOpenAI | null {
export default class AiProviderServiceImpl implements AiProviderService {

public getModel():LanguageModelV1 | null {
const apiKey = storeGet('apiKeys.openAi.key');
const endpoint = storeGet('apiKeys.openAi.endpoint');
let model = storeGet('model.gpt.default');
Expand All @@ -18,14 +20,11 @@ export default class AiProviderServiceImpl implements ClientProviderService<Chat
if (StrUtil.hasBlank(apiKey, endpoint)) {
return null;
}
console.log(apiKey, endpoint);
return new ChatOpenAI({
modelName: model,
temperature: 0.7,
openAIApiKey: apiKey,
configuration: {
baseURL: joinUrl(endpoint, '/v1')
},
const openai = createOpenAI({
compatibility: 'compatible',
baseURL: joinUrl(endpoint, '/v1'),
apiKey: apiKey
});
return openai(model);
}
}
4 changes: 2 additions & 2 deletions src/common/api/api-def.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { MsgT } from '@/common/types/msg/interfaces/MsgT';
import { DpTask } from '@/backend/db/tables/dpTask';
import { YdRes } from '@/common/types/YdRes';
import { ChapterParseResult } from '@/common/types/chapter-result';
Expand All @@ -16,6 +15,7 @@ import { ClipQuery } from '@/common/api/dto';
import { ClipMeta, OssBaseMeta } from '@/common/types/clipMeta';
import WatchHistoryVO from '@/common/types/WatchHistoryVO';
import { COOKIE } from '@/common/types/DlVideoType';
import { CoreMessage } from 'ai';

interface ApiDefinition {
'eg': { params: string, return: number },
Expand All @@ -32,7 +32,7 @@ interface AiFuncDef {
'ai-func/analyze-grammars': { params: string, return: number };
'ai-func/analyze-new-phrases': { params: string, return: number };
'ai-func/analyze-new-words': { params: string, return: number };
'ai-func/chat': { params: { msgs: MsgT[] }, return: number };
'ai-func/chat': { params: { msgs: CoreMessage[] }, return: number };
'ai-func/transcript': { params: { filePath: string }, return: number };
'ai-func/explain-select-with-context': { params: { sentence: string, selectedWord: string }, return: number };
'ai-func/explain-select': { params: { word: string }, return: number };
Expand Down
33 changes: 0 additions & 33 deletions src/common/types/ChatMessage.ts

This file was deleted.

8 changes: 4 additions & 4 deletions src/common/types/msg/AiCtxMenuExplainSelectMessage.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import CustomMessage, { MsgType } from '@/common/types/msg/interfaces/CustomMessage';
import { MsgT } from '@/common/types/msg/interfaces/MsgT';
import { codeBlock } from 'common-tags';
import { Topic } from '@/fronted/hooks/useChatPanel';
import { AiFuncExplainSelectRes } from '@/common/types/aiRes/AiFuncExplainSelectRes';
import { getDpTaskResult } from '@/fronted/hooks/useDpTaskCenter';
import { CoreMessage } from 'ai';

export default class AiCtxMenuExplainSelectMessage implements CustomMessage<AiCtxMenuExplainSelectMessage> {
public taskId: number;
Expand All @@ -22,7 +22,7 @@ export default class AiCtxMenuExplainSelectMessage implements CustomMessage<AiCt

msgType: MsgType = 'ai-func-explain-select';

async toMsg(): Promise<MsgT[]> {
async toMsg(): Promise<CoreMessage[]> {

const resp = await getDpTaskResult<AiFuncExplainSelectRes>(this.taskId);
// 根据以上信息编造一个假的回复
Expand All @@ -38,10 +38,10 @@ export default class AiCtxMenuExplainSelectMessage implements CustomMessage<AiCt
- 例句3:${resp?.examplesSentence3}
`
return [{
type:'human',
role:'user',
content: `请帮我理解这个单词/短语 ${this.word}`
},{
type:'ai',
role:'assistant',
content: aiResp
}];
}
Expand Down
Loading

0 comments on commit afe656a

Please sign in to comment.