Skip to content

Commit

Permalink
[Worker] Move reload upon worker termination to WebWorker (#533)
Browse files Browse the repository at this point in the history
In #471, we add `modelId` and
`chatOpts` to `ServiceWorkerMLCEngineHandler ` and
`ExtensionServiceWorkerMLCEngineHandler`. The motivation was to reload
the model once the handler detects an out-of-sync between what the
frontend engine expects versus what the low-level engine is loaded with.
This out-of-sync usually happens when a service worker is killed
unexpectedly.

Recently, it is observed that such out-of-sync can happen in web worker
as well. Therefore, we move the logic to `WebWorkerMLCEngineHandler`.
The aforementioned handlers calls `super.onmessage()` at the end, so no
change in behavior is expected from them.

Tested in WebLLMChat, intentionally terminating service worker and
sending another message trigger the same behavior as before.
  • Loading branch information
CharlieFRuan authored Aug 10, 2024
1 parent b0486fc commit 5472977
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 156 deletions.
76 changes: 1 addition & 75 deletions src/extension_service_worker.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import * as tvmjs from "tvmjs";
import log from "loglevel";
import { ChatOptions, MLCEngineConfig } from "./config";
import {
ReloadParams,
WorkerRequest,
ChatCompletionNonStreamingParams,
ChatCompletionStreamInitParams,
} from "./message";
import { MLCEngineInterface } from "./types";
import { ReloadParams, WorkerRequest } from "./message";
import {
ChatWorker,
WebWorkerMLCEngineHandler,
WebWorkerMLCEngine,
} from "./web_worker";
import { areChatOptionsEqual } from "./utils";
import { ChatCompletionChunk } from "./openai_api_protocols/index";
import { WebGPUNotFoundError } from "./error";

export interface ExtensionMLCEngineConfig extends MLCEngineConfig {
Expand All @@ -39,17 +32,6 @@ export interface ExtensionMLCEngineConfig extends MLCEngineConfig {
* });
*/
export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler {
/**
* The modelId and chatOpts that the underlying engine (backend) is currently loaded with.
*
* TODO(webllm-team): This is always in-sync with `this.engine` unless device is lost due to
* unexpected reason. Therefore, we should get it from `this.engine` directly and make handler
* stateless. We should also perhaps make `engine` of type `MLCEngine` instead. Besides, consider
* if we should add appConfig, or use engine's API to find the corresponding model record rather
* than relying on just the modelId.
*/
modelId?: string;
chatOpts?: ChatOptions;
port: chrome.runtime.Port | null;

constructor(port: chrome.runtime.Port) {
Expand Down Expand Up @@ -114,62 +96,6 @@ export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler {
return;
}

// Unset modelId and chatOpts since backend unloads the model
if (msg.kind === "unload") {
this.handleTask(msg.uuid, async () => {
await this.engine.unload();
this.modelId = undefined;
this.chatOpts = undefined;
return null;
});
return;
}

if (msg.kind === "chatCompletionNonStreaming") {
// Directly return the ChatCompletion response
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionNonStreamingParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in ServiceWorkerMLCEngineHandler.");
await this.engine.reload(params.modelId, params.chatOpts);
}
const res = await this.engine.chatCompletion(params.request);
return res;
});
return;
}

if (msg.kind === "chatCompletionStreamInit") {
// One-time set up that instantiates the chunk generator in worker
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionStreamInitParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in ServiceWorkerMLCEngineHandler.");
await this.engine.reload(params.modelId, params.chatOpts);
}
this.chatCompletionAsyncChunkGenerator =
(await this.engine.chatCompletion(params.request)) as AsyncGenerator<
ChatCompletionChunk,
void,
void
>;
return null;
});
return;
}

// All rest of message handling are the same as WebWorkerMLCEngineHandler
super.onmessage(event);
}
Expand Down
81 changes: 1 addition & 80 deletions src/service_worker.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import * as tvmjs from "tvmjs";
import log from "loglevel";
import { ChatOptions, MLCEngineConfig } from "./config";
import {
ReloadParams,
WorkerRequest,
WorkerResponse,
ChatCompletionNonStreamingParams,
ChatCompletionStreamInitParams,
} from "./message";
import { ReloadParams, WorkerRequest, WorkerResponse } from "./message";
import { InitProgressReport } from "./types";
import {
WebWorkerMLCEngineHandler,
WebWorkerMLCEngine,
ChatWorker,
} from "./web_worker";
import { areChatOptionsEqual } from "./utils";
import { ChatCompletionChunk } from "./openai_api_protocols/index";
import {
NoServiceWorkerAPIError,
NonWorkerEnvironmentError,
Expand Down Expand Up @@ -43,18 +36,6 @@ type IServiceWorker = globalThis.ServiceWorker;
* });
*/
export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler {
/**
* The modelId and chatOpts that the underlying engine (backend) is currently loaded with.
*
* TODO(webllm-team): This is always in-sync with `this.engine` unless device is lost due to
* unexpected reason. Therefore, we should get it from `this.engine` directly and make handler
* stateless. We should also perhaps make `engine` of type `MLCEngine` instead. Besides, consider
* if we should add appConfig, or use engine's API to find the corresponding model record rather
* than relying on just the modelId.
*/
modelId?: string;
chatOpts?: ChatOptions;

private clientRegistry = new Map<
string,
IServiceWorker | Client | MessagePort
Expand Down Expand Up @@ -162,66 +143,6 @@ export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler {
return;
}

if (msg.kind === "unload") {
this.handleTask(msg.uuid, async () => {
await this.engine.unload();
onComplete?.(null);
this.modelId = undefined;
this.chatOpts = undefined;
return null;
});
return;
}

if (msg.kind === "chatCompletionNonStreaming") {
// Directly return the ChatCompletion response
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionNonStreamingParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in ServiceWorkerMLCEngineHandler.");
this.initRequestUuid = msg.uuid;
await this.engine.reload(params.modelId, params.chatOpts);
}
const res = await this.engine.chatCompletion(params.request);
onComplete?.(res);
return res;
});
return;
}

