Skip to content

Commit

Permalink
work on updating smart-embed to use new pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
Brian Joseph Petro committed Sep 18, 2024
1 parent ca3aae7 commit 07d078a
Show file tree
Hide file tree
Showing 11 changed files with 304 additions and 223 deletions.
23 changes: 19 additions & 4 deletions smart-embed-model/adapters/_adapter.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
export class SmartEmbedAdapter {
constructor(smart_embed) {
this.smart_embed = smart_embed;
this.settings = smart_embed.settings;
this.model_config = smart_embed.model_config;
this.http_adapter = smart_embed.http_adapter;
}

async load() {
// Implement in subclasses if needed
}

async count_tokens(input) {
throw new Error('count_tokens method not implemented');
}

async embed(input) {
throw new Error('embed method not implemented');
}

async embed_batch(inputs) {
throw new Error('embed_batch method not implemented');
}
async load() { throw new Error("Not implemented"); }
async count_tokens(input) { throw new Error("Not implemented"); }
async embed(input) { throw new Error("Not implemented"); }
async embed_batch(input) { throw new Error("Not implemented"); }
}
126 changes: 126 additions & 0 deletions smart-embed-model/adapters/_api.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import { SmartHttpRequest } from "smart-http-request";
import { SmartEmbedAdapter } from "./_adapter.js";

export class SmartEmbedModelApiAdapter extends SmartEmbedAdapter {
constructor(smart_embed) {
super(smart_embed);
this.model_key = this.smart_embed.model_key;
this.model_config = this.smart_embed.model_config;
this.endpoint = this.model_config.endpoint;
this.max_tokens = this.model_config.max_tokens;
this.dims = this.model_config.dims;
}

get http_adapter() {
if (!this._http_adapter) {
if (this.smart_embed.opts.http_adapter) this._http_adapter = this.smart_embed.opts.http_adapter;
else this._http_adapter = new SmartHttpRequest();
}
return this._http_adapter;
}

get api_key() {
return this.settings.api_key || this.model_config.api_key;
}

async count_tokens(input) {
throw new Error("count_tokens not implemented");
}

estimate_tokens(input) {
if (typeof input === 'object') input = JSON.stringify(input);
return Math.ceil(input.length / 3.7);
}

async embed_batch(inputs) {
inputs = inputs.filter(item => item.embed_input?.length > 0);
if (inputs.length === 0) {
console.log("empty batch (or all items have empty embed_input)");
return [];
}
const embed_inputs = await Promise.all(inputs.map(item => this.prepare_embed_input(item.embed_input)));
const embeddings = await this.request_embedding(embed_inputs);
if (!embeddings) return console.error(inputs);
return inputs.map((item, i) => {
item.vec = embeddings[i].vec;
item.tokens = embeddings[i].tokens;
return item;
});
}

async prepare_embed_input(embed_input) {
throw new Error("prepare_embed_input not implemented");
}

prepare_batch_input(items) {
return items.map(item => this.prepare_embed_input(item.embed_input));
}

prepare_request_body(embed_input) {
throw new Error("prepare_request_body not implemented");
}

prepare_request_headers() {
let headers = {
"Content-Type": "application/json",
"Authorization": `Bearer ${this.api_key}`
};
if (this.smart_embed.opts.headers) {
headers = { ...headers, ...this.smart_embed.opts.headers };
}
return headers;
}

async request_embedding(embed_input) {
embed_input = embed_input.filter(input => input !== null && input.length > 0);
if (embed_input.length === 0) {
console.log("embed_input is empty after filtering null and empty strings");
return null;
}
const request = {
url: this.endpoint,
method: "POST",
body: JSON.stringify(this.prepare_request_body(embed_input)),
headers: this.prepare_request_headers()
};
const resp = await this.request(request);
return this.parse_response(resp);
}

parse_response(resp) {
throw new Error("parse_response not implemented");
}

is_error(resp_json) {
throw new Error("is_error not implemented");
}

async get_resp_json(resp) {
return (typeof resp.json === 'function') ? await resp.json() : await resp.json;
}

async request(req, retries = 0) {
try {
req.throw = false;
const resp = await this.http_adapter.request({ url: this.endpoint, ...req });
const resp_json = await this.get_resp_json(resp);
if (this.is_error(resp_json)) {
return await this.handle_request_err(resp_json, req, retries);
}
return resp_json;
} catch (error) {
return await this.handle_request_err(error, req, retries);
}
}

async handle_request_err(error, req, retries) {
if (error.status === 429 && retries < 3) {
const backoff = Math.pow(retries + 1, 2);
console.log(`Retrying request (429) in ${backoff} seconds...`);
await new Promise(r => setTimeout(r, 1000 * backoff));
return await this.request(req, retries + 1);
}
console.error(error);
return null;
}
}
102 changes: 2 additions & 100 deletions smart-embed-model/adapters/openai.js
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import { SmartEmbedAdapter } from "./_adapter.js";
import { SmartEmbedModelApiAdapter } from "./_api.js";
import { Tiktoken } from 'js-tiktoken/lite';
import cl100k_base from "../cl100k_base.json" assert { type: "json" };

