Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLN] Cleaned up OpenAI JS EF #2108

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: trailing-whitespace
- id: mixed-line-ending
- id: end-of-file-fixer
exclude: "go/migrations"
exclude: "go/migrations|.vscode"
- id: requirements-txt-fixer
- id: check-yaml
args: ["--allow-multiple-documents"]
Expand Down
103 changes: 38 additions & 65 deletions clients/js/src/embeddings/OpenAIEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

let OpenAIApi: any;
let openAiVersion = null;
let openAiMajorVersion = null;

interface OpenAIAPI {
createEmbedding: (params: {
Expand Down Expand Up @@ -30,14 +28,10 @@ class OpenAIAPIv3 implements OpenAIAPI {
user?: string;
}): Promise<number[][]> {
const embeddings: number[][] = [];
const response = await this.openai
.createEmbedding({
model: params.model,
input: params.input,
})
.catch((error: any) => {
throw error;
});
const response = await this.openai.createEmbedding({
model: params.model,
input: params.input,
});
// @ts-ignore
const data = response.data["data"];
for (let i = 0; i < data.length; i += 1) {
Expand All @@ -48,14 +42,10 @@ class OpenAIAPIv3 implements OpenAIAPI {
}

class OpenAIAPIv4 implements OpenAIAPI {
private readonly apiKey: any;
private openai: any;

constructor(apiKey: any) {
this.apiKey = apiKey;
this.openai = new OpenAIApi({
apiKey: this.apiKey,
});
constructor(configuration: { organization: string; apiKey: string }) {
this.openai = new OpenAIApi.OpenAI(configuration);
}

public async createEmbedding(params: {
Expand Down Expand Up @@ -95,63 +85,46 @@ export class OpenAIEmbeddingFunction implements IEmbeddingFunction {
this.model = openai_model || "text-embedding-ada-002";
}

private async loadClient() {
// cache the client
if (this.openaiApi) return;

try {
const { openai, version } = await OpenAIEmbeddingFunction.import();
OpenAIApi = openai;
let versionVar: string = version;
openAiVersion = versionVar.replace(/[^0-9.]/g, "");
openAiMajorVersion = parseInt(openAiVersion.split(".")[0]);
} catch (_a) {
// @ts-ignore
if (_a.code === "MODULE_NOT_FOUND") {
throw new Error(
"Please install the openai package to use the OpenAIEmbeddingFunction, `npm install -S openai`",
);
}
throw _a; // Re-throw other errors
}

if (openAiMajorVersion > 3) {
this.openaiApi = new OpenAIAPIv4(this.api_key);
} else {
this.openaiApi = new OpenAIAPIv3({
organization: this.org_id,
apiKey: this.api_key,
});
}
}

public async generate(texts: string[]): Promise<number[][]> {
await this.loadClient();
const openaiApi = await this.getOpenAIClient();

return await this.openaiApi!.createEmbedding({
return await openaiApi.createEmbedding({
model: this.model,
input: texts,
}).catch((error: any) => {
throw error;
});
}

/** @ignore */
static async import(): Promise<{
// @ts-ignore
openai: typeof import("openai");
version: string;
}> {
private async getOpenAIClient(): Promise<OpenAIAPI> {
if (this.openaiApi) return this.openaiApi;
try {
// @ts-ignore
const { default: openai } = await import("openai");
// @ts-ignore
const { VERSION } = await import("openai/version");
return { openai, version: VERSION };
// @ts-ignore - we need to dynamically import the openai package without TS errors
OpenAIApi = await import("openai");
// @ts-ignore - we need to dynamically import the openai package without TS errors
this.openaiApi = await import("openai/version")
.catch(() => ({ VERSION: "3" }))
.then(({ VERSION }) => {
if (VERSION.startsWith("4")) {
return new OpenAIAPIv4({
apiKey: this.api_key,
organization: this.org_id,
});
} else if (VERSION.startsWith("3")) {
return new OpenAIAPIv3({
organization: this.org_id,
apiKey: this.api_key,
});
} else {
throw new Error("Unsupported OpenAI library version");
}
});
return this.openaiApi;
} catch (e) {
throw new Error(
"Please install openai as a dependency with, e.g. `yarn add openai`",
);
// @ts-ignore
if (e.code === "MODULE_NOT_FOUND") {
throw new Error(
"Please install the openai package to use the OpenAIEmbeddingFunction, `npm install -S openai`",
);
}
throw e;
}
}
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ dependencies = [

[tool.black]
line-length = 88
required-version = "24.4.2" # Black will refuse to run if it's not this version.
required-version = "23.3.0" # Black will refuse to run if it's not this version.
target-version = ['py38', 'py39', 'py310', 'py311']

[tool.pytest.ini_options]
Expand Down
Loading