Skip to content

Commit

Permalink
Database > ExtensionState
Browse files Browse the repository at this point in the history
  • Loading branch information
pacoccino committed Feb 28, 2024
1 parent 394fc47 commit 095179b
Show file tree
Hide file tree
Showing 15 changed files with 250 additions and 179 deletions.
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,3 @@ pnpm build
- [ ] Implement more tasks
- [ ] Add more models
- [ ] Unload model from memory after being inactive for a while
- [ ] Improve extension panel:
- [ ] search/filter models
- [ ] download/remove individual models
2 changes: 1 addition & 1 deletion packages/extension/manifest.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const externally_connectable_urls = [
"<all_urls>",
]
let permissions: chrome.runtime.ManifestPermissions[] = [
//"storage",
"storage",
]
let side_panel: chrome.sidePanel.SidePanel | undefined = undefined

Expand Down
3 changes: 1 addition & 2 deletions packages/extension/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@ai-mask/chrome-extension",
"version": "0.4.2",
"version": "0.4.3",
"description": "",
"type": "module",
"scripts": {
Expand Down Expand Up @@ -34,7 +34,6 @@
"@ai-mask/core": "workspace:^",
"@xenova/transformers": "^2.15.1",
"clsx": "^2.1.0",
"localforage": "^1.10.0",
"react": "^18.2.0",
"react-dom": "^18.2.0"
}
Expand Down
33 changes: 18 additions & 15 deletions packages/extension/src/lib/AIMaskService.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ExtensionMessager, MessagerStreamHandler, AIActions, AIActionParams, getModel, models } from "@ai-mask/core";
import { database } from "./Database";
import { extensionState } from "./State";
import { InternalMessage, InternalMessager } from "./InternalMessager";
import { ModelLoadReport, AIMaskInferer } from "./AIMaskInfer";

Expand All @@ -14,14 +14,14 @@ export class AIMaskService {
// @ts-ignore
this.messager = new ExtensionMessager<AIActions>(this.handleAppMessage.bind(this))
InternalMessager.listen(this.handleInternalMessage.bind(this))
database.init().catch(console.error)
extensionState.init().catch(console.error)
}