if (msg.kind === "chatCompletionStreamInit") {
// One-time set up that instantiates the chunk generator in worker
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionStreamInitParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in ServiceWorkerMLCEngineHandler.");
this.initRequestUuid = msg.uuid;
await this.engine.reload(params.modelId, params.chatOpts);
}
this.chatCompletionAsyncChunkGenerator =
(await this.engine.chatCompletion(params.request)) as AsyncGenerator<
ChatCompletionChunk,
void,
void
>;
onComplete?.(null);
return null;
});
return;
}

// All rest of message handling are the same as WebWorkerMLCEngineHandler
super.onmessage(msg, onComplete, onError);
}
Expand Down
40 changes: 39 additions & 1 deletion src/web_worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ import {
* onmessage = handler.onmessage;
*/
export class WebWorkerMLCEngineHandler {
/**
* The modelId and chatOpts that the underlying engine (backend) is currently loaded with.
*
* TODO(webllm-team): This is always in-sync with `this.engine` unless device is lost due to
* unexpected reason. Therefore, we should get it from `this.engine` directly and make handler
* stateless. We should also perhaps make `engine` of type `MLCEngine` instead. Besides, consider
* if we should add appConfig, or use engine's API to find the corresponding model record rather
* than relying on just the modelId.
*/
modelId?: string;
chatOpts?: ChatOptions;

public engine: MLCEngine;
protected chatCompletionAsyncChunkGenerator?: AsyncGenerator<
ChatCompletionChunk,
Expand Down Expand Up @@ -124,6 +136,8 @@ export class WebWorkerMLCEngineHandler {
this.handleTask(msg.uuid, async () => {
const params = msg.content as ReloadParams;
await this.engine.reload(params.modelId, params.chatOpts);
this.modelId = params.modelId;
this.chatOpts = params.chatOpts;
onComplete?.(null);
return null;
});
Expand Down Expand Up @@ -170,6 +184,17 @@ export class WebWorkerMLCEngineHandler {
// Directly return the ChatCompletion response
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionNonStreamingParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"WebWorkerMLCEngine expects model is loaded in WebWorkerMLCEngineHandler, " +
"but it is not. This may due to web/service worker is unexpectedly killed. ",
);
log.info("Reloading engine in WebWorkerMLCEngineHandler.");
await this.engine.reload(params.modelId, params.chatOpts);
}

const res = await this.engine.chatCompletion(params.request);
onComplete?.(res);
return res;
Expand All @@ -180,6 +205,16 @@ export class WebWorkerMLCEngineHandler {
// One-time set up that instantiates the chunk generator in worker
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionStreamInitParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"WebWorkerMLCEngine expects model is loaded in WebWorkerMLCEngineHandler, " +
"but it is not. This may due to web/service worker is unexpectedly killed. ",
);
log.info("Reloading engine in WebWorkerMLCEngineHandler.");
await this.engine.reload(params.modelId, params.chatOpts);
}
this.chatCompletionAsyncChunkGenerator =
(await this.engine.chatCompletion(
params.request,
Expand Down Expand Up @@ -221,8 +256,11 @@ export class WebWorkerMLCEngineHandler {
return;
}
case "unload": {
// Unset modelId and chatOpts since backend unloads the model
this.handleTask(msg.uuid, async () => {
await this.engine.unload();
this.modelId = undefined;
this.chatOpts = undefined;
onComplete?.(null);
return null;
});
Expand Down Expand Up @@ -337,7 +375,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
/**
* The modelId and chatOpts that the frontend expects the backend engine is currently loaded
* with. Needed for service worker. It is the backend and handler's job to match up with the
* expectation despite the service worker possibly being killed.
* expectation despite the web/service worker possibly being killed.
*/
modelId?: string;
chatOpts?: ChatOptions;
Expand Down

0 comments on commit 5472977

Please sign in to comment.