Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Engine proxy draft #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion packages/core/src/ExtensionMessager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,14 @@ export class ExtensionMessagerClient<T extends ExtensionMessageRequestData> {
onMessageHandler: ((message: ExtensionMessageResponse) => void) | undefined
port: chrome.runtime.Port | undefined

constructor({ name, port }: { name: string, port?: chrome.runtime.Port }) {
constructor({ name, port }: { name?: string, port?: chrome.runtime.Port }) {
this.handlers = {}
if (port) {
this.port = port
}
if (!name) {
name = (window || document)?.location?.origin || 'unknown'
}
this.listen(name)
}

Expand Down
16 changes: 16 additions & 0 deletions packages/core/src/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,29 @@ export type AIMaskInferParams =
| ChatCompletionParams
| TranslationParams

type EngineStubParamsBase = {
engine: Model['engine']
modelId: Model['id']
engineParams: any
callParams: any[]
}

export interface TransformersStubParams extends EngineStubParamsBase {
engine: 'transformers.js'
engineParams: {
task: Model['task']
}
}

export type EngineStubParams = TransformersStubParams

export type AIActions =
| ExtensionMessageRequestData<'infer', {
modelId: Model['id']
task: Model['task']
params: AIMaskInferParams
}>
| ExtensionMessageRequestData<'engine_stub', EngineStubParams>
| ExtensionMessageRequestData<'get_models'>

export type AIAction<T> = AIActions & { action: T }
Expand Down
12 changes: 9 additions & 3 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export type Model = {
id: string
name: string
engine: 'web-llm' | 'transformers.js',
task: 'completion' | 'chat' | 'translation',
task: 'completion' | 'chat' | 'translation' | 'feature-extraction',
}

export const models: Model[] = [
Expand Down Expand Up @@ -36,9 +36,15 @@ export const models: Model[] = [
engine: 'transformers.js',
task: 'translation',
},
{
id: 'Xenova/all-MiniLM-L6-v2',
name: 'MiniLM L6 v2',
engine: 'transformers.js',
task: 'feature-extraction'
},
]

export function getModel(id: Model['id']): Model | undefined {
return models.find(model => model.id == id)
export function getModel(id: Model['id'], engine?: Model['engine']): Model | undefined {
return models.find(model => model.id == id && (engine ? model.engine == engine : true))
}
export type LLMDID = typeof models[number]['id']
4 changes: 3 additions & 1 deletion packages/extension/src/background/main.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { AIMaskService } from "../lib/AIMaskService";
import { InternalMessager } from "../lib/InternalMessager";

console.log('hello from background')
Expand Down Expand Up @@ -42,6 +43,7 @@ async function setupOffscreenDocument() {
}

InternalMessager.listen(async () => {
console.log('received message')
// console.log('received message')
await setupOffscreenDocument()
})
new AIMaskService()
29 changes: 22 additions & 7 deletions packages/extension/src/lib/AIMaskInfer.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { ChatModule, InitProgressReport } from "@mlc-ai/web-llm";
import { GenerateProgressCallback } from "@mlc-ai/web-llm/lib/types";
import { env, pipeline, TranslationPipeline } from '@xenova/transformers';
import { env, FeatureExtractionPipeline, pipeline, TranslationPipeline } from '@xenova/transformers';

import { config, AIActionParams, MessagerStreamHandler, Model, ChatCompletionParams, CompletionParams, TranslationParams } from "@ai-mask/core";
import { config, AIActionParams, MessagerStreamHandler, Model, ChatCompletionParams, CompletionParams, TranslationParams, EngineStubParams } from "@ai-mask/core";

export interface ModelLoadReport {
progress: number
Expand All @@ -11,16 +11,17 @@ export interface ModelLoadReport {
}

// https://github.com/xenova/transformers.js/pull/462
// env.backends.onnx.wasm.numThreads = 1 // not needed in offscreen
env.backends.onnx.wasm.numThreads = 1 // not needed in offscreen
env.allowLocalModels = false;

export class AIMaskInferer {
model: Model
inMemory: boolean = false
engineInstance: TranslationPipeline | ChatModule | null = null

constructor(model: Model) {
engineInstance: TranslationPipeline | FeatureExtractionPipeline | ChatModule | null = null
engineParams: EngineStubParams['engineParams']
constructor(model: Model, engineParams?: EngineStubParams['engineParams']) {
this.model = model;
this.engineParams = engineParams
}

isReady(): boolean {
Expand All @@ -35,10 +36,13 @@ export class AIMaskInferer {
async load(progressCallback?: (progress: ModelLoadReport) => void) {
switch (this.model.engine) {
case 'transformers.js':
if (this.model.task === 'completion' || this.model.task === 'chat') {
throw new Error('nope')
}
let files: any[] = []

this.engineInstance = await pipeline(
'translation',
this.model.task,
this.model.id,
{
progress_callback: (event: any) => {
Expand Down Expand Up @@ -123,6 +127,17 @@ export class AIMaskInferer {
}
}

async stub(call_params: any[], streamhandler: MessagerStreamHandler<string>): Promise<any> {
if (this.model.engine !== 'transformers.js') throw new Error(`engine ${this.model.engine} not supported for stubbing`)
if (!this.isReady()) throw new Error('inferer not ready')

console.log(call_params)
const engine = this.getEngine() as FeatureExtractionPipeline

const response = await engine(call_params[0] as string[], call_params[1])

return response.tolist();
}

async taskChat(params: ChatCompletionParams, streamhandler: MessagerStreamHandler<string>): Promise<string> {
if (this.model.engine !== 'web-llm') throw new Error(`engine ${this.model.engine} not supported for completion`)
Expand Down
38 changes: 27 additions & 11 deletions packages/extension/src/lib/AIMaskService.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ExtensionMessagerServer, MessagerStreamHandler, AIActions, AIActionParams, getModel, models } from "@ai-mask/core";
import { ExtensionMessagerServer, MessagerStreamHandler, AIActions, AIActionParams, getModel, models, Model, EngineStubParams } from "@ai-mask/core";
import { extensionState } from "./State";
import { InternalMessage, InternalMessager } from "./InternalMessager";
import { ModelLoadReport, AIMaskInferer } from "./AIMaskInfer";
Expand All @@ -17,7 +17,7 @@ export class AIMaskService {
extensionState.init().catch(console.error)
}

async getInferer(params: AIActionParams<'infer'>): Promise<AIMaskInferer> {
async getInferer(modelId: Model['id'], engine?: Model['engine'], engineParams?: EngineStubParams['engineParams']): Promise<AIMaskInferer> {
const progressHandler = (report: ModelLoadReport) => {
if (!this.inferer) return
const modelId = this.inferer.model.id
Expand All @@ -27,7 +27,7 @@ export class AIMaskService {
}

if (this.inferer) {
if (this.inferer.model.id === params.modelId) {
if (this.inferer.model.id === modelId) {
if (this.inferer.isReady()) return this.inferer
else {
await extensionState.set('status', 'loading')
Expand All @@ -38,18 +38,15 @@ export class AIMaskService {
await this.unloadModel()
}
}
const model = await getModel(params.modelId)
const model = await getModel(modelId, engine)
if (!model) {
throw new Error('model not found')
}
if (model.task !== params.task) {
throw new Error('incompatible task and model')
throw new Error(`model ${modelId} not found`)
}

await extensionState.set('status', 'loading')
this.inferer = new AIMaskInferer(model)
this.inferer = new AIMaskInferer(model, engineParams)
await this.inferer.load(progressHandler)
await extensionState.setCached(params.modelId)
await extensionState.setCached(modelId)
await extensionState.set('status', 'loaded')
await extensionState.set('loaded_model', model.id)
return this.inferer
Expand Down Expand Up @@ -81,7 +78,7 @@ export class AIMaskService {
if (await extensionState.get('status') === 'infering') throw new Error('already infering')

try {
const inferer = await this.getInferer(params)
const inferer = await this.getInferer(params.modelId)
await extensionState.set('status', 'infering')
const response = await inferer.infer(params, streamhandler)

Expand All @@ -94,6 +91,23 @@ export class AIMaskService {
}
}

async onEngineStub(params: AIActionParams<'engine_stub'>, streamhandler: MessagerStreamHandler<string>): Promise<string> {
if (await extensionState.get('status') === 'infering') throw new Error('already infering')

try {
const inferer = await this.getInferer(params.modelId, params.engine, params.engineParams)
await extensionState.set('status', 'infering')
const response = await inferer.stub(params.callParams, streamhandler)

await extensionState.set('status', 'loaded')
return response
} catch (error) {
await extensionState.set('status', 'error')
await extensionState.set('error', JSON.stringify(error))
throw error
}
}

async handleInternalMessage(message: InternalMessage) {
switch (message.type) {
case 'get_state':
Expand All @@ -112,6 +126,8 @@ export class AIMaskService {
switch (request.action) {
case 'infer':
return this.onInfer(request.params, streamhandler)
case 'engine_stub':
return this.onEngineStub(request.params, streamhandler)
case 'get_models':
return models
default:
Expand Down
2 changes: 1 addition & 1 deletion packages/extension/src/offscreen/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ import { AIMaskService } from "../lib/AIMaskService";

console.log('hello from offscreen')

new AIMaskService()
//new AIMaskService()
8 changes: 4 additions & 4 deletions packages/extension/src/popup/components/ModelRow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ export default function ModelRow({ model, extensionState }: { model: Model, exte
}

<div className="w-full flex justify-stretch bg-black">
<div className={clsx("flex flex-col w-20 items-center p-2", status_colors[status])} >
<div className={clsx("flex flex-col w-16 items-center p-2", status_colors[status])} >
<h4 className="text-orange-950">Status</h4>
<p className="text-orange-900">{status}</p>
</div>
<div className="flex flex-col w-16 items-center p-2 bg-orange-300">
<div className="flex flex-col w-14 items-center p-2 bg-orange-300">
<h4 className="text-orange-950">VRAM</h4>
<p className="text-orange-900">{vrams[model.id] || '?'} GB</p>
</div>
<div className={clsx("flex flex-col w-24 items-center p-2 bg-green-300", task_colors[model.task])}>
<div className={clsx("flex flex-col w-28 items-center px-1 py-2 bg-green-300", task_colors[model.task])}>
<h4 className="text-orange-950">Task</h4>
<p className="text-orange-900">{model.task}</p>
<p className="text-orange-900 text-xs text-center">{model.task}</p>
</div>
<div className={clsx("flex flex-col flex-1 items-center p-2 bg-green-100")}>
<h4 className="text-orange-950">Engine</h4>
Expand Down
10 changes: 8 additions & 2 deletions packages/sdk/src/AIMaskClient.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Model, ExtensionMessagerClient, MessagerStreamHandler, ExtensionMessageRequestData, AIActions, AIAction, ChatCompletionParams, TranslationParams } from "@ai-mask/core";
import { createFakePort } from './utils'
import { AIMaskPassTransformers } from "./AIMaskPassTransformers";

export interface InferOptionsBase {
modelId: Model['id']
Expand All @@ -16,9 +17,14 @@ export type InferOptions = InferOptionsNonStreaming | InferOptionsStreaming

export class AIMaskClient {
messager: ExtensionMessagerClient<AIActions>
transformers: AIMaskPassTransformers

constructor(params?: { name?: string, port?: chrome.runtime.Port }) {
if (!params?.port && !AIMaskClient.isExtensionAvailable()) {
throw new Error('AI-Mask extension is not available')
}
this.messager = new ExtensionMessagerClient<AIActions>({ name: params?.name || 'ai-mask-app', port: params?.port })
this.transformers = new AIMaskPassTransformers(this)
}

static isExtensionAvailable() {
Expand All @@ -36,11 +42,11 @@ export class AIMaskClient {
this.messager.dispose()
}

private async request<T>(request: AIAction<T>, streamCallback?: MessagerStreamHandler): Promise<any> {
async request<T>(request: AIAction<T>, streamCallback?: MessagerStreamHandler): Promise<any> {
return this.messager.send(request, streamCallback)
}

private async *requestStream<T>(request: AIAction<T>): AsyncGenerator<string> {
async *requestStream<T>(request: AIAction<T>): AsyncGenerator<string> {
let resolve: (data: string) => void
let done = false
let promise = new Promise<string>(r => resolve = r)
Expand Down
30 changes: 30 additions & 0 deletions packages/sdk/src/AIMaskPassTransformers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { AIAction, Model, TransformersStubParams } from "@ai-mask/core"
import { AIMaskClient } from "."

export class AIMaskPassTransformers {
client: AIMaskClient
constructor(client: AIMaskClient) {
this.client = client
}

pipeline(task: Model['task'], modelId: string) {
return async (texts: string[], options?: any) => {
const request: AIAction<'engine_stub'> = {
action: 'engine_stub',
params: {
engine: 'transformers.js',
modelId,
engineParams: {
task: task,
},
callParams: [
texts,
options
]
}
}
const response = this.client.request(request)
return response
}
}
}