Skip to content

Commit

Permalink
Merge pull request #5 from pacoccino/embeddings
Browse files Browse the repository at this point in the history
embeddings task
  • Loading branch information
pacoccino authored Mar 23, 2024
2 parents c072962 + 9d9c25d commit fa2b954
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 32 deletions.
2 changes: 2 additions & 0 deletions packages/core/src/OpenAI_stubs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ export interface ChatCompletionParams {
max_tokens?: number
temperature?: number
}
export type ChatCompletionResponse = string
export type ChatCompletionResponseStream = string

// TODO directly use OpenAI SDK type
26 changes: 25 additions & 1 deletion packages/core/src/actions.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ExtensionMessageRequestData } from "./ExtensionMessager"
import { Model } from "./models"
import { ChatCompletionParams } from "./OpenAI_stubs"
import { ChatCompletionParams, ChatCompletionResponse, ChatCompletionResponseStream } from "./OpenAI_stubs"

export type CompletionParams = {
prompt: string
Expand All @@ -11,11 +11,35 @@ export type TranslationParams = {
destLang: string
}

export type FeatureExtractionParams = {
texts: string[]
pooling?: 'none' | 'mean' | 'cls'
normalize?: boolean
}

export type AIMaskInferParams =
| CompletionParams
| ChatCompletionParams
| TranslationParams
| FeatureExtractionParams

export type CompletionResponse = string
export type CompletionResponseStream = string
export type TranslationResponse = string
export type TranslationResponseStream = string
export type FeatureExtractionResponse = number[][]
export type FeatureExtractionResponseStream = number[][]

export type AIMaskInferResponse =
| CompletionResponse
| ChatCompletionResponse
| TranslationResponse
| FeatureExtractionResponse
export type AIMaskInferResponseStream =
| CompletionResponseStream
| ChatCompletionResponseStream
| TranslationResponseStream
| FeatureExtractionResponseStream

export type AIActions =
| ExtensionMessageRequestData<'infer', {
Expand Down
16 changes: 9 additions & 7 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
export type Model = {
id: string
name: string
engine: 'web-llm' | 'transformers.js',
task: 'completion' | 'chat' | 'translation',
}

export const models: Model[] = [
export const models = [
{
id: 'gemma-2b-it-q4f32_1',
name: 'Gemma 2B',
Expand Down Expand Up @@ -36,8 +30,16 @@ export const models: Model[] = [
engine: 'transformers.js',
task: 'translation',
},
{
id: 'Xenova/all-MiniLM-L6-v2',
name: 'MiniLM L6 v2',
engine: 'transformers.js',
task: 'feature-extraction'
},
]

export type Model = typeof models[number]

export function getModel(id: Model['id']): Model | undefined {
return models.find(model => model.id == id)
}
Expand Down
42 changes: 32 additions & 10 deletions packages/extension/src/lib/AIMaskInfer.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { ChatModule, InitProgressReport } from "@mlc-ai/web-llm";
import { GenerateProgressCallback } from "@mlc-ai/web-llm/lib/types";
import { env, pipeline, TranslationPipeline } from '@xenova/transformers';
import { env, pipeline, Pipeline, TranslationPipeline, FeatureExtractionPipeline } from '@xenova/transformers';

import { config, AIActionParams, MessagerStreamHandler, Model, ChatCompletionParams, CompletionParams, TranslationParams } from "@ai-mask/core";
import { config, AIActionParams, MessagerStreamHandler, Model, ChatCompletionParams, FeatureExtractionParams, CompletionParams, TranslationParams, AIMaskInferResponseStream, AIMaskInferResponse, FeatureExtractionResponseStream, ChatCompletionResponseStream, ChatCompletionResponse, TranslationResponseStream, TranslationResponse, CompletionResponseStream, CompletionResponse, FeatureExtractionResponse } from "@ai-mask/core";

export interface ModelLoadReport {
progress: number
Expand All @@ -17,7 +17,7 @@ env.allowLocalModels = false;
export class AIMaskInferer {
model: Model
inMemory: boolean = false
engineInstance: TranslationPipeline | ChatModule | null = null
engineInstance: Pipeline | ChatModule | null = null

constructor(model: Model) {
this.model = model;
Expand All @@ -36,9 +36,10 @@ export class AIMaskInferer {
switch (this.model.engine) {
case 'transformers.js':
let files: any[] = []

// @ts-ignore
this.engineInstance = await pipeline(
'translation',
// @ts-ignore
this.model.task,
this.model.id,
{
progress_callback: (event: any) => {
Expand Down Expand Up @@ -91,7 +92,7 @@ export class AIMaskInferer {

switch (this.model.engine) {
case 'transformers.js': {
const engine = this.engineInstance as TranslationPipeline
const engine = this.engineInstance as Pipeline
await engine.dispose()
break;
}
Expand All @@ -107,7 +108,7 @@ export class AIMaskInferer {
this.inMemory = false
}

async infer(params: AIActionParams<'infer'>, streamhandler: MessagerStreamHandler<string>) {
async infer(params: AIActionParams<'infer'>, streamhandler: MessagerStreamHandler<AIMaskInferResponseStream>): Promise<AIMaskInferResponse> {
if (!this.isReady()) throw new Error('inferer not ready')
if (this.model.id !== params.modelId) throw new Error('model id mismatch')

Expand All @@ -116,6 +117,8 @@ export class AIMaskInferer {
return this.taskChat(params.params as ChatCompletionParams, streamhandler)
case 'completion':
return this.taskCompletion(params.params as CompletionParams, streamhandler)
case 'feature-extraction':
return this.taskFeatureExtraction(params.params as FeatureExtractionParams, streamhandler)
case 'translation':
return this.taskTranslation(params.params as TranslationParams, streamhandler)
default:
Expand All @@ -124,7 +127,7 @@ export class AIMaskInferer {
}


async taskChat(params: ChatCompletionParams, streamhandler: MessagerStreamHandler<string>): Promise<string> {
async taskChat(params: ChatCompletionParams, streamhandler: MessagerStreamHandler<ChatCompletionResponseStream>): Promise<ChatCompletionResponse> {
if (this.model.engine !== 'web-llm') throw new Error(`engine ${this.model.engine} not supported for completion`)
const engine = this.getEngine() as ChatModule
await engine.resetChat()
Expand All @@ -148,7 +151,7 @@ export class AIMaskInferer {
return message
}

async taskCompletion(params: CompletionParams, streamhandler: MessagerStreamHandler<string>): Promise<string> {
async taskCompletion(params: CompletionParams, streamhandler: MessagerStreamHandler<CompletionResponseStream>): Promise<CompletionResponse> {
if (this.model.engine !== 'web-llm') throw new Error(`engine ${this.model.engine} not supported for completion`)
const engine = this.getEngine() as ChatModule
await engine.resetChat()
Expand All @@ -161,9 +164,10 @@ export class AIMaskInferer {
return response
}

async taskTranslation(params: TranslationParams, streamhandler: MessagerStreamHandler<string>): Promise<string> {
async taskTranslation(params: TranslationParams, streamhandler: MessagerStreamHandler<TranslationResponseStream>): Promise<TranslationResponse> {
if (this.model.engine !== 'transformers.js') throw new Error(`engine ${this.model.engine} not supported for translation`)

// @ts-ignore
const engine = this.getEngine() as TranslationPipeline
const translationParams = params

Expand All @@ -183,4 +187,22 @@ export class AIMaskInferer {
// @ts-ignore
return response[0].translation_text
}
async taskFeatureExtraction(params: FeatureExtractionParams, streamhandler: MessagerStreamHandler<FeatureExtractionResponseStream>): Promise<FeatureExtractionResponse> {
if (this.model.engine !== 'transformers.js') throw new Error(`engine ${this.model.engine} not supported for translation`)

// @ts-ignore
const engine = this.getEngine() as FeatureExtractionPipeline

const response = await engine(
params.texts,
{
pooling: params.pooling,
normalize: params.normalize,
},
)
const list = response.tolist()
streamhandler(list)

return list
}
}
4 changes: 2 additions & 2 deletions packages/extension/src/lib/AIMaskService.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ExtensionMessagerServer, MessagerStreamHandler, AIActions, AIActionParams, getModel, models } from "@ai-mask/core";
import { ExtensionMessagerServer, MessagerStreamHandler, AIActions, AIActionParams, getModel, models, AIMaskInferResponse, AIMaskInferResponseStream } from "@ai-mask/core";
import { extensionState } from "./State";
import { InternalMessage, InternalMessager } from "./InternalMessager";
import { ModelLoadReport, AIMaskInferer } from "./AIMaskInfer";
Expand Down Expand Up @@ -77,7 +77,7 @@ export class AIMaskService {
await extensionState.init(true)
}

async onInfer(params: AIActionParams<'infer'>, streamhandler: MessagerStreamHandler<string>): Promise<string> {
async onInfer(params: AIActionParams<'infer'>, streamhandler: MessagerStreamHandler<AIMaskInferResponseStream>): Promise<AIMaskInferResponse> {
if (await extensionState.get('status') === 'infering') throw new Error('already infering')

try {
Expand Down
11 changes: 6 additions & 5 deletions packages/extension/src/popup/components/ModelRow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ const status_colors = {
'Uncached': 'bg-orange-50',
}

const task_colors = {
const task_colors: { [K: Model['task']]: string } = {
'chat': 'bg-orange-400',
'completion': 'bg-orange-400',
'translation': 'bg-blue-400',
'feature-extraction': 'bg-yellow-400',
}

const vrams: any = config.mlc.appConfig.model_list.reduce((acc: any, item) => {
Expand All @@ -40,17 +41,17 @@ export default function ModelRow({ model, extensionState }: { model: Model, exte
}

<div className="w-full flex justify-stretch bg-black">
<div className={clsx("flex flex-col w-20 items-center p-2", status_colors[status])} >
<div className={clsx("flex flex-col w-16 items-center p-2", status_colors[status])} >
<h4 className="text-orange-950">Status</h4>
<p className="text-orange-900">{status}</p>
</div>
<div className="flex flex-col w-16 items-center p-2 bg-orange-300">
<div className="flex flex-col w-14 items-center p-2 bg-orange-300">
<h4 className="text-orange-950">VRAM</h4>
<p className="text-orange-900">{vrams[model.id] || '?'} GB</p>
</div>
<div className={clsx("flex flex-col w-24 items-center p-2 bg-green-300", task_colors[model.task])}>
<div className={clsx("flex flex-col w-28 items-center p-2 bg-green-300 px-1 py-2", task_colors[model.task])}>
<h4 className="text-orange-950">Task</h4>
<p className="text-orange-900">{model.task}</p>
<p className="text-orange-900 text-xs text-center">{model.task}</p>
</div>
<div className={clsx("flex flex-col flex-1 items-center p-2 bg-green-100")}>
<h4 className="text-orange-950">Engine</h4>
Expand Down
27 changes: 20 additions & 7 deletions packages/sdk/src/AIMaskClient.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Model, ExtensionMessagerClient, MessagerStreamHandler, ExtensionMessageRequestData, AIActions, AIAction, ChatCompletionParams, TranslationParams } from "@ai-mask/core";
import { Model, ExtensionMessagerClient, MessagerStreamHandler, AIActions, AIAction, ChatCompletionParams, TranslationParams, TranslationResponseStream } from "@ai-mask/core";
import { createFakePort } from './utils'
import { FeatureExtractionParams, FeatureExtractionResponse, TranslationResponse, ChatCompletionResponse, ChatCompletionResponseStream } from "@ai-mask/core";

export interface InferOptionsBase {
modelId: Model['id']
Expand Down Expand Up @@ -68,9 +69,9 @@ export class AIMaskClient {
}
*/

async chat(params: ChatCompletionParams, options: InferOptionsNonStreaming): Promise<string>
async chat(params: ChatCompletionParams, options: InferOptionsStreaming): Promise<AsyncGenerator<string>>
async chat(params: ChatCompletionParams, options: InferOptions): Promise<string | AsyncGenerator<string>> {
async chat(params: ChatCompletionParams, options: InferOptionsNonStreaming): Promise<ChatCompletionResponse>
async chat(params: ChatCompletionParams, options: InferOptionsStreaming): Promise<AsyncGenerator<ChatCompletionResponseStream>>
async chat(params: ChatCompletionParams, options: InferOptions): Promise<ChatCompletionResponse | AsyncGenerator<ChatCompletionResponseStream>> {
const request: AIAction<'infer'> = {
action: 'infer',
params: {
Expand All @@ -87,9 +88,9 @@ export class AIMaskClient {
}
}

async translate(params: TranslationParams, options: InferOptionsNonStreaming): Promise<string>
async translate(params: TranslationParams, options: InferOptionsStreaming): Promise<AsyncGenerator<string>>
async translate(params: TranslationParams, options: InferOptions): Promise<string | AsyncGenerator<string>> {
async translate(params: TranslationParams, options: InferOptionsNonStreaming): Promise<TranslationResponse>
async translate(params: TranslationParams, options: InferOptionsStreaming): Promise<AsyncGenerator<TranslationResponseStream>>
async translate(params: TranslationParams, options: InferOptions): Promise<TranslationResponse | AsyncGenerator<TranslationResponseStream>> {
const request: AIAction<'infer'> = {
action: 'infer',
params: {
Expand All @@ -106,6 +107,18 @@ export class AIMaskClient {
}
}

async featureExtraction(params: FeatureExtractionParams, options: InferOptionsNonStreaming): Promise<FeatureExtractionResponse> {
const request: AIAction<'infer'> = {
action: 'infer',
params: {
modelId: options.modelId,
task: 'feature-extraction',
params,
},
}
return this.request(request)
}

async getModels(): Promise<Model[]> {
return this.request({
action: 'get_models',
Expand Down

0 comments on commit fa2b954

Please sign in to comment.