From 5ffff81eef817573da5670a519642aee350426f0 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Wed, 7 Aug 2024 23:23:11 -0700 Subject: [PATCH 1/2] Add Vertex embeddings to integration package --- .../src/embeddings/googlevertexai.ts | 5 + .../src/utils/googlevertexai-connection.ts | 4 +- .../langchain-google-common/src/connection.ts | 8 +- .../langchain-google-common/src/embeddings.ts | 202 ++++++++++++++++++ libs/langchain-google-common/src/index.ts | 1 + libs/langchain-google-gauth/src/auth.ts | 4 +- libs/langchain-google-gauth/src/embeddings.ts | 39 ++++ libs/langchain-google-gauth/src/index.ts | 1 + .../src/embeddings.ts | 25 +++ .../src/index.ts | 1 + .../src/tests/embeddings.int.test.ts | 29 +++ .../src/embeddings.ts | 25 +++ libs/langchain-google-vertexai/src/index.ts | 1 + .../src/embeddings.ts | 38 ++++ libs/langchain-google-webauth/src/index.ts | 1 + 15 files changed, 377 insertions(+), 7 deletions(-) create mode 100644 libs/langchain-google-common/src/embeddings.ts create mode 100644 libs/langchain-google-gauth/src/embeddings.ts create mode 100644 libs/langchain-google-vertexai-web/src/embeddings.ts create mode 100644 libs/langchain-google-vertexai-web/src/tests/embeddings.int.test.ts create mode 100644 libs/langchain-google-vertexai/src/embeddings.ts create mode 100644 libs/langchain-google-webauth/src/embeddings.ts diff --git a/libs/langchain-community/src/embeddings/googlevertexai.ts b/libs/langchain-community/src/embeddings/googlevertexai.ts index fa08b1c0e3..f059b45aea 100644 --- a/libs/langchain-community/src/embeddings/googlevertexai.ts +++ b/libs/langchain-community/src/embeddings/googlevertexai.ts @@ -10,6 +10,7 @@ import { import { GoogleVertexAILLMConnection } from "../utils/googlevertexai-connection.js"; /** + * @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web * Defines the parameters required to initialize a * GoogleVertexAIEmbeddings instance. It extends EmbeddingsParams and * GoogleVertexAIConnectionParams. @@ -19,12 +20,14 @@ export interface GoogleVertexAIEmbeddingsParams GoogleVertexAIBaseLLMInput {} /** + * @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web * Defines additional options specific to the * GoogleVertexAILLMEmbeddingsInstance. It extends AsyncCallerCallOptions. */ interface GoogleVertexAILLMEmbeddingsOptions extends AsyncCallerCallOptions {} /** + * @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web * Represents an instance for generating embeddings using the Google * Vertex AI API. It contains the content to be embedded. */ @@ -33,6 +36,7 @@ interface GoogleVertexAILLMEmbeddingsInstance { } /** + * @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web * Defines the structure of the embeddings results returned by the Google * Vertex AI API. It extends GoogleVertexAIBasePrediction and contains the * embeddings and their statistics. @@ -48,6 +52,7 @@ interface GoogleVertexEmbeddingsResults extends GoogleVertexAIBasePrediction { } /** + * @deprecated Import and use from @langchain/google-vertexai or @langchain/google-vertexai-web * Enables calls to the Google Cloud's Vertex AI API to access * the embeddings generated by Large Language Models. * diff --git a/libs/langchain-community/src/utils/googlevertexai-connection.ts b/libs/langchain-community/src/utils/googlevertexai-connection.ts index 7a03352588..5843bbc2b3 100644 --- a/libs/langchain-community/src/utils/googlevertexai-connection.ts +++ b/libs/langchain-community/src/utils/googlevertexai-connection.ts @@ -212,7 +212,9 @@ export class GoogleVertexAILLMConnection< } const projectId = await this.client.getProjectId(); - + console.log( + `https://${this.endpoint}/v1/projects/${projectId}/locations/${this.location}/publishers/google/models/${this.model}:${method}` + ); return `https://${this.endpoint}/v1/projects/${projectId}/locations/${this.location}/publishers/google/models/${this.model}:${method}`; } diff --git a/libs/langchain-google-common/src/connection.ts b/libs/langchain-google-common/src/connection.ts index bf73c40787..d00dab8851 100644 --- a/libs/langchain-google-common/src/connection.ts +++ b/libs/langchain-google-common/src/connection.ts @@ -166,8 +166,8 @@ export abstract class GoogleHostConnection< } export abstract class GoogleAIConnection< - CallOptions extends BaseLanguageModelCallOptions, - MessageType, + CallOptions extends AsyncCallerCallOptions, + InputType, AuthOptions > extends GoogleHostConnection @@ -232,12 +232,12 @@ export abstract class GoogleAIConnection< } abstract formatData( - input: MessageType, + input: InputType, parameters: GoogleAIModelRequestParams ): unknown; async request( - input: MessageType, + input: InputType, parameters: GoogleAIModelRequestParams, options: CallOptions ): Promise { diff --git a/libs/langchain-google-common/src/embeddings.ts b/libs/langchain-google-common/src/embeddings.ts new file mode 100644 index 0000000000..b0367d4fb8 --- /dev/null +++ b/libs/langchain-google-common/src/embeddings.ts @@ -0,0 +1,202 @@ +import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"; +import { + AsyncCaller, + AsyncCallerCallOptions, +} from "@langchain/core/utils/async_caller"; +import { chunkArray } from "@langchain/core/utils/chunk_array"; +import { GoogleAIConnection } from "./connection.js"; +import { ApiKeyGoogleAuth, GoogleAbstractedClient } from "./auth.js"; +import { GoogleAIModelRequestParams, GoogleConnectionParams } from "./types.js"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +class EmbeddingsConnection< + CallOptions extends AsyncCallerCallOptions, + AuthOptions +> extends GoogleAIConnection< + CallOptions, + GoogleEmbeddingsInstance[], + AuthOptions +> { + convertSystemMessageToHumanContent: boolean | undefined; + + constructor( + fields: GoogleConnectionParams | undefined, + caller: AsyncCaller, + client: GoogleAbstractedClient, + streaming: boolean + ) { + super(fields, caller, client, streaming); + } + + async buildUrlMethod(): Promise { + return "predict"; + } + + formatData( + input: GoogleEmbeddingsInstance[], + parameters: GoogleAIModelRequestParams + ): unknown { + return { + instances: input, + parameters, + }; + } +} + +/** + * Defines the parameters required to initialize a + * GoogleEmbeddings instance. It extends EmbeddingsParams and + * GoogleConnectionParams. + */ +export interface BaseGoogleEmbeddingsParams + extends EmbeddingsParams, + GoogleConnectionParams { + model: string; +} + +/** + * Defines additional options specific to the + * GoogleEmbeddingsInstance. It extends AsyncCallerCallOptions. + */ +export interface BaseGoogleEmbeddingsOptions extends AsyncCallerCallOptions {} + +/** + * Represents an instance for generating embeddings using the Google + * Vertex AI API. It contains the content to be embedded. + */ +export interface GoogleEmbeddingsInstance { + content: string; +} + +/** + * Defines the structure of the embeddings results returned by the Google + * Vertex AI API. It extends GoogleBasePrediction and contains the + * embeddings and their statistics. + */ +export interface BaseGoogleEmbeddingsResults { + embeddings: { + statistics: { + token_count: number; + truncated: boolean; + }; + values: number[]; + }; +} + +/** + * Enables calls to the Google Cloud's Vertex AI API to access + * the embeddings generated by Large Language Models. + * + * To use, you will need to have one of the following authentication + * methods in place: + * - You are logged into an account permitted to the Google Cloud project + * using Vertex AI. + * - You are running this on a machine using a service account permitted to + * the Google Cloud project using Vertex AI. + * - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the + * path of a credentials file for a service account permitted to the + * Google Cloud project using Vertex AI. + * @example + * ```typescript + * const model = new GoogleEmbeddings(); + * const res = await model.embedQuery( + * "What would be a good company name for a company that makes colorful socks?" + * ); + * console.log({ res }); + * ``` + */ +export abstract class BaseGoogleEmbeddings + extends Embeddings + implements BaseGoogleEmbeddingsParams +{ + model: string; + + private connection: GoogleAIConnection< + BaseGoogleEmbeddingsOptions, + GoogleEmbeddingsInstance[], + GoogleConnectionParams + >; + + constructor(fields: BaseGoogleEmbeddingsParams) { + super(fields); + + this.model = fields.model; + this.connection = new EmbeddingsConnection( + { ...fields, ...this }, + this.caller, + this.buildClient(fields), + false + ); + } + + abstract buildAbstractedClient( + fields?: GoogleConnectionParams + ): GoogleAbstractedClient; + + buildApiKeyClient(apiKey: string): GoogleAbstractedClient { + return new ApiKeyGoogleAuth(apiKey); + } + + buildApiKey( + fields?: GoogleConnectionParams + ): string | undefined { + return fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY"); + } + + buildClient( + fields?: GoogleConnectionParams + ): GoogleAbstractedClient { + const apiKey = this.buildApiKey(fields); + if (apiKey) { + return this.buildApiKeyClient(apiKey); + } else { + return this.buildAbstractedClient(fields); + } + } + + /** + * Takes an array of documents as input and returns a promise that + * resolves to a 2D array of embeddings for each document. It splits the + * documents into chunks and makes requests to the Google Vertex AI API to + * generate embeddings. + * @param documents An array of documents to be embedded. + * @returns A promise that resolves to a 2D array of embeddings for each document. + */ + async embedDocuments(documents: string[]): Promise { + const instanceChunks: GoogleEmbeddingsInstance[][] = chunkArray( + documents.map((document) => ({ + content: document, + })), + 5 + ); // Vertex AI accepts max 5 instances per prediction + const parameters = {}; + const options = {}; + const responses = await Promise.all( + instanceChunks.map((instances) => + this.connection.request(instances, parameters, options) + ) + ); + const result: number[][] = + responses + ?.map( + (response) => + (response?.data as any)?.predictions?.map( + (result: any) => result.embeddings.values + ) ?? [] + ) + .flat() ?? []; + return result; + } + + /** + * Takes a document as input and returns a promise that resolves to an + * embedding for the document. It calls the embedDocuments method with the + * document as the input. + * @param document A document to be embedded. + * @returns A promise that resolves to an embedding for the document. + */ + async embedQuery(document: string): Promise { + const data = await this.embedDocuments([document]); + return data[0]; + } +} diff --git a/libs/langchain-google-common/src/index.ts b/libs/langchain-google-common/src/index.ts index 3e4311e2b0..18e62d415f 100644 --- a/libs/langchain-google-common/src/index.ts +++ b/libs/langchain-google-common/src/index.ts @@ -1,5 +1,6 @@ export * from "./chat_models.js"; export * from "./llms.js"; +export * from "./embeddings.js"; export * from "./auth.js"; export * from "./connection.js"; diff --git a/libs/langchain-google-gauth/src/auth.ts b/libs/langchain-google-gauth/src/auth.ts index 51bc1cbba3..21093bcbce 100644 --- a/libs/langchain-google-gauth/src/auth.ts +++ b/libs/langchain-google-gauth/src/auth.ts @@ -3,7 +3,7 @@ import { ensureAuthOptionScopes, GoogleAbstractedClient, GoogleAbstractedClientOps, - GoogleBaseLLMInput, + GoogleConnectionParams, JsonStream, } from "@langchain/google-common"; import { GoogleAuth, GoogleAuthOptions } from "google-auth-library"; @@ -27,7 +27,7 @@ export class NodeJsonStream extends JsonStream { export class GAuthClient implements GoogleAbstractedClient { gauth: GoogleAuth; - constructor(fields?: GoogleBaseLLMInput) { + constructor(fields?: GoogleConnectionParams) { const options = ensureAuthOptionScopes( fields?.authOptions, "scopes", diff --git a/libs/langchain-google-gauth/src/embeddings.ts b/libs/langchain-google-gauth/src/embeddings.ts new file mode 100644 index 0000000000..e9e668da92 --- /dev/null +++ b/libs/langchain-google-gauth/src/embeddings.ts @@ -0,0 +1,39 @@ +import { + GoogleAbstractedClient, + GoogleConnectionParams, + BaseGoogleEmbeddings, + BaseGoogleEmbeddingsParams, +} from "@langchain/google-common"; +import { GoogleAuthOptions } from "google-auth-library"; +import { GAuthClient } from "./auth.js"; + +/** + * Input to LLM class. + */ +export interface GoogleEmbeddingsInput + extends BaseGoogleEmbeddingsParams {} + +/** + * Integration with an LLM. + */ +export class GoogleEmbeddings + extends BaseGoogleEmbeddings + implements GoogleEmbeddingsInput +{ + // Used for tracing, replace with the same name as your class + static lc_name() { + return "GoogleEmbeddings"; + } + + lc_serializable = true; + + constructor(fields: GoogleEmbeddingsInput) { + super(fields); + } + + buildAbstractedClient( + fields?: GoogleConnectionParams + ): GoogleAbstractedClient { + return new GAuthClient(fields); + } +} diff --git a/libs/langchain-google-gauth/src/index.ts b/libs/langchain-google-gauth/src/index.ts index 2c8aa4ecb4..7f420a4ed6 100644 --- a/libs/langchain-google-gauth/src/index.ts +++ b/libs/langchain-google-gauth/src/index.ts @@ -1,2 +1,3 @@ export * from "./chat_models.js"; export * from "./llms.js"; +export * from "./embeddings.js"; diff --git a/libs/langchain-google-vertexai-web/src/embeddings.ts b/libs/langchain-google-vertexai-web/src/embeddings.ts new file mode 100644 index 0000000000..a4e86019ec --- /dev/null +++ b/libs/langchain-google-vertexai-web/src/embeddings.ts @@ -0,0 +1,25 @@ +import { + type GoogleEmbeddingsInput, + GoogleEmbeddings, +} from "@langchain/google-webauth"; + +/** + * Input to chat model class. + */ +export interface GoogleVertexAIEmbeddingsInput extends GoogleEmbeddingsInput {} + +/** + * Integration with a chat model. + */ +export class GoogleVertexAIEmbeddings extends GoogleEmbeddings { + static lc_name() { + return "GoogleVertexAIEmbeddings"; + } + + constructor(fields: GoogleVertexAIEmbeddingsInput) { + super({ + ...fields, + platformType: "gcp", + }); + } +} diff --git a/libs/langchain-google-vertexai-web/src/index.ts b/libs/langchain-google-vertexai-web/src/index.ts index 2c8aa4ecb4..7f420a4ed6 100644 --- a/libs/langchain-google-vertexai-web/src/index.ts +++ b/libs/langchain-google-vertexai-web/src/index.ts @@ -1,2 +1,3 @@ export * from "./chat_models.js"; export * from "./llms.js"; +export * from "./embeddings.js"; diff --git a/libs/langchain-google-vertexai-web/src/tests/embeddings.int.test.ts b/libs/langchain-google-vertexai-web/src/tests/embeddings.int.test.ts new file mode 100644 index 0000000000..94ed3b48f5 --- /dev/null +++ b/libs/langchain-google-vertexai-web/src/tests/embeddings.int.test.ts @@ -0,0 +1,29 @@ +import { test, expect } from "@jest/globals"; +import { GoogleVertexAIEmbeddings } from "../embeddings.js"; + +test("Test GoogleVertexAIEmbeddings.embedQuery", async () => { + const embeddings = new GoogleVertexAIEmbeddings({ + model: "textembedding-gecko", + }); + const res = await embeddings.embedQuery("Hello world"); + expect(typeof res[0]).toBe("number"); +}); + +test("Test GoogleVertexAIEmbeddings.embedDocuments", async () => { + const embeddings = new GoogleVertexAIEmbeddings({ + model: "text-embedding-004", + }); + const res = await embeddings.embedDocuments([ + "Hello world", + "Bye bye", + "we need", + "at least", + "six documents", + "to test pagination", + ]); + // console.log(res); + expect(res).toHaveLength(6); + res.forEach((r) => { + expect(typeof r[0]).toBe("number"); + }); +}); diff --git a/libs/langchain-google-vertexai/src/embeddings.ts b/libs/langchain-google-vertexai/src/embeddings.ts new file mode 100644 index 0000000000..13ab0a85f4 --- /dev/null +++ b/libs/langchain-google-vertexai/src/embeddings.ts @@ -0,0 +1,25 @@ +import { + type GoogleEmbeddingsInput, + GoogleEmbeddings, +} from "@langchain/google-gauth"; + +/** + * Input to chat model class. + */ +export interface GoogleVertexAIEmbeddingsInput extends GoogleEmbeddingsInput {} + +/** + * Integration with a chat model. + */ +export class GoogleVertexAIEmbeddings extends GoogleEmbeddings { + static lc_name() { + return "GoogleVertexAIEmbeddings"; + } + + constructor(fields: GoogleVertexAIEmbeddingsInput) { + super({ + ...fields, + platformType: "gcp", + }); + } +} diff --git a/libs/langchain-google-vertexai/src/index.ts b/libs/langchain-google-vertexai/src/index.ts index 2c8aa4ecb4..7f420a4ed6 100644 --- a/libs/langchain-google-vertexai/src/index.ts +++ b/libs/langchain-google-vertexai/src/index.ts @@ -1,2 +1,3 @@ export * from "./chat_models.js"; export * from "./llms.js"; +export * from "./embeddings.js"; diff --git a/libs/langchain-google-webauth/src/embeddings.ts b/libs/langchain-google-webauth/src/embeddings.ts new file mode 100644 index 0000000000..246f12ce1a --- /dev/null +++ b/libs/langchain-google-webauth/src/embeddings.ts @@ -0,0 +1,38 @@ +import { + GoogleAbstractedClient, + GoogleConnectionParams, + BaseGoogleEmbeddings, + BaseGoogleEmbeddingsParams, +} from "@langchain/google-common"; +import { WebGoogleAuth, WebGoogleAuthOptions } from "./auth.js"; + +/** + * Input to LLM class. + */ +export interface GoogleEmbeddingsInput + extends BaseGoogleEmbeddingsParams {} + +/** + * Integration with an LLM. + */ +export class GoogleEmbeddings + extends BaseGoogleEmbeddings + implements GoogleEmbeddingsInput +{ + // Used for tracing, replace with the same name as your class + static lc_name() { + return "GoogleEmbeddings"; + } + + lc_serializable = true; + + constructor(fields: GoogleEmbeddingsInput) { + super(fields); + } + + buildAbstractedClient( + fields?: GoogleConnectionParams + ): GoogleAbstractedClient { + return new WebGoogleAuth(fields); + } +} diff --git a/libs/langchain-google-webauth/src/index.ts b/libs/langchain-google-webauth/src/index.ts index 2c8aa4ecb4..7f420a4ed6 100644 --- a/libs/langchain-google-webauth/src/index.ts +++ b/libs/langchain-google-webauth/src/index.ts @@ -1,2 +1,3 @@ export * from "./chat_models.js"; export * from "./llms.js"; +export * from "./embeddings.js"; From e413dcc58f5141343236e39ff123b5aaad954c85 Mon Sep 17 00:00:00 2001 From: "local-dev-korbit-ai-mentor[bot]" <130798245+local-dev-korbit-ai-mentor[bot]@users.noreply.github.com> Date: Thu, 15 Aug 2024 17:38:56 +0000 Subject: [PATCH 2/2] [skip ci]