async getInferer(params: AIActionParams<'infer'>): Promise<AIMaskInferer> {
const progressHandler = (report: ModelLoadReport) => {
if (!this.inferer) return
const modelId = this.inferer.model.id
database
extensionState
.setProgress(modelId, report.progress)
.catch(console.error)
}
Expand All @@ -30,7 +30,7 @@ export class AIMaskService {
if (this.inferer.model.id === params.modelId) {
if (this.inferer.isReady()) return this.inferer
else {
await database.set('status', 'loading')
await extensionState.set('status', 'loading')
await this.inferer.load(progressHandler)
return this.inferer
}
Expand All @@ -45,20 +45,21 @@ export class AIMaskService {
if (model.task !== params.task) {
throw new Error('incompatible task and model')
}

await extensionState.set('status', 'loading')
this.inferer = new AIMaskInferer(model)
await database.set('status', 'loading')
await this.inferer.load(progressHandler)
await database.set('status', 'loaded')
await database.setCached(params.modelId)
await database.set('loaded_model', model.id)
await extensionState.setCached(params.modelId)
await extensionState.set('status', 'loaded')
await extensionState.set('loaded_model', model.id)
return this.inferer
}

async unloadModel() {
if (this.inferer) {
this.inferer.unload()
await database.set('status', 'initialied')
await database.set('loaded_model', undefined)
await extensionState.set('status', 'initialied')
await extensionState.set('loaded_model', undefined)
await InternalMessager.send({
type: 'models_updated',
}, true)
Expand All @@ -73,26 +74,28 @@ export class AIMaskService {
caches.delete(cacheKey)
})
})
await database.init(true)
await extensionState.init(true)
}

async onInfer(params: AIActionParams<'infer'>, streamhandler: MessagerStreamHandler<string>): Promise<string> {
if (await database.get('status') === 'infering') throw new Error('already infering')
if (await extensionState.get('status') === 'infering') throw new Error('already infering')

try {
const inferer = await this.getInferer(params)
await database.set('status', 'infering')
await extensionState.set('status', 'infering')
const response = await inferer.infer(params, streamhandler)
await extensionState.set('status', 'loaded')
return response
} catch (e) {
await extensionState.set('status', 'error')
throw e
} finally {
await database.set('status', 'loaded')
}
}

async handleInternalMessage(message: InternalMessage) {
switch (message.type) {
case 'get_state':
return await extensionState.getAll()
case 'clear_models_cache':
return await this.clearModelsCache()
case 'unload_model':
Expand Down
4 changes: 2 additions & 2 deletions packages/extension/src/lib/InternalMessager.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
export type InternalMessage = {
type: string
data: any
data?: any
}

export class InternalMessager {
static async send(message: any, broadcast?: true): Promise<any> {
static async send(message: InternalMessage, broadcast?: true): Promise<any> {
// console.log('[InternalMessager] send', message, broadcast)
try {
const response = await chrome.runtime.sendMessage(message)
Expand Down
108 changes: 108 additions & 0 deletions packages/extension/src/lib/State.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { Model } from "@ai-mask/core";
import { InternalMessager } from "./InternalMessager";

export interface State_Type {
cached_models: { [id: string]: boolean }
cache_progress: { [id: string]: number }
loaded_model: Model['id'] | undefined
cache_size: number
status: 'uninitialized' | 'initialied' | 'loading' | 'loaded' | 'error' | 'infering'
}

export const INITIAL_State: State_Type = {
cached_models: {},
cache_progress: {},
loaded_model: undefined,
cache_size: 0,
status: 'uninitialized',
}

type State_Type_Keys = keyof State_Type
type State_Type_Value = State_Type[State_Type_Keys]

class State {
state: State_Type = INITIAL_State

async init(reset = false) {
if (reset) {
this.state = INITIAL_State
await chrome.storage.local.clear()
}
const storedValues = await chrome.storage.local.get(['cached_models'])
this.state.cached_models = storedValues.cached_models || {}
await this.notifyUpdate()
await this.updateCachedModelsSize()
/*
const initialized = await this.get('status')
if (reset || initialized !== 'initialied') {
for (const key in INITIAL_State) {
await this.set(key as State_Type_Keys, INITIAL_State[key as State_Type_Keys] as State_Type_Value)
}
}
this.set('status', 'initialied')
this.set('loaded_model', undefined)
*/
}

async updateCachedModelsSize() {
const estimate = await navigator.storage.estimate()
this.state.cache_size = estimate.usage || 0
await this.notifyUpdate('cache_size')
/*
const cachesKeys = await caches.keys()
for (const cachesKey of cachesKeys) {
const cache = await caches.open(cachesKey)
const requests = await cache.keys()
for (const request of requests) {
const response = await caches.match(request)
const responseSize = response ? Number(response.headers.get('content-length')) : 0;
console.log(request.url, responseSize / 1024 / 1024)
// TODO match url with model id and aggregate
}
}
*/
}

async notifyUpdate(key?: State_Type_Keys) {
const value = key && this.state[key]
await InternalMessager.send({
type: 'state_updated',
data: {
key,
value,
state: this.state,
}
}, true)
}

async set<T extends State_Type_Keys>(key: T, value: State_Type_Value): Promise<void> {
this.state[key] = value as typeof INITIAL_State[T]
await this.notifyUpdate(key)
}

async get<T extends State_Type_Keys>(key: T): Promise<typeof INITIAL_State[T]> {
return this.state[key]
}

async getAll(): Promise<State_Type> {
return this.state
}

async setProgress(id: Model['id'], progress: number) {
this.state.cache_progress[id] = progress
await this.notifyUpdate('cache_progress')
}

async setCached(id: Model['id'], cached: boolean = true) {
this.state.cached_models[id] = cached
await this.notifyUpdate('cached_models')
await chrome.storage.local.set({
cached_models: this.state.cached_models
})
await this.updateCachedModelsSize()
}
}

export const extensionState = new State()
77 changes: 0 additions & 77 deletions packages/extension/src/lib/database.ts

This file was deleted.

6 changes: 3 additions & 3 deletions packages/extension/src/side_panel/App.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Models from "./components/Models";
import { useDb } from "./hooks/db";
import { useExtensionState } from "./hooks/state";
import clsx from 'clsx'

const state_colors = {
Expand All @@ -20,8 +20,8 @@ const state_text = {
}

export default function App() {
const db = useDb()
const status = db?.status || 'uninitialized'
const extensionState = useExtensionState()
const status = extensionState?.status || 'uninitialized'

return (
<div className={clsx("flex flex-col items-center", import.meta.env.DEV ? 'w-screen h-screen' : 'w-[350px] h-[500px]')} >
Expand Down
6 changes: 3 additions & 3 deletions packages/extension/src/side_panel/components/ModelRow.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Model, config } from "@ai-mask/core";
import { DB_Type } from "../../lib/Database";
import { State_Type } from "../../lib/State";
import clsx from "clsx";
import { modelStatus } from "../lib/models";

Expand All @@ -21,12 +21,12 @@ const vrams: any = config.mlc.appConfig.model_list.reduce((acc: any, item) => {
return acc as any
})

export default function ModelRow({ model, db }: { model: Model, db: DB_Type }) {
export default function ModelRow({ model, extensionState }: { model: Model, extensionState: State_Type }) {
const {
progress,
loading,
status,
} = modelStatus(model, db)
} = modelStatus(model, extensionState)

return (
<div key={model.id} className="">
Expand Down
Loading

0 comments on commit 095179b

Please sign in to comment.