export class SmartEmbedOpenAIAdapter extends SmartEmbedAdapter {
export class SmartEmbedOpenAIAdapter extends SmartEmbedModelApiAdapter {
constructor(smart_embed) {
super(smart_embed);
this.model_key = smart_embed.opts.model_key || "text-embedding-ada-002";
this.endpoint = "https://api.openai.com/v1/embeddings";
this.max_tokens = 8191; // Default max tokens for OpenAI embeddings
this.dims = smart_embed.opts.dims || 1536; // Default dimensions for OpenAI embeddings
this.enc = null;
this.request_adapter = smart_embed.env.opts.request_adapter;
}
get api_key() {
return this.smart_embed.opts.api_key
|| this.smart_embed.env.smart_connections_plugin?.settings?.api_key // temporary for backwards compatibility in SC OP
;
}

async load() {
Expand All @@ -27,30 +17,6 @@ export class SmartEmbedOpenAIAdapter extends SmartEmbedAdapter {
return this.enc.encode(input).length;
}

estimate_tokens(input) {
if (typeof input === 'object') input = JSON.stringify(input);
return Math.ceil(input.length / 3.7);
}

async embed_batch(inputs) {
console.log(`Original inputs length: ${inputs.length}`);
inputs = inputs.filter(item => item.embed_input?.length > 0);
console.log(`Filtered inputs length: ${inputs.length}`);
if (inputs.length === 0) {
console.log("empty batch (or all items have empty embed_input)");
return [];
}
const embed_inputs = await Promise.all(inputs.map(item => this.prepare_embed_input(item.embed_input)));
console.log(`Prepared embed_inputs length: ${embed_inputs.length}`);
const embeddings = await this.request_embedding(embed_inputs);
if (!embeddings) return console.error(inputs);
return inputs.map((item, i) => {
item.vec = embeddings[i].vec;
item.tokens = embeddings[i].tokens;
return item
});
}

async prepare_embed_input(embed_input) {
if (typeof embed_input !== 'string') {
throw new TypeError('embed_input must be a string');
Expand All @@ -69,14 +35,12 @@ export class SmartEmbedOpenAIAdapter extends SmartEmbedAdapter {
const reduce_ratio = (tokens_ct - this.max_tokens) / tokens_ct;
const new_length = Math.floor(embed_input.length * (1 - reduce_ratio));

// Trim the input to the new length, ensuring we don't cut off in the middle of a word
let trimmed_input = embed_input.slice(0, new_length);
const last_space_index = trimmed_input.lastIndexOf(' ');
if (last_space_index > 0) {
trimmed_input = trimmed_input.slice(0, last_space_index);
}

// Recursively call prepare_embed_input to ensure we're within token limit
const prepared_input = await this.prepare_embed_input(trimmed_input);
if (prepared_input === null) {
console.log("Warning: prepare_embed_input resulted in an empty string after trimming");
Expand All @@ -85,10 +49,6 @@ export class SmartEmbedOpenAIAdapter extends SmartEmbedAdapter {
return prepared_input;
}

prepare_batch_input(items) {
return items.map(item => this.prepare_embed_input(item.embed_input));
}

prepare_request_body(embed_input) {
const body = {
model: this.model_key,
Expand All @@ -100,33 +60,6 @@ export class SmartEmbedOpenAIAdapter extends SmartEmbedAdapter {
return body;
}

prepare_request_headers() {
let headers = {
"Content-Type": "application/json",
"Authorization": `Bearer ${this.api_key}`
};
if (this.smart_embed.opts.headers) {
headers = { ...headers, ...this.smart_embed.opts.headers };
}
return headers;
}

async request_embedding(embed_input) {
embed_input = embed_input.filter(input => input !== null && input.length > 0);
if (embed_input.length === 0) {
console.log("embed_input is empty after filtering null and empty strings");
return null;
}
const request = {
url: this.endpoint,
method: "POST",
body: JSON.stringify(this.prepare_request_body(embed_input)),
headers: this.prepare_request_headers()
};
const resp = await this.request(request);
return this.parse_response(resp);
}

parse_response(resp) {
return resp.data.map(item => ({
vec: item.embedding,
Expand All @@ -137,35 +70,4 @@ export class SmartEmbedOpenAIAdapter extends SmartEmbedAdapter {
is_error(resp_json) {
return !resp_json.data || !resp_json.usage;
}

async get_resp_json(resp) {
return (typeof resp.json === 'function') ? await resp.json() : await resp.json;
}

async request(req, retries = 0) {
try {
req.throw = false;
const resp = this.request_adapter
? await this.request_adapter({ url: this.endpoint, ...req })
: await fetch(this.endpoint, req);
const resp_json = await this.get_resp_json(resp);
if (this.is_error(resp_json)) {
return await this.handle_request_err(resp_json, req, retries);
}
return resp_json;
} catch (error) {
return await this.handle_request_err(error, req, retries);
}
}

async handle_request_err(error, req, retries) {
if (error.status === 429 && retries < 3) {
const backoff = Math.pow(retries + 1, 2);
console.log(`Retrying request (429) in ${backoff} seconds...`);
await new Promise(r => setTimeout(r, 1000 * backoff));
return await this.request(req, retries + 1);
}
console.error(error);
return null;
}
}
24 changes: 17 additions & 7 deletions smart-embed-model/adapters/transformers.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,32 @@ import { SmartEmbedAdapter } from "./_adapter.js";
export class SmartEmbedTransformersAdapter extends SmartEmbedAdapter {
constructor(smart_embed) {
super(smart_embed);
this.model_key = this.smart_embed.model_key;
this.model_config = this.smart_embed.model_config;
this.model = null;
this.tokenizer = null;
}

get batch_size() {
if(this.use_gpu && this.smart_embed.opts.gpu_batch_size) return this.smart_embed.opts.gpu_batch_size;
return this.smart_embed.opts.batch_size || 1;
return this.smart_embed.batch_size;
}

get max_tokens() {
return this.smart_embed.max_tokens;
}

get use_gpu() {
return this.smart_embed.opts.use_gpu || false;
}
get max_tokens() { return this.smart_embed.opts.max_tokens || 512; }
get use_gpu() { return this.smart_embed.opts.use_gpu || false; }

async load() {
const { pipeline, env, AutoTokenizer } = await import('@xenova/transformers');
// const { pipeline, env, AutoTokenizer } = await import('https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.9');

env.allowLocalModels = false;
const pipeline_opts = {
quantized: true,
};

if (this.use_gpu) {
console.log("[Transformers] Using GPU");
pipeline_opts.device = 'webgpu';
Expand All @@ -28,8 +37,9 @@ export class SmartEmbedTransformersAdapter extends SmartEmbedAdapter {
console.log("[Transformers] Using CPU");
env.backends.onnx.wasm.numThreads = 8;
}
this.model = await pipeline('feature-extraction', this.smart_embed.opts.model_key, pipeline_opts);
this.tokenizer = await AutoTokenizer.from_pretrained(this.smart_embed.opts.model_key);

this.model = await pipeline('feature-extraction', this.model_key, pipeline_opts);
this.tokenizer = await AutoTokenizer.from_pretrained(this.model_key);
}

async count_tokens(input) {
Expand Down
2 changes: 1 addition & 1 deletion smart-embed-model/adapters/transformers_iframe.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export class SmartEmbedTransformersIframeAdapter extends SmartEmbedIframeAdapter
constructor(smart_embed) {
super(smart_embed);
this.connector = transformers_connector;
if(this.smart_embed.env.settings.legacy_transformers){
if(this.smart_embed.settings.legacy_transformers){
this.connector = this.connector
.replace('@xenova/transformers', 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2')
;
Expand Down
14 changes: 6 additions & 8 deletions smart-embed-model/adapters/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,17 @@ export class SmartEmbedWorkerAdapter extends SmartEmbedAdapter {
this.worker_id = `smart_embed_worker_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
this.message_prefix = `msg_${Math.random().toString(36).substr(2, 9)}_`;
}
get main() { return this.smart_embed.env.main; }
get main() { return this.smart_embed; }

async load() {
console.log('loading worker adapter', this.smart_embed.opts);

const global_key = `smart_embed_worker_${this.smart_embed.opts.embed_model_key}`;
const global_key = `smart_embed_worker_${this.smart_embed.model_key}`;

if (!this.main[global_key]) {
this.main[global_key] = new Worker(this.worker_url, { type: 'module' });
console.log('new worker created', this.main[global_key]);
if (!this.smart_embed[global_key]) {
this.smart_embed[global_key] = new Worker(this.worker_url, { type: 'module' });
console.log('new worker created', this.smart_embed[global_key]);
}

this.worker = this.main[global_key];
this.worker = this.smart_embed[global_key];
console.log('worker', this.worker);
console.log('worker_url', this.worker_url);

Expand Down
Loading

0 comments on commit 07d078a

Please sign in to comment.