Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish implementation of webllm #3

Merged
merged 4 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/langchain-community/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ export const config = {
"chat_models/portkey": "chat_models/portkey",
"chat_models/premai": "chat_models/premai",
"chat_models/togetherai": "chat_models/togetherai",
"chat_models/webllm": "chat_models/webllm",
"chat_models/yandex": "chat_models/yandex",
"chat_models/zhipuai": "chat_models/zhipuai",
// callbacks
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"@huggingface/inference": "^2.6.4",
"@jest/globals": "^29.5.0",
"@langchain/scripts": "~0.0",
"@mlc-ai/web-llm": "0.2.28",
"@mlc-ai/web-llm": "^0.2.28",
"@mozilla/readability": "^0.4.4",
"@neondatabase/serverless": "^0.9.1",
"@opensearch-project/opensearch": "^2.2.0",
Expand Down Expand Up @@ -205,7 +205,7 @@
"@google-ai/generativelanguage": "^0.2.1",
"@gradientai/nodejs-sdk": "^1.2.0",
"@huggingface/inference": "^2.6.4",
"@mlc-ai/web-llm": "0.2.28",
"@mlc-ai/web-llm": "^0.2.28",
"@mozilla/readability": "*",
"@neondatabase/serverless": "*",
"@opensearch-project/opensearch": "*",
Expand Down
84 changes: 51 additions & 33 deletions libs/langchain-community/src/chat_models/webllm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ import type { BaseLanguageModelCallOptions } from "@langchain/core/language_mode
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { BaseMessage, AIMessageChunk } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";
import {
ChatModule,
type ModelRecord,
InitProgressCallback,
} from "@mlc-ai/web-llm";
import * as webllm from '@mlc-ai/web-llm'
import { ChatCompletionMessageParam } from '@mlc-ai/web-llm/lib/openai_api_protocols'

// Code from jacoblee93 https://github.com/jacoblee93/fully-local-pdf-chatbot/blob/main/app/lib/chat_models/webllm.ts
Expand All @@ -20,7 +16,9 @@ import { ChatCompletionMessageParam } from '@mlc-ai/web-llm/lib/openai_api_proto
* can set this in the environment variable `LLAMA_PATH`.
*/
export interface WebLLMInputs extends BaseChatModelParams {
modelRecord: ModelRecord;
appConfig?: webllm.AppConfig;
chatOpts?: webllm.ChatOptions;
modelRecord: webllm.ModelRecord;
temperature?: number;
}

Expand All @@ -31,16 +29,19 @@ export interface WebLLMCallOptions extends BaseLanguageModelCallOptions {}
* This can be installed using `npm install -S @mlc-ai/web-llm`
* @example
* ```typescript
* // Initialize the ChatWebLLM model with the model record.
* // Initialize the ChatWebLLM model with the model record and chat options.
* // Note that if the appConfig field is set, the list of model records
* // must include the selected model record for the engine.
* const model = new ChatWebLLM({
* modelRecord: {
* "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f32_1-MLC/resolve/main/",
* "local_id": "Phi2-q4f32_1",
* "model_lib_url": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/phi-2/phi-2-q4f32_1-ctx2k-webgpu.wasm",
* "vram_required_MB": 4032.48,
* "low_resource_required": false,
* "model_url": "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f32_1-MLC/resolve/main/",
* "model_id": "Llama-3-8B-Instruct-q4f32_1",
* "model_lib_url": webllm.modelLibURLPrefix + webllm.modelVersion + "/Llama-3-8B-Instruct-q4f32_1-ctx4k_cs1k-webgpu.wasm,
* },
* temperature: 0.5,
* chatOpts: {
* temperature: 0.5,
* top-p: 2
* }
* });
*
* // Call the model with a message and await the response.
Expand All @@ -52,39 +53,54 @@ export interface WebLLMCallOptions extends BaseLanguageModelCallOptions {}
export class ChatWebLLM extends SimpleChatModel<WebLLMCallOptions> {
static inputs: WebLLMInputs;

protected _chatModule: ChatModule;
protected engine: webllm.EngineInterface;

modelRecord: ModelRecord;
/**
* Configures list of models available to engine via list of ModelRecords.
* @example
*/
// const appConfig: webllm.AppConfig = {
// model_list: [
// {
// "model_url": "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f32_1-MLC/resolve/main/",
// "model_id": "Llama-3-8B-Instruct-q4f32_1",
// "model_lib_url": webllm.modelLibURLPrefix + webllm.modelVersion + "/Llama-3-8B-Instruct-q4f32_1-ctx4k_cs1k-webgpu.wasm",
// },
// ]
// };
appConfig?: webllm.AppConfig;

temperature?: number;
/**
* Configures model options (temperature, etc.).
*/
chatOpts?: webllm.ChatOptions;

modelRecord: webllm.ModelRecord;

static lc_name() {
return "ChatWebLLM";
}

constructor(inputs: WebLLMInputs) {
super(inputs);
this._chatModule = new ChatModule();
this.appConfig = inputs.appConfig;
this.chatOpts = inputs.chatOpts;
this.modelRecord = inputs.modelRecord;
this.temperature = inputs.temperature;
}

_llmType() {
return "web-llm";
return "ChatWebLLM: " + this.modelRecord.model_id
}

_modelType() {
return this.modelRecord.local_id
async initialize() {
this.engine = webllm.Engine().reload(this.modelRecord.model_id, this.appConfig, this.chatOpts)
this.engine.setInitProgressCallback(() => {})
}

async initialize(progressCallback?: InitProgressCallback) {
if (progressCallback !== undefined) {
this._chatModule.setInitProgressCallback(progressCallback);
}
await this._chatModule.reload(this.modelRecord.local_id, undefined, {
model_list: [this.modelRecord],
});
this._chatModule.setInitProgressCallback(() => {});
async reload(newModelRecord: webllm.ModelRecord, newAppConfig?: webllm.AppConfig, newChatOpts?: webllm.ChatOptions) {
if (this.engine !== undefined) {
this.engine.reload(newModelRecord.model_id, newAppConfig, newChatOpts)
} else throw new Error("Initialize model before reloading.")
}

async *_streamResponseChunks(
Expand Down Expand Up @@ -120,23 +136,25 @@ export class ChatWebLLM extends SimpleChatModel<WebLLMCallOptions> {
};
},
);
const stream = this._chatModule.chatCompletionAsyncChunkGenerator(

const stream = this.engine.chat.completions.create(
{
stream: true,
messages: messagesInput,
stop: options.stop,
temperature: this.temperature,
},
{},
logprobs: true
}
);
for await (const chunk of stream) {
// Last chunk has undefined content
const text = chunk.choices[0].delta.content ?? "";
yield new ChatGenerationChunk({
text,
message: new AIMessageChunk({
content: text,
additional_kwargs: {
logprobs: chunk.choices[0].logprobs,
finish_reason: chunk.choices[0].finish_reason
},
}),
});
Expand Down
12 changes: 6 additions & 6 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8953,7 +8953,7 @@ __metadata:
"@langchain/core": ~0.1.60
"@langchain/openai": ~0.0.28
"@langchain/scripts": ~0.0
"@mlc-ai/web-llm": 0.2.28
"@mlc-ai/web-llm": ^0.2.28
"@mozilla/readability": ^0.4.4
"@neondatabase/serverless": ^0.9.1
"@opensearch-project/opensearch": ^2.2.0
Expand Down Expand Up @@ -9088,7 +9088,7 @@ __metadata:
"@google-ai/generativelanguage": ^0.2.1
"@gradientai/nodejs-sdk": ^1.2.0
"@huggingface/inference": ^2.6.4
"@mlc-ai/web-llm": 0.2.28
"@mlc-ai/web-llm": ^0.2.28
"@mozilla/readability": "*"
"@neondatabase/serverless": "*"
"@opensearch-project/opensearch": "*"
Expand Down Expand Up @@ -10115,10 +10115,10 @@ __metadata:
languageName: node
linkType: hard

"@mlc-ai/web-llm@npm:0.2.28":
version: 0.2.28
resolution: "@mlc-ai/web-llm@npm:0.2.28"
checksum: 18188e18c5866e6cccbba587d04329c71215134fb184b6e9d285a76b6978218d8eed071511c11ae71153ba8bcda333bd4b10bbdce77f07a05e269e372d235f0e
"@mlc-ai/web-llm@npm:^0.2.28":
version: 0.2.35
resolution: "@mlc-ai/web-llm@npm:0.2.35"
checksum: 03c1d1847340f88474e1eeed7a91cc09e29299a1216e378385ffe5479c203d39a8656d98c9187864322453a91f046b874d7073662ab04033527079d9bb29bee3
languageName: node
linkType: hard

Expand Down