From 831036f1a15897100f9efcc4007ffb27d49a9637 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Fri, 13 Dec 2024 12:52:59 +0100 Subject: [PATCH 1/2] feat(adapters): add embedding support for IBM vLLM Ref: #176 Signed-off-by: Tomas Dvorak --- package.json | 2 +- .../ibm_vllm_generate_protos.sh | 8 +- .../tsconfig.proto.json | 2 +- .../ibm_vllm_generate_protos/tsup.config.ts | 17 +- src/adapters/ibm-vllm/client.ts | 78 +- src/adapters/ibm-vllm/llm.ts | 48 +- .../proto/caikit_data_model_caikit_nlp.proto | 123 ++ .../proto/caikit_data_model_common.proto | 199 ++ .../proto/caikit_data_model_nlp.proto | 199 ++ .../proto/caikit_data_model_runtime.proto | 58 + .../ibm-vllm/proto/caikit_runtime_Nlp.proto | 248 +++ src/adapters/ibm-vllm/types.ts | 1598 +++++++++++++++++ src/internals/types.ts | 4 + yarn.lock | 29 +- 14 files changed, 2565 insertions(+), 48 deletions(-) create mode 100644 src/adapters/ibm-vllm/proto/caikit_data_model_caikit_nlp.proto create mode 100644 src/adapters/ibm-vllm/proto/caikit_data_model_common.proto create mode 100644 src/adapters/ibm-vllm/proto/caikit_data_model_nlp.proto create mode 100644 src/adapters/ibm-vllm/proto/caikit_data_model_runtime.proto create mode 100644 src/adapters/ibm-vllm/proto/caikit_runtime_Nlp.proto diff --git a/package.json b/package.json index b2dbaed7..8eed8816 100644 --- a/package.json +++ b/package.json @@ -179,7 +179,7 @@ "mathjs": "^14.0.0", "mustache": "^4.2.0", "object-hash": "^3.0.0", - "p-queue": "^8.0.1", + "p-queue-compat": "^1.0.227", "p-throttle": "^7.0.0", "pino": "^9.5.0", "promise-based-task": "^3.1.1", diff --git a/scripts/ibm_vllm_generate_protos/ibm_vllm_generate_protos.sh b/scripts/ibm_vllm_generate_protos/ibm_vllm_generate_protos.sh index 1b2e7589..8b412ac1 100755 --- a/scripts/ibm_vllm_generate_protos/ibm_vllm_generate_protos.sh +++ b/scripts/ibm_vllm_generate_protos/ibm_vllm_generate_protos.sh @@ -19,8 +19,8 @@ GRPC_PROTO_PATH="./src/adapters/ibm-vllm/proto" GRPC_TYPES_PATH="./src/adapters/ibm-vllm/types.ts" SCRIPT_DIR="$(dirname "$0")" -OUTPUT_RELATIVE_PATH="dist/generation.d.ts" -GRPC_TYPES_TMP_PATH=types +OUTPUT_RELATIVE_PATH="dist/merged.d.ts" +GRPC_TYPES_TMP_PATH="types" rm -f "$GRPC_TYPES_PATH" @@ -39,7 +39,7 @@ yarn run proto-loader-gen-types \ cd "$SCRIPT_DIR" - tsup --dts-only + ENTRY="$(basename "$OUTPUT_RELATIVE_PATH" ".d.ts")" tsup --dts-only sed -i.bak '$ d' "$OUTPUT_RELATIVE_PATH" sed -i.bak -E "s/^interface/export interface/" "$OUTPUT_RELATIVE_PATH" sed -i.bak -E "s/^type/export type/" "$OUTPUT_RELATIVE_PATH" @@ -50,4 +50,4 @@ rm -rf "${SCRIPT_DIR}"/{dist,dts,types} yarn run lint:fix "${GRPC_TYPES_PATH}" yarn prettier --write "${GRPC_TYPES_PATH}" -yarn copyright +TARGETS="$GRPC_TYPES_PATH" yarn copyright diff --git a/scripts/ibm_vllm_generate_protos/tsconfig.proto.json b/scripts/ibm_vllm_generate_protos/tsconfig.proto.json index 3f3a9b39..0e4a32c2 100644 --- a/scripts/ibm_vllm_generate_protos/tsconfig.proto.json +++ b/scripts/ibm_vllm_generate_protos/tsconfig.proto.json @@ -4,7 +4,7 @@ "rootDir": ".", "baseUrl": ".", "target": "ESNext", - "module": "ES6", + "module": "ESNext", "outDir": "dist", "declaration": true, "emitDeclarationOnly": true, diff --git a/scripts/ibm_vllm_generate_protos/tsup.config.ts b/scripts/ibm_vllm_generate_protos/tsup.config.ts index 3cf25ebf..71acab22 100644 --- a/scripts/ibm_vllm_generate_protos/tsup.config.ts +++ b/scripts/ibm_vllm_generate_protos/tsup.config.ts @@ -15,14 +15,27 @@ */ import { defineConfig } from "tsup"; +import fs from "node:fs"; + +if (!process.env.ENTRY) { + throw new Error(`Entry file was not provided!`); +} +const target = `types/${process.env.ENTRY}.ts`; +await fs.promises.writeFile( + target, + [ + `export { ProtoGrpcType as A } from "./caikit_runtime_Nlp.js"`, + `export { ProtoGrpcType as B } from "./generation.js"`, + ].join("\n"), +); export default defineConfig({ - entry: ["types/generation.ts"], + entry: [target], tsconfig: "./tsconfig.proto.json", sourcemap: false, dts: true, format: ["esm"], - treeshake: false, + treeshake: true, legacyOutput: false, skipNodeModulesBundle: true, bundle: true, diff --git a/src/adapters/ibm-vllm/client.ts b/src/adapters/ibm-vllm/client.ts index be4b8746..b294abf6 100644 --- a/src/adapters/ibm-vllm/client.ts +++ b/src/adapters/ibm-vllm/client.ts @@ -25,25 +25,31 @@ import * as R from "remeda"; // eslint-disable-next-line no-restricted-imports import { UnaryCallback } from "@grpc/grpc-js/build/src/client.js"; import { FrameworkError, ValueError } from "@/errors.js"; -import protoLoader from "@grpc/proto-loader"; +import protoLoader, { Options } from "@grpc/proto-loader"; import { BatchedGenerationRequest, BatchedGenerationResponse__Output, BatchedTokenizeRequest, BatchedTokenizeResponse__Output, + type EmbeddingTasksRequest, GenerationRequest__Output, ModelInfoRequest, ModelInfoResponse__Output, ProtoGrpcType as GenerationProtoGentypes, + ProtoGrpcType$1 as CaikitProtoGentypes, SingleGenerationRequest, + EmbeddingResults__Output, + type SubtypeConstructor, } from "@/adapters/ibm-vllm/types.js"; import { parseEnv } from "@/internals/env.js"; import { z } from "zod"; import { Cache } from "@/cache/decoratorCache.js"; import { Serializable } from "@/internals/serializable.js"; +import PQueue from "p-queue-compat"; const GENERATION_PROTO_PATH = new URL("./proto/generation.proto", import.meta.url); +const NLP_PROTO_PATH = new URL("./proto/caikit_runtime_Nlp.proto", import.meta.url); interface ClientOptions { modelRouterSubdomain?: string; @@ -55,6 +61,11 @@ interface ClientOptions { }; grpcClientOptions: GRPCClientOptions; clientShutdownDelay: number; + limits?: { + concurrency?: { + embeddings?: number; + }; + }; } const defaultOptions = { @@ -66,18 +77,24 @@ const defaultOptions = { }, }; -const generationPackageObject = grpc.loadPackageDefinition( - protoLoader.loadSync([GENERATION_PROTO_PATH.pathname], { - longs: Number, - enums: String, - arrays: true, - objects: true, - oneofs: true, - keepCase: true, - defaults: true, - }), +const grpcConfig: Options = { + longs: Number, + enums: String, + arrays: true, + objects: true, + oneofs: true, + keepCase: true, + defaults: true, +}; + +const generationPackage = grpc.loadPackageDefinition( + protoLoader.loadSync([GENERATION_PROTO_PATH.pathname], grpcConfig), ) as unknown as GenerationProtoGentypes; +const embeddingsPackage = grpc.loadPackageDefinition( + protoLoader.loadSync([NLP_PROTO_PATH.pathname], grpcConfig), +) as unknown as CaikitProtoGentypes; + const GRPC_CLIENT_TTL = 15 * 60 * 1000; type CallOptions = GRPCCallOptions & { signal?: AbortSignal }; @@ -88,9 +105,12 @@ export class Client extends Serializable { private usedDefaultCredentials = false; @Cache({ ttl: GRPC_CLIENT_TTL }) - protected getClient(modelId: string) { + protected getClient void }>( + modelId: string, + factory: SubtypeConstructor, + ): T { const modelSpecificUrl = this.options.url.replace(/{model_id}/, modelId.replaceAll("/", "--")); - const client = new generationPackageObject.fmaas.GenerationService( + const client = new factory( modelSpecificUrl, grpc.credentials.createSsl( Buffer.from(this.options.credentials.rootCert), @@ -129,33 +149,47 @@ export class Client extends Serializable { } async modelInfo(request: RequiredModel, options?: CallOptions) { - const client = this.getClient(request.model_id); + const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService); return this.wrapGrpcCall( client.modelInfo.bind(client), )(request, options); } async generate(request: RequiredModel, options?: CallOptions) { - const client = this.getClient(request.model_id); + const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService); return this.wrapGrpcCall( client.generate.bind(client), )(request, options); } async generateStream(request: RequiredModel, options?: CallOptions) { - const client = this.getClient(request.model_id); + const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService); return this.wrapGrpcStream( client.generateStream.bind(client), )(request, options); } async tokenize(request: RequiredModel, options?: CallOptions) { - const client = this.getClient(request.model_id); + const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService); return this.wrapGrpcCall( client.tokenize.bind(client), )(request, options); } + async embed(request: RequiredModel, options?: CallOptions) { + const client = this.getClient( + request.model_id, + embeddingsPackage.caikit.runtime.Nlp.NlpService, + ); + return this.queues.embeddings.add( + () => + this.wrapGrpcCall( + client.embeddingTasksPredict.bind(client), + )(request, options), + { throwOnTimeout: true }, + ); + } + protected wrapGrpcCall( fn: ( request: TRequest, @@ -213,4 +247,14 @@ export class Client extends Serializable { Object.assign(this, snapshot); this.options.credentials = this.getDefaultCredentials(); } + + @Cache({ enumerable: false }) + protected get queues() { + return { + embeddings: new PQueue({ + concurrency: this.options.limits?.concurrency?.embeddings ?? 5, + throwOnTimeout: true, + }), + }; + } } diff --git a/src/adapters/ibm-vllm/llm.ts b/src/adapters/ibm-vllm/llm.ts index 068aa8b1..16ec5efc 100644 --- a/src/adapters/ibm-vllm/llm.ts +++ b/src/adapters/ibm-vllm/llm.ts @@ -27,8 +27,12 @@ import { LLMError, LLMMeta, } from "@/llms/base.js"; -import { isEmpty, isString } from "remeda"; -import type { DecodingParameters, SingleGenerationRequest } from "@/adapters/ibm-vllm/types.js"; +import { chunk, isEmpty, isString } from "remeda"; +import type { + DecodingParameters, + SingleGenerationRequest, + EmbeddingTasksRequest, +} from "@/adapters/ibm-vllm/types.js"; import { LLM, LLMEvents, LLMInput } from "@/llms/llm.js"; import { Emitter } from "@/emitter/emitter.js"; import { GenerationResponse__Output } from "@/adapters/ibm-vllm/types.js"; @@ -39,6 +43,7 @@ import { ServiceError } from "@grpc/grpc-js"; import { Client } from "@/adapters/ibm-vllm/client.js"; import { GetRunContext } from "@/context.js"; import { BatchedGenerationRequest } from "./types.js"; +import { OmitPrivateKeys } from "@/internals/types.js"; function isGrpcServiceError(err: unknown): err is ServiceError { return ( @@ -100,6 +105,12 @@ export type IBMvLLMParameters = NonNullable< export interface IBMvLLMGenerateOptions extends GenerateOptions {} +export interface IBMvLLMEmbeddingOptions + extends EmbeddingOptions, + Omit, "texts"> { + chunkSize?: number; +} + export type IBMvLLMEvents = LLMEvents; export class IBMvLLM extends LLM { @@ -128,9 +139,36 @@ export class IBMvLLM extends LLM { }; } - // eslint-disable-next-line unused-imports/no-unused-vars - async embed(input: LLMInput[], options?: EmbeddingOptions): Promise { - throw new NotImplementedError(); + async embed( + input: LLMInput[], + { chunkSize, signal, ...options }: IBMvLLMEmbeddingOptions = {}, + ): Promise { + const results = await Promise.all( + chunk(input, chunkSize ?? 100).map(async (texts) => { + const response = await this.client.embed( + { + model_id: this.modelId, + truncate_input_tokens: options?.truncate_input_tokens ?? 512, + texts, + }, + { + signal, + }, + ); + const embeddings = response.results?.vectors.map((vector) => { + const embedding = vector[vector.data]?.values; + if (!embedding) { + throw new LLMError("Missing embedding"); + } + return embedding; + }); + if (embeddings?.length !== texts.length) { + throw new LLMError("Missing embedding"); + } + return embeddings; + }), + ); + return { embeddings: results.flat() }; } async tokenize(input: LLMInput): Promise { diff --git a/src/adapters/ibm-vllm/proto/caikit_data_model_caikit_nlp.proto b/src/adapters/ibm-vllm/proto/caikit_data_model_caikit_nlp.proto new file mode 100644 index 00000000..d449471f --- /dev/null +++ b/src/adapters/ibm-vllm/proto/caikit_data_model_caikit_nlp.proto @@ -0,0 +1,123 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Source: https://github.com/IBM/vllm/blob/main/proto/caikit_data_model_caikit_nlp.proto + +/*------------------------------------------------------------------------------ + * AUTO GENERATED + *----------------------------------------------------------------------------*/ + +syntax = "proto3"; +package caikit_data_model.caikit_nlp; +import "google/protobuf/struct.proto"; +import "caikit_data_model_common.proto"; + + +/*-- MESSAGES ----------------------------------------------------------------*/ + +message EmbeddingResult { + + /*-- fields --*/ + caikit_data_model.common.Vector1D result = 1; + caikit_data_model.common.ProducerId producer_id = 2; + int64 input_token_count = 3; +} + +message EmbeddingResults { + + /*-- fields --*/ + caikit_data_model.common.ListOfVector1D results = 1; + caikit_data_model.common.ProducerId producer_id = 2; + int64 input_token_count = 3; +} + +message ExponentialDecayLengthPenalty { + + /*-- fields --*/ + int64 start_index = 1; + double decay_factor = 2; +} + +message GenerationTrainRecord { + + /*-- fields --*/ + string input = 1; + string output = 2; +} + +message RerankResult { + + /*-- fields --*/ + caikit_data_model.caikit_nlp.RerankScores result = 1; + caikit_data_model.common.ProducerId producer_id = 2; + int64 input_token_count = 3; +} + +message RerankResults { + + /*-- fields --*/ + repeated caikit_data_model.caikit_nlp.RerankScores results = 1; + caikit_data_model.common.ProducerId producer_id = 2; + int64 input_token_count = 3; +} + +message RerankScore { + + /*-- fields --*/ + google.protobuf.Struct document = 1; + int64 index = 2; + double score = 3; + string text = 4; +} + +message RerankScores { + + /*-- fields --*/ + string query = 1; + repeated caikit_data_model.caikit_nlp.RerankScore scores = 2; +} + +message SentenceSimilarityResult { + + /*-- fields --*/ + caikit_data_model.caikit_nlp.SentenceSimilarityScores result = 1; + caikit_data_model.common.ProducerId producer_id = 2; + int64 input_token_count = 3; +} + +message SentenceSimilarityResults { + + /*-- fields --*/ + repeated caikit_data_model.caikit_nlp.SentenceSimilarityScores results = 1; + caikit_data_model.common.ProducerId producer_id = 2; + int64 input_token_count = 3; +} + +message SentenceSimilarityScores { + + /*-- fields --*/ + repeated double scores = 1; +} + +message TuningConfig { + + /*-- fields --*/ + int64 num_virtual_tokens = 1; + string prompt_tuning_init_text = 2; + string prompt_tuning_init_method = 3; + string prompt_tuning_init_source_model = 4; + repeated string output_model_types = 5; +} diff --git a/src/adapters/ibm-vllm/proto/caikit_data_model_common.proto b/src/adapters/ibm-vllm/proto/caikit_data_model_common.proto new file mode 100644 index 00000000..a212c2b7 --- /dev/null +++ b/src/adapters/ibm-vllm/proto/caikit_data_model_common.proto @@ -0,0 +1,199 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Source: https://github.com/IBM/vllm/blob/main/proto/caikit_data_model_common.proto + +/*------------------------------------------------------------------------------ + * AUTO GENERATED + *----------------------------------------------------------------------------*/ + +syntax = "proto3"; +package caikit_data_model.common; + + +/*-- ENUMS -------------------------------------------------------------------*/ + +enum TrainingStatus { + PLACEHOLDER_UNSET = 0; + QUEUED = 1; + RUNNING = 2; + COMPLETED = 3; + CANCELED = 4; + ERRORED = 5; +} + + +/*-- MESSAGES ----------------------------------------------------------------*/ + +message BoolSequence { + + /*-- fields --*/ + repeated bool values = 1; +} + +message ConnectionInfo { + + /*-- nested messages --*/ + + /*-- fields --*/ + string hostname = 1; + optional int64 port = 2; + optional caikit_data_model.common.ConnectionTlsInfo tls = 3; + optional int64 timeout = 4; + map options = 5; +} + +message ConnectionTlsInfo { + + /*-- fields --*/ + optional bool enabled = 1; + optional bool insecure_verify = 2; + optional string ca_file = 3; + optional string cert_file = 4; + optional string key_file = 5; +} + +message Directory { + + /*-- fields --*/ + string dirname = 1; + string extension = 2; +} + +message File { + + /*-- fields --*/ + bytes data = 1; + string filename = 2; + string type = 3; +} + +message FileReference { + + /*-- fields --*/ + string filename = 1; +} + +message FloatSequence { + + /*-- fields --*/ + repeated double values = 1; +} + +message IntSequence { + + /*-- fields --*/ + repeated int64 values = 1; +} + +message ListOfFileReferences { + + /*-- fields --*/ + repeated string files = 1; +} + +message ListOfVector1D { + + /*-- fields --*/ + repeated caikit_data_model.common.Vector1D vectors = 1; +} + +message NpFloat32Sequence { + + /*-- fields --*/ + repeated float values = 1; +} + +message NpFloat64Sequence { + + /*-- fields --*/ + repeated double values = 1; +} + +message ProducerId { + + /*-- fields --*/ + string name = 1; + string version = 2; +} + +message ProducerPriority { + + /*-- fields --*/ + repeated caikit_data_model.common.ProducerId producers = 1; +} + +message PyFloatSequence { + + /*-- fields --*/ + repeated double values = 1; +} + +message S3Base { + + /*-- fields --*/ + string endpoint = 2; + string region = 3; + string bucket = 4; + string accessKey = 5; + string secretKey = 6; + string IAM_id = 7; + string IAM_api_key = 8; +} + +message S3Files { + + /*-- fields --*/ + string endpoint = 2; + string region = 3; + string bucket = 4; + string accessKey = 5; + string secretKey = 6; + string IAM_id = 7; + string IAM_api_key = 8; + repeated string files = 1; +} + +message S3Path { + + /*-- fields --*/ + string endpoint = 2; + string region = 3; + string bucket = 4; + string accessKey = 5; + string secretKey = 6; + string IAM_id = 7; + string IAM_api_key = 8; + string path = 1; +} + +message StrSequence { + + /*-- fields --*/ + repeated string values = 1; +} + +message Vector1D { + + /*-- fields --*/ + + /*-- oneofs --*/ + oneof data { + caikit_data_model.common.PyFloatSequence data_pyfloatsequence = 1; + caikit_data_model.common.NpFloat32Sequence data_npfloat32sequence = 2; + caikit_data_model.common.NpFloat64Sequence data_npfloat64sequence = 3; + } +} diff --git a/src/adapters/ibm-vllm/proto/caikit_data_model_nlp.proto b/src/adapters/ibm-vllm/proto/caikit_data_model_nlp.proto new file mode 100644 index 00000000..c1be0169 --- /dev/null +++ b/src/adapters/ibm-vllm/proto/caikit_data_model_nlp.proto @@ -0,0 +1,199 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Source: https://github.com/IBM/vllm/blob/main/proto/caikit_data_model_nlp.proto + +/*------------------------------------------------------------------------------ + * AUTO GENERATED + *----------------------------------------------------------------------------*/ + +syntax = "proto3"; +package caikit_data_model.nlp; +import "caikit_data_model_common.proto"; + + +/*-- ENUMS -------------------------------------------------------------------*/ + +enum FinishReason { + NOT_FINISHED = 0; + MAX_TOKENS = 1; + EOS_TOKEN = 2; + CANCELLED = 3; + TIME_LIMIT = 4; + STOP_SEQUENCE = 5; + TOKEN_LIMIT = 6; + ERROR = 7; +} + +enum InputWarningReason { + UNSUITABLE_INPUT = 0; +} + + +/*-- MESSAGES ----------------------------------------------------------------*/ + +message ClassificationResult { + + /*-- fields --*/ + string label = 1; + double score = 2; +} + +message ClassificationResults { + + /*-- fields --*/ + repeated caikit_data_model.nlp.ClassificationResult results = 1; +} + +message ClassificationTrainRecord { + + /*-- fields --*/ + string text = 1; + repeated string labels = 2; +} + +message ClassifiedGeneratedTextResult { + + /*-- fields --*/ + string generated_text = 1; + caikit_data_model.nlp.TextGenTokenClassificationResults token_classification_results = 2; + caikit_data_model.nlp.FinishReason finish_reason = 3; + int64 generated_token_count = 4; + uint64 seed = 5; + int64 input_token_count = 6; + repeated caikit_data_model.nlp.InputWarning warnings = 9; + repeated caikit_data_model.nlp.GeneratedToken tokens = 10; + repeated caikit_data_model.nlp.GeneratedToken input_tokens = 11; +} + +message ClassifiedGeneratedTextStreamResult { + + /*-- fields --*/ + string generated_text = 1; + caikit_data_model.nlp.TextGenTokenClassificationResults token_classification_results = 2; + caikit_data_model.nlp.FinishReason finish_reason = 3; + int64 generated_token_count = 4; + uint64 seed = 5; + int64 input_token_count = 6; + repeated caikit_data_model.nlp.InputWarning warnings = 9; + repeated caikit_data_model.nlp.GeneratedToken tokens = 10; + repeated caikit_data_model.nlp.GeneratedToken input_tokens = 11; + int64 processed_index = 7; + int64 start_index = 8; +} + +message GeneratedTextResult { + + /*-- fields --*/ + string generated_text = 1; + int64 generated_tokens = 2; + caikit_data_model.nlp.FinishReason finish_reason = 3; + caikit_data_model.common.ProducerId producer_id = 4; + int64 input_token_count = 5; + uint64 seed = 6; + repeated caikit_data_model.nlp.GeneratedToken tokens = 7; + repeated caikit_data_model.nlp.GeneratedToken input_tokens = 8; +} + +message GeneratedTextStreamResult { + + /*-- fields --*/ + string generated_text = 1; + repeated caikit_data_model.nlp.GeneratedToken tokens = 2; + caikit_data_model.nlp.TokenStreamDetails details = 3; + caikit_data_model.common.ProducerId producer_id = 4; + repeated caikit_data_model.nlp.GeneratedToken input_tokens = 5; +} + +message GeneratedToken { + + /*-- fields --*/ + string text = 1; + double logprob = 3; +} + +message InputWarning { + + /*-- fields --*/ + caikit_data_model.nlp.InputWarningReason id = 1; + string message = 2; +} + +message TextGenTokenClassificationResults { + + /*-- fields --*/ + repeated caikit_data_model.nlp.TokenClassificationResult input = 10; + repeated caikit_data_model.nlp.TokenClassificationResult output = 20; +} + +message Token { + + /*-- fields --*/ + int64 start = 1; + int64 end = 2; + string text = 3; +} + +message TokenClassificationResult { + + /*-- fields --*/ + int64 start = 1; + int64 end = 2; + string word = 3; + string entity = 4; + string entity_group = 5; + double score = 6; + int64 token_count = 7; +} + +message TokenClassificationResults { + + /*-- fields --*/ + repeated caikit_data_model.nlp.TokenClassificationResult results = 1; +} + +message TokenClassificationStreamResult { + + /*-- fields --*/ + repeated caikit_data_model.nlp.TokenClassificationResult results = 1; + int64 processed_index = 2; + int64 start_index = 3; +} + +message TokenStreamDetails { + + /*-- fields --*/ + caikit_data_model.nlp.FinishReason finish_reason = 1; + uint32 generated_tokens = 2; + uint64 seed = 3; + int64 input_token_count = 4; +} + +message TokenizationResults { + + /*-- fields --*/ + repeated caikit_data_model.nlp.Token results = 1; + int64 token_count = 4; +} + +message TokenizationStreamResult { + + /*-- fields --*/ + repeated caikit_data_model.nlp.Token results = 1; + int64 token_count = 4; + int64 processed_index = 2; + int64 start_index = 3; +} diff --git a/src/adapters/ibm-vllm/proto/caikit_data_model_runtime.proto b/src/adapters/ibm-vllm/proto/caikit_data_model_runtime.proto new file mode 100644 index 00000000..24ce5669 --- /dev/null +++ b/src/adapters/ibm-vllm/proto/caikit_data_model_runtime.proto @@ -0,0 +1,58 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Source: https://github.com/IBM/vllm/blob/main/proto/caikit_data_model_runtime.proto + +/*------------------------------------------------------------------------------ + * AUTO GENERATED + *----------------------------------------------------------------------------*/ + +syntax = "proto3"; +package caikit_data_model.runtime; +import "google/protobuf/timestamp.proto"; +import "caikit_data_model_common.proto"; + + +/*-- MESSAGES ----------------------------------------------------------------*/ + +message ModelPointer { + + /*-- fields --*/ + string model_id = 1; +} + +message TrainingInfoRequest { + + /*-- fields --*/ + string training_id = 1; +} + +message TrainingJob { + + /*-- fields --*/ + string training_id = 1; + string model_name = 2; +} + +message TrainingStatusResponse { + + /*-- fields --*/ + string training_id = 1; + caikit_data_model.common.TrainingStatus state = 2; + google.protobuf.Timestamp submission_timestamp = 3; + google.protobuf.Timestamp completion_timestamp = 4; + repeated string reasons = 5; +} diff --git a/src/adapters/ibm-vllm/proto/caikit_runtime_Nlp.proto b/src/adapters/ibm-vllm/proto/caikit_runtime_Nlp.proto new file mode 100644 index 00000000..d51af243 --- /dev/null +++ b/src/adapters/ibm-vllm/proto/caikit_runtime_Nlp.proto @@ -0,0 +1,248 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Source: https://github.com/IBM/vllm/blob/main/proto/caikit_runtime_Nlp.proto + +/*------------------------------------------------------------------------------ + * AUTO GENERATED + *----------------------------------------------------------------------------*/ + +syntax = "proto3"; +package caikit.runtime.Nlp; +import "google/protobuf/struct.proto"; +import "caikit_data_model_caikit_nlp.proto"; +import "caikit_data_model_common.proto"; +import "caikit_data_model_nlp.proto"; +import "caikit_data_model_runtime.proto"; + + +/*-- MESSAGES ----------------------------------------------------------------*/ + +message BidiStreamingTokenClassificationTaskRequest { + + /*-- fields --*/ + string text_stream = 1; + optional double threshold = 2; +} + +message DataStreamSourceGenerationTrainRecord { + + /*-- fields --*/ + + /*-- oneofs --*/ + oneof data_stream { + caikit.runtime.Nlp.DataStreamSourceGenerationTrainRecordJsonData jsondata = 1; + caikit_data_model.common.FileReference file = 2; + caikit_data_model.common.ListOfFileReferences list_of_files = 3; + caikit_data_model.common.Directory directory = 4; + caikit_data_model.common.S3Files s3files = 5; + } +} + +message DataStreamSourceGenerationTrainRecordJsonData { + + /*-- fields --*/ + repeated caikit_data_model.caikit_nlp.GenerationTrainRecord data = 1; +} + +message EmbeddingTaskRequest { + + /*-- fields --*/ + string text = 1; + optional int64 truncate_input_tokens = 2; +} + +message EmbeddingTasksRequest { + + /*-- fields --*/ + repeated string texts = 1; + optional int64 truncate_input_tokens = 2; +} + +message RerankTaskRequest { + + /*-- fields --*/ + string query = 1; + repeated google.protobuf.Struct documents = 2; + optional int64 top_n = 3; + optional int64 truncate_input_tokens = 4; + optional bool return_documents = 5; + optional bool return_query = 6; + optional bool return_text = 7; +} + +message RerankTasksRequest { + + /*-- fields --*/ + repeated string queries = 1; + repeated google.protobuf.Struct documents = 2; + optional int64 top_n = 3; + optional int64 truncate_input_tokens = 4; + optional bool return_documents = 5; + optional bool return_queries = 6; + optional bool return_text = 7; +} + +message SentenceSimilarityTaskRequest { + + /*-- fields --*/ + string source_sentence = 1; + repeated string sentences = 2; + optional int64 truncate_input_tokens = 3; +} + +message SentenceSimilarityTasksRequest { + + /*-- fields --*/ + repeated string source_sentences = 1; + repeated string sentences = 2; + optional int64 truncate_input_tokens = 3; +} + +message ServerStreamingTextGenerationTaskRequest { + + /*-- fields --*/ + string text = 1; + optional int64 max_new_tokens = 2; + optional int64 min_new_tokens = 3; + optional int64 truncate_input_tokens = 4; + optional string decoding_method = 5; + optional int64 top_k = 6; + optional double top_p = 7; + optional double typical_p = 8; + optional double temperature = 9; + optional double repetition_penalty = 10; + optional double max_time = 11; + optional caikit_data_model.caikit_nlp.ExponentialDecayLengthPenalty exponential_decay_length_penalty = 12; + repeated string stop_sequences = 13; + optional uint64 seed = 14; + optional bool preserve_input_text = 15; +} + +message TextClassificationTaskRequest { + + /*-- fields --*/ + string text = 1; +} + +message TextGenerationTaskPeftPromptTuningTrainParameters { + + /*-- fields --*/ + string base_model = 1; + caikit.runtime.Nlp.DataStreamSourceGenerationTrainRecord train_stream = 2; + caikit_data_model.caikit_nlp.TuningConfig tuning_config = 3; + optional caikit.runtime.Nlp.DataStreamSourceGenerationTrainRecord val_stream = 4; + optional string device = 5; + optional string tuning_type = 6; + optional int64 num_epochs = 7; + optional double learning_rate = 8; + optional string verbalizer = 9; + optional int64 batch_size = 10; + optional int64 max_source_length = 11; + optional int64 max_target_length = 12; + optional int64 accumulate_steps = 13; + optional string torch_dtype = 14; + optional bool silence_progress_bars = 15; + optional int64 seed = 16; +} + +message TextGenerationTaskPeftPromptTuningTrainRequest { + + /*-- fields --*/ + string model_name = 1; + caikit_data_model.common.S3Path output_path = 2; + caikit.runtime.Nlp.TextGenerationTaskPeftPromptTuningTrainParameters parameters = 3; +} + +message TextGenerationTaskRequest { + + /*-- fields --*/ + string text = 1; + optional int64 max_new_tokens = 2; + optional int64 min_new_tokens = 3; + optional int64 truncate_input_tokens = 4; + optional string decoding_method = 5; + optional int64 top_k = 6; + optional double top_p = 7; + optional double typical_p = 8; + optional double temperature = 9; + optional double repetition_penalty = 10; + optional double max_time = 11; + optional caikit_data_model.caikit_nlp.ExponentialDecayLengthPenalty exponential_decay_length_penalty = 12; + repeated string stop_sequences = 13; + optional uint64 seed = 14; + optional bool preserve_input_text = 15; +} + +message TextGenerationTaskTextGenerationTrainParameters { + + /*-- fields --*/ + string base_model = 1; + caikit.runtime.Nlp.DataStreamSourceGenerationTrainRecord train_stream = 2; + optional string torch_dtype = 3; + optional int64 max_source_length = 4; + optional int64 max_target_length = 5; + optional int64 batch_size = 6; + optional int64 num_epochs = 7; + optional int64 accumulate_steps = 8; + optional int64 random_seed = 9; + optional double lr = 10; + optional bool use_iterable_dataset = 11; +} + +message TextGenerationTaskTextGenerationTrainRequest { + + /*-- fields --*/ + string model_name = 1; + caikit_data_model.common.S3Path output_path = 2; + caikit.runtime.Nlp.TextGenerationTaskTextGenerationTrainParameters parameters = 3; +} + +message TokenClassificationTaskRequest { + + /*-- fields --*/ + string text = 1; + optional double threshold = 2; +} + +message TokenizationTaskRequest { + + /*-- fields --*/ + string text = 1; +} + + +/*-- SERVICES ----------------------------------------------------------------*/ + +service NlpService { + rpc BidiStreamingTokenClassificationTaskPredict(stream caikit.runtime.Nlp.BidiStreamingTokenClassificationTaskRequest) returns (stream caikit_data_model.nlp.TokenClassificationStreamResult); + rpc EmbeddingTaskPredict(caikit.runtime.Nlp.EmbeddingTaskRequest) returns (caikit_data_model.caikit_nlp.EmbeddingResult); + rpc EmbeddingTasksPredict(caikit.runtime.Nlp.EmbeddingTasksRequest) returns (caikit_data_model.caikit_nlp.EmbeddingResults); + rpc RerankTaskPredict(caikit.runtime.Nlp.RerankTaskRequest) returns (caikit_data_model.caikit_nlp.RerankResult); + rpc RerankTasksPredict(caikit.runtime.Nlp.RerankTasksRequest) returns (caikit_data_model.caikit_nlp.RerankResults); + rpc SentenceSimilarityTaskPredict(caikit.runtime.Nlp.SentenceSimilarityTaskRequest) returns (caikit_data_model.caikit_nlp.SentenceSimilarityResult); + rpc SentenceSimilarityTasksPredict(caikit.runtime.Nlp.SentenceSimilarityTasksRequest) returns (caikit_data_model.caikit_nlp.SentenceSimilarityResults); + rpc ServerStreamingTextGenerationTaskPredict(caikit.runtime.Nlp.ServerStreamingTextGenerationTaskRequest) returns (stream caikit_data_model.nlp.GeneratedTextStreamResult); + rpc TextClassificationTaskPredict(caikit.runtime.Nlp.TextClassificationTaskRequest) returns (caikit_data_model.nlp.ClassificationResults); + rpc TextGenerationTaskPredict(caikit.runtime.Nlp.TextGenerationTaskRequest) returns (caikit_data_model.nlp.GeneratedTextResult); + rpc TokenClassificationTaskPredict(caikit.runtime.Nlp.TokenClassificationTaskRequest) returns (caikit_data_model.nlp.TokenClassificationResults); + rpc TokenizationTaskPredict(caikit.runtime.Nlp.TokenizationTaskRequest) returns (caikit_data_model.nlp.TokenizationResults); +} + +service NlpTrainingService { + rpc TextGenerationTaskPeftPromptTuningTrain(caikit.runtime.Nlp.TextGenerationTaskPeftPromptTuningTrainRequest) returns (caikit_data_model.runtime.TrainingJob); + rpc TextGenerationTaskTextGenerationTrain(caikit.runtime.Nlp.TextGenerationTaskTextGenerationTrainRequest) returns (caikit_data_model.runtime.TrainingJob); +} diff --git a/src/adapters/ibm-vllm/types.ts b/src/adapters/ibm-vllm/types.ts index 55bcc902..42fe39b9 100644 --- a/src/adapters/ibm-vllm/types.ts +++ b/src/adapters/ibm-vllm/types.ts @@ -22,6 +22,1604 @@ import { EnumTypeDefinition, } from "@grpc/proto-loader"; +export interface BidiStreamingTokenClassificationTaskRequest { + text_stream?: string; + threshold?: number | string; + _threshold?: "threshold"; +} +export interface BidiStreamingTokenClassificationTaskRequest__Output { + text_stream: string; + threshold?: number; + _threshold: "threshold"; +} + +export interface ClassificationResult { + label?: string; + score?: number | string; +} +export interface ClassificationResult__Output { + label: string; + score: number; +} + +export interface ClassificationResults { + results?: ClassificationResult[]; +} +export interface ClassificationResults__Output { + results: ClassificationResult__Output[]; +} + +export interface PyFloatSequence { + values?: (number | string)[]; +} +export interface PyFloatSequence__Output { + values: number[]; +} + +export interface NpFloat32Sequence { + values?: (number | string)[]; +} +export interface NpFloat32Sequence__Output { + values: number[]; +} + +export interface NpFloat64Sequence { + values?: (number | string)[]; +} +export interface NpFloat64Sequence__Output { + values: number[]; +} + +export interface Vector1D { + data_pyfloatsequence?: PyFloatSequence | null; + data_npfloat32sequence?: NpFloat32Sequence | null; + data_npfloat64sequence?: NpFloat64Sequence | null; + data?: "data_pyfloatsequence" | "data_npfloat32sequence" | "data_npfloat64sequence"; +} +export interface Vector1D__Output { + data_pyfloatsequence?: PyFloatSequence__Output | null; + data_npfloat32sequence?: NpFloat32Sequence__Output | null; + data_npfloat64sequence?: NpFloat64Sequence__Output | null; + data: "data_pyfloatsequence" | "data_npfloat32sequence" | "data_npfloat64sequence"; +} + +export interface ProducerId { + name?: string; + version?: string; +} +export interface ProducerId__Output { + name: string; + version: string; +} + +export interface EmbeddingResult { + result?: Vector1D | null; + producer_id?: ProducerId | null; + input_token_count?: number | string | Long; +} +export interface EmbeddingResult__Output { + result: Vector1D__Output | null; + producer_id: ProducerId__Output | null; + input_token_count: number; +} + +export interface ListOfVector1D { + vectors?: Vector1D[]; +} +export interface ListOfVector1D__Output { + vectors: Vector1D__Output[]; +} + +export interface EmbeddingResults { + results?: ListOfVector1D | null; + producer_id?: ProducerId | null; + input_token_count?: number | string | Long; +} +export interface EmbeddingResults__Output { + results: ListOfVector1D__Output | null; + producer_id: ProducerId__Output | null; + input_token_count: number; +} + +export interface EmbeddingTaskRequest { + text?: string; + truncate_input_tokens?: number | string | Long; + _truncate_input_tokens?: "truncate_input_tokens"; +} +export interface EmbeddingTaskRequest__Output { + text: string; + truncate_input_tokens?: number; + _truncate_input_tokens: "truncate_input_tokens"; +} + +export interface EmbeddingTasksRequest { + texts?: string[]; + truncate_input_tokens?: number | string | Long; + _truncate_input_tokens?: "truncate_input_tokens"; +} +export interface EmbeddingTasksRequest__Output { + texts: string[]; + truncate_input_tokens?: number; + _truncate_input_tokens: "truncate_input_tokens"; +} + +declare const FinishReason: { + readonly NOT_FINISHED: "NOT_FINISHED"; + readonly MAX_TOKENS: "MAX_TOKENS"; + readonly EOS_TOKEN: "EOS_TOKEN"; + readonly CANCELLED: "CANCELLED"; + readonly TIME_LIMIT: "TIME_LIMIT"; + readonly STOP_SEQUENCE: "STOP_SEQUENCE"; + readonly TOKEN_LIMIT: "TOKEN_LIMIT"; + readonly ERROR: "ERROR"; +}; +export type FinishReason = + | "NOT_FINISHED" + | 0 + | "MAX_TOKENS" + | 1 + | "EOS_TOKEN" + | 2 + | "CANCELLED" + | 3 + | "TIME_LIMIT" + | 4 + | "STOP_SEQUENCE" + | 5 + | "TOKEN_LIMIT" + | 6 + | "ERROR" + | 7; +export type FinishReason__Output = (typeof FinishReason)[keyof typeof FinishReason]; + +export interface GeneratedToken { + text?: string; + logprob?: number | string; +} +export interface GeneratedToken__Output { + text: string; + logprob: number; +} + +export interface GeneratedTextResult { + generated_text?: string; + generated_tokens?: number | string | Long; + finish_reason?: FinishReason; + producer_id?: ProducerId | null; + input_token_count?: number | string | Long; + seed?: number | string | Long; + tokens?: GeneratedToken[]; + input_tokens?: GeneratedToken[]; +} +export interface GeneratedTextResult__Output { + generated_text: string; + generated_tokens: number; + finish_reason: FinishReason__Output; + producer_id: ProducerId__Output | null; + input_token_count: number; + seed: number; + tokens: GeneratedToken__Output[]; + input_tokens: GeneratedToken__Output[]; +} + +export interface TokenStreamDetails { + finish_reason?: FinishReason; + generated_tokens?: number; + seed?: number | string | Long; + input_token_count?: number | string | Long; +} +export interface TokenStreamDetails__Output { + finish_reason: FinishReason__Output; + generated_tokens: number; + seed: number; + input_token_count: number; +} + +export interface GeneratedTextStreamResult { + generated_text?: string; + tokens?: GeneratedToken[]; + details?: TokenStreamDetails | null; + producer_id?: ProducerId | null; + input_tokens?: GeneratedToken[]; +} +export interface GeneratedTextStreamResult__Output { + generated_text: string; + tokens: GeneratedToken__Output[]; + details: TokenStreamDetails__Output | null; + producer_id: ProducerId__Output | null; + input_tokens: GeneratedToken__Output[]; +} + +declare const NullValue: { + readonly NULL_VALUE: "NULL_VALUE"; +}; +export type NullValue = "NULL_VALUE" | 0; +export type NullValue__Output = (typeof NullValue)[keyof typeof NullValue]; + +export interface ListValue { + values?: Value[]; +} +export interface ListValue__Output { + values: Value__Output[]; +} + +export interface Value { + nullValue?: NullValue; + numberValue?: number | string; + stringValue?: string; + boolValue?: boolean; + structValue?: Struct | null; + listValue?: ListValue | null; + kind?: "nullValue" | "numberValue" | "stringValue" | "boolValue" | "structValue" | "listValue"; +} +export interface Value__Output { + nullValue?: NullValue__Output; + numberValue?: number; + stringValue?: string; + boolValue?: boolean; + structValue?: Struct__Output | null; + listValue?: ListValue__Output | null; + kind: "nullValue" | "numberValue" | "stringValue" | "boolValue" | "structValue" | "listValue"; +} + +export interface Struct { + fields?: Record; +} +export interface Struct__Output { + fields: Record; +} + +export interface RerankScore { + document?: Struct | null; + index?: number | string | Long; + score?: number | string; + text?: string; +} +export interface RerankScore__Output { + document: Struct__Output | null; + index: number; + score: number; + text: string; +} + +export interface RerankScores { + query?: string; + scores?: RerankScore[]; +} +export interface RerankScores__Output { + query: string; + scores: RerankScore__Output[]; +} + +export interface RerankResult { + result?: RerankScores | null; + producer_id?: ProducerId | null; + input_token_count?: number | string | Long; +} +export interface RerankResult__Output { + result: RerankScores__Output | null; + producer_id: ProducerId__Output | null; + input_token_count: number; +} + +export interface RerankResults { + results?: RerankScores[]; + producer_id?: ProducerId | null; + input_token_count?: number | string | Long; +} +export interface RerankResults__Output { + results: RerankScores__Output[]; + producer_id: ProducerId__Output | null; + input_token_count: number; +} + +export interface RerankTaskRequest { + query?: string; + documents?: Struct[]; + top_n?: number | string | Long; + truncate_input_tokens?: number | string | Long; + return_documents?: boolean; + return_query?: boolean; + return_text?: boolean; + _top_n?: "top_n"; + _truncate_input_tokens?: "truncate_input_tokens"; + _return_documents?: "return_documents"; + _return_query?: "return_query"; + _return_text?: "return_text"; +} +export interface RerankTaskRequest__Output { + query: string; + documents: Struct__Output[]; + top_n?: number; + truncate_input_tokens?: number; + return_documents?: boolean; + return_query?: boolean; + return_text?: boolean; + _top_n: "top_n"; + _truncate_input_tokens: "truncate_input_tokens"; + _return_documents: "return_documents"; + _return_query: "return_query"; + _return_text: "return_text"; +} + +export interface RerankTasksRequest { + queries?: string[]; + documents?: Struct[]; + top_n?: number | string | Long; + truncate_input_tokens?: number | string | Long; + return_documents?: boolean; + return_queries?: boolean; + return_text?: boolean; + _top_n?: "top_n"; + _truncate_input_tokens?: "truncate_input_tokens"; + _return_documents?: "return_documents"; + _return_queries?: "return_queries"; + _return_text?: "return_text"; +} +export interface RerankTasksRequest__Output { + queries: string[]; + documents: Struct__Output[]; + top_n?: number; + truncate_input_tokens?: number; + return_documents?: boolean; + return_queries?: boolean; + return_text?: boolean; + _top_n: "top_n"; + _truncate_input_tokens: "truncate_input_tokens"; + _return_documents: "return_documents"; + _return_queries: "return_queries"; + _return_text: "return_text"; +} + +export interface SentenceSimilarityScores { + scores?: (number | string)[]; +} +export interface SentenceSimilarityScores__Output { + scores: number[]; +} + +export interface SentenceSimilarityResult { + result?: SentenceSimilarityScores | null; + producer_id?: ProducerId | null; + input_token_count?: number | string | Long; +} +export interface SentenceSimilarityResult__Output { + result: SentenceSimilarityScores__Output | null; + producer_id: ProducerId__Output | null; + input_token_count: number; +} + +export interface SentenceSimilarityResults { + results?: SentenceSimilarityScores[]; + producer_id?: ProducerId | null; + input_token_count?: number | string | Long; +} +export interface SentenceSimilarityResults__Output { + results: SentenceSimilarityScores__Output[]; + producer_id: ProducerId__Output | null; + input_token_count: number; +} + +export interface SentenceSimilarityTaskRequest { + source_sentence?: string; + sentences?: string[]; + truncate_input_tokens?: number | string | Long; + _truncate_input_tokens?: "truncate_input_tokens"; +} +export interface SentenceSimilarityTaskRequest__Output { + source_sentence: string; + sentences: string[]; + truncate_input_tokens?: number; + _truncate_input_tokens: "truncate_input_tokens"; +} + +export interface SentenceSimilarityTasksRequest { + source_sentences?: string[]; + sentences?: string[]; + truncate_input_tokens?: number | string | Long; + _truncate_input_tokens?: "truncate_input_tokens"; +} +export interface SentenceSimilarityTasksRequest__Output { + source_sentences: string[]; + sentences: string[]; + truncate_input_tokens?: number; + _truncate_input_tokens: "truncate_input_tokens"; +} + +export interface ExponentialDecayLengthPenalty { + start_index?: number | string | Long; + decay_factor?: number | string; +} +export interface ExponentialDecayLengthPenalty__Output { + start_index: number; + decay_factor: number; +} + +export interface ServerStreamingTextGenerationTaskRequest { + text?: string; + max_new_tokens?: number | string | Long; + min_new_tokens?: number | string | Long; + truncate_input_tokens?: number | string | Long; + decoding_method?: string; + top_k?: number | string | Long; + top_p?: number | string; + typical_p?: number | string; + temperature?: number | string; + repetition_penalty?: number | string; + max_time?: number | string; + exponential_decay_length_penalty?: ExponentialDecayLengthPenalty | null; + stop_sequences?: string[]; + seed?: number | string | Long; + preserve_input_text?: boolean; + _max_new_tokens?: "max_new_tokens"; + _min_new_tokens?: "min_new_tokens"; + _truncate_input_tokens?: "truncate_input_tokens"; + _decoding_method?: "decoding_method"; + _top_k?: "top_k"; + _top_p?: "top_p"; + _typical_p?: "typical_p"; + _temperature?: "temperature"; + _repetition_penalty?: "repetition_penalty"; + _max_time?: "max_time"; + _exponential_decay_length_penalty?: "exponential_decay_length_penalty"; + _seed?: "seed"; + _preserve_input_text?: "preserve_input_text"; +} +export interface ServerStreamingTextGenerationTaskRequest__Output { + text: string; + max_new_tokens?: number; + min_new_tokens?: number; + truncate_input_tokens?: number; + decoding_method?: string; + top_k?: number; + top_p?: number; + typical_p?: number; + temperature?: number; + repetition_penalty?: number; + max_time?: number; + exponential_decay_length_penalty?: ExponentialDecayLengthPenalty__Output | null; + stop_sequences: string[]; + seed?: number; + preserve_input_text?: boolean; + _max_new_tokens: "max_new_tokens"; + _min_new_tokens: "min_new_tokens"; + _truncate_input_tokens: "truncate_input_tokens"; + _decoding_method: "decoding_method"; + _top_k: "top_k"; + _top_p: "top_p"; + _typical_p: "typical_p"; + _temperature: "temperature"; + _repetition_penalty: "repetition_penalty"; + _max_time: "max_time"; + _exponential_decay_length_penalty: "exponential_decay_length_penalty"; + _seed: "seed"; + _preserve_input_text: "preserve_input_text"; +} + +export interface TextClassificationTaskRequest { + text?: string; +} +export interface TextClassificationTaskRequest__Output { + text: string; +} + +export interface TextGenerationTaskRequest { + text?: string; + max_new_tokens?: number | string | Long; + min_new_tokens?: number | string | Long; + truncate_input_tokens?: number | string | Long; + decoding_method?: string; + top_k?: number | string | Long; + top_p?: number | string; + typical_p?: number | string; + temperature?: number | string; + repetition_penalty?: number | string; + max_time?: number | string; + exponential_decay_length_penalty?: ExponentialDecayLengthPenalty | null; + stop_sequences?: string[]; + seed?: number | string | Long; + preserve_input_text?: boolean; + _max_new_tokens?: "max_new_tokens"; + _min_new_tokens?: "min_new_tokens"; + _truncate_input_tokens?: "truncate_input_tokens"; + _decoding_method?: "decoding_method"; + _top_k?: "top_k"; + _top_p?: "top_p"; + _typical_p?: "typical_p"; + _temperature?: "temperature"; + _repetition_penalty?: "repetition_penalty"; + _max_time?: "max_time"; + _exponential_decay_length_penalty?: "exponential_decay_length_penalty"; + _seed?: "seed"; + _preserve_input_text?: "preserve_input_text"; +} +export interface TextGenerationTaskRequest__Output { + text: string; + max_new_tokens?: number; + min_new_tokens?: number; + truncate_input_tokens?: number; + decoding_method?: string; + top_k?: number; + top_p?: number; + typical_p?: number; + temperature?: number; + repetition_penalty?: number; + max_time?: number; + exponential_decay_length_penalty?: ExponentialDecayLengthPenalty__Output | null; + stop_sequences: string[]; + seed?: number; + preserve_input_text?: boolean; + _max_new_tokens: "max_new_tokens"; + _min_new_tokens: "min_new_tokens"; + _truncate_input_tokens: "truncate_input_tokens"; + _decoding_method: "decoding_method"; + _top_k: "top_k"; + _top_p: "top_p"; + _typical_p: "typical_p"; + _temperature: "temperature"; + _repetition_penalty: "repetition_penalty"; + _max_time: "max_time"; + _exponential_decay_length_penalty: "exponential_decay_length_penalty"; + _seed: "seed"; + _preserve_input_text: "preserve_input_text"; +} + +export interface TokenClassificationResult { + start?: number | string | Long; + end?: number | string | Long; + word?: string; + entity?: string; + entity_group?: string; + score?: number | string; + token_count?: number | string | Long; +} +export interface TokenClassificationResult__Output { + start: number; + end: number; + word: string; + entity: string; + entity_group: string; + score: number; + token_count: number; +} + +export interface TokenClassificationResults { + results?: TokenClassificationResult[]; +} +export interface TokenClassificationResults__Output { + results: TokenClassificationResult__Output[]; +} + +export interface TokenClassificationStreamResult { + results?: TokenClassificationResult[]; + processed_index?: number | string | Long; + start_index?: number | string | Long; +} +export interface TokenClassificationStreamResult__Output { + results: TokenClassificationResult__Output[]; + processed_index: number; + start_index: number; +} + +export interface TokenClassificationTaskRequest { + text?: string; + threshold?: number | string; + _threshold?: "threshold"; +} +export interface TokenClassificationTaskRequest__Output { + text: string; + threshold?: number; + _threshold: "threshold"; +} + +export interface Token { + start?: number | string | Long; + end?: number | string | Long; + text?: string; +} +export interface Token__Output { + start: number; + end: number; + text: string; +} + +export interface TokenizationResults { + results?: Token[]; + token_count?: number | string | Long; +} +export interface TokenizationResults__Output { + results: Token__Output[]; + token_count: number; +} + +export interface TokenizationTaskRequest { + text?: string; +} +export interface TokenizationTaskRequest__Output { + text: string; +} + +export interface NlpServiceClient extends grpc.Client { + BidiStreamingTokenClassificationTaskPredict( + metadata: grpc.Metadata, + options?: grpc.CallOptions, + ): grpc.ClientDuplexStream< + BidiStreamingTokenClassificationTaskRequest, + TokenClassificationStreamResult__Output + >; + BidiStreamingTokenClassificationTaskPredict( + options?: grpc.CallOptions, + ): grpc.ClientDuplexStream< + BidiStreamingTokenClassificationTaskRequest, + TokenClassificationStreamResult__Output + >; + bidiStreamingTokenClassificationTaskPredict( + metadata: grpc.Metadata, + options?: grpc.CallOptions, + ): grpc.ClientDuplexStream< + BidiStreamingTokenClassificationTaskRequest, + TokenClassificationStreamResult__Output + >; + bidiStreamingTokenClassificationTaskPredict( + options?: grpc.CallOptions, + ): grpc.ClientDuplexStream< + BidiStreamingTokenClassificationTaskRequest, + TokenClassificationStreamResult__Output + >; + EmbeddingTaskPredict( + argument: EmbeddingTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + EmbeddingTaskPredict( + argument: EmbeddingTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + EmbeddingTaskPredict( + argument: EmbeddingTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + EmbeddingTaskPredict( + argument: EmbeddingTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + embeddingTaskPredict( + argument: EmbeddingTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + embeddingTaskPredict( + argument: EmbeddingTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + embeddingTaskPredict( + argument: EmbeddingTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + embeddingTaskPredict( + argument: EmbeddingTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + EmbeddingTasksPredict( + argument: EmbeddingTasksRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + EmbeddingTasksPredict( + argument: EmbeddingTasksRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + EmbeddingTasksPredict( + argument: EmbeddingTasksRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + EmbeddingTasksPredict( + argument: EmbeddingTasksRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + embeddingTasksPredict( + argument: EmbeddingTasksRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + embeddingTasksPredict( + argument: EmbeddingTasksRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + embeddingTasksPredict( + argument: EmbeddingTasksRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + embeddingTasksPredict( + argument: EmbeddingTasksRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + RerankTaskPredict( + argument: RerankTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + RerankTaskPredict( + argument: RerankTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + RerankTaskPredict( + argument: RerankTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + RerankTaskPredict( + argument: RerankTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + rerankTaskPredict( + argument: RerankTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + rerankTaskPredict( + argument: RerankTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + rerankTaskPredict( + argument: RerankTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + rerankTaskPredict( + argument: RerankTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + RerankTasksPredict( + argument: RerankTasksRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + RerankTasksPredict( + argument: RerankTasksRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + RerankTasksPredict( + argument: RerankTasksRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + RerankTasksPredict( + argument: RerankTasksRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + rerankTasksPredict( + argument: RerankTasksRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + rerankTasksPredict( + argument: RerankTasksRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + rerankTasksPredict( + argument: RerankTasksRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + rerankTasksPredict( + argument: RerankTasksRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + SentenceSimilarityTaskPredict( + argument: SentenceSimilarityTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + SentenceSimilarityTaskPredict( + argument: SentenceSimilarityTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + SentenceSimilarityTaskPredict( + argument: SentenceSimilarityTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + SentenceSimilarityTaskPredict( + argument: SentenceSimilarityTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + sentenceSimilarityTaskPredict( + argument: SentenceSimilarityTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + sentenceSimilarityTaskPredict( + argument: SentenceSimilarityTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + sentenceSimilarityTaskPredict( + argument: SentenceSimilarityTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + sentenceSimilarityTaskPredict( + argument: SentenceSimilarityTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + SentenceSimilarityTasksPredict( + argument: SentenceSimilarityTasksRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + SentenceSimilarityTasksPredict( + argument: SentenceSimilarityTasksRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + SentenceSimilarityTasksPredict( + argument: SentenceSimilarityTasksRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + SentenceSimilarityTasksPredict( + argument: SentenceSimilarityTasksRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + sentenceSimilarityTasksPredict( + argument: SentenceSimilarityTasksRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + sentenceSimilarityTasksPredict( + argument: SentenceSimilarityTasksRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + sentenceSimilarityTasksPredict( + argument: SentenceSimilarityTasksRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + sentenceSimilarityTasksPredict( + argument: SentenceSimilarityTasksRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + ServerStreamingTextGenerationTaskPredict( + argument: ServerStreamingTextGenerationTaskRequest, + metadata: grpc.Metadata, + options?: grpc.CallOptions, + ): grpc.ClientReadableStream; + ServerStreamingTextGenerationTaskPredict( + argument: ServerStreamingTextGenerationTaskRequest, + options?: grpc.CallOptions, + ): grpc.ClientReadableStream; + serverStreamingTextGenerationTaskPredict( + argument: ServerStreamingTextGenerationTaskRequest, + metadata: grpc.Metadata, + options?: grpc.CallOptions, + ): grpc.ClientReadableStream; + serverStreamingTextGenerationTaskPredict( + argument: ServerStreamingTextGenerationTaskRequest, + options?: grpc.CallOptions, + ): grpc.ClientReadableStream; + TextClassificationTaskPredict( + argument: TextClassificationTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextClassificationTaskPredict( + argument: TextClassificationTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextClassificationTaskPredict( + argument: TextClassificationTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextClassificationTaskPredict( + argument: TextClassificationTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textClassificationTaskPredict( + argument: TextClassificationTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textClassificationTaskPredict( + argument: TextClassificationTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textClassificationTaskPredict( + argument: TextClassificationTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textClassificationTaskPredict( + argument: TextClassificationTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskPredict( + argument: TextGenerationTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskPredict( + argument: TextGenerationTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskPredict( + argument: TextGenerationTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskPredict( + argument: TextGenerationTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskPredict( + argument: TextGenerationTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskPredict( + argument: TextGenerationTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskPredict( + argument: TextGenerationTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskPredict( + argument: TextGenerationTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TokenClassificationTaskPredict( + argument: TokenClassificationTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TokenClassificationTaskPredict( + argument: TokenClassificationTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TokenClassificationTaskPredict( + argument: TokenClassificationTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TokenClassificationTaskPredict( + argument: TokenClassificationTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + tokenClassificationTaskPredict( + argument: TokenClassificationTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + tokenClassificationTaskPredict( + argument: TokenClassificationTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + tokenClassificationTaskPredict( + argument: TokenClassificationTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + tokenClassificationTaskPredict( + argument: TokenClassificationTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TokenizationTaskPredict( + argument: TokenizationTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TokenizationTaskPredict( + argument: TokenizationTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TokenizationTaskPredict( + argument: TokenizationTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TokenizationTaskPredict( + argument: TokenizationTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + tokenizationTaskPredict( + argument: TokenizationTaskRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + tokenizationTaskPredict( + argument: TokenizationTaskRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + tokenizationTaskPredict( + argument: TokenizationTaskRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + tokenizationTaskPredict( + argument: TokenizationTaskRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; +} +export interface NlpServiceDefinition extends grpc.ServiceDefinition { + BidiStreamingTokenClassificationTaskPredict: MethodDefinition< + BidiStreamingTokenClassificationTaskRequest, + TokenClassificationStreamResult, + BidiStreamingTokenClassificationTaskRequest__Output, + TokenClassificationStreamResult__Output + >; + EmbeddingTaskPredict: MethodDefinition< + EmbeddingTaskRequest, + EmbeddingResult, + EmbeddingTaskRequest__Output, + EmbeddingResult__Output + >; + EmbeddingTasksPredict: MethodDefinition< + EmbeddingTasksRequest, + EmbeddingResults, + EmbeddingTasksRequest__Output, + EmbeddingResults__Output + >; + RerankTaskPredict: MethodDefinition< + RerankTaskRequest, + RerankResult, + RerankTaskRequest__Output, + RerankResult__Output + >; + RerankTasksPredict: MethodDefinition< + RerankTasksRequest, + RerankResults, + RerankTasksRequest__Output, + RerankResults__Output + >; + SentenceSimilarityTaskPredict: MethodDefinition< + SentenceSimilarityTaskRequest, + SentenceSimilarityResult, + SentenceSimilarityTaskRequest__Output, + SentenceSimilarityResult__Output + >; + SentenceSimilarityTasksPredict: MethodDefinition< + SentenceSimilarityTasksRequest, + SentenceSimilarityResults, + SentenceSimilarityTasksRequest__Output, + SentenceSimilarityResults__Output + >; + ServerStreamingTextGenerationTaskPredict: MethodDefinition< + ServerStreamingTextGenerationTaskRequest, + GeneratedTextStreamResult, + ServerStreamingTextGenerationTaskRequest__Output, + GeneratedTextStreamResult__Output + >; + TextClassificationTaskPredict: MethodDefinition< + TextClassificationTaskRequest, + ClassificationResults, + TextClassificationTaskRequest__Output, + ClassificationResults__Output + >; + TextGenerationTaskPredict: MethodDefinition< + TextGenerationTaskRequest, + GeneratedTextResult, + TextGenerationTaskRequest__Output, + GeneratedTextResult__Output + >; + TokenClassificationTaskPredict: MethodDefinition< + TokenClassificationTaskRequest, + TokenClassificationResults, + TokenClassificationTaskRequest__Output, + TokenClassificationResults__Output + >; + TokenizationTaskPredict: MethodDefinition< + TokenizationTaskRequest, + TokenizationResults, + TokenizationTaskRequest__Output, + TokenizationResults__Output + >; +} + +export interface S3Path { + path?: string; + endpoint?: string; + region?: string; + bucket?: string; + accessKey?: string; + secretKey?: string; + IAM_id?: string; + IAM_api_key?: string; +} +export interface S3Path__Output { + path: string; + endpoint: string; + region: string; + bucket: string; + accessKey: string; + secretKey: string; + IAM_id: string; + IAM_api_key: string; +} + +export interface GenerationTrainRecord { + input?: string; + output?: string; +} +export interface GenerationTrainRecord__Output { + input: string; + output: string; +} + +export interface DataStreamSourceGenerationTrainRecordJsonData { + data?: GenerationTrainRecord[]; +} +export interface DataStreamSourceGenerationTrainRecordJsonData__Output { + data: GenerationTrainRecord__Output[]; +} + +export interface FileReference { + filename?: string; +} +export interface FileReference__Output { + filename: string; +} + +export interface ListOfFileReferences { + files?: string[]; +} +export interface ListOfFileReferences__Output { + files: string[]; +} + +export interface Directory { + dirname?: string; + extension?: string; +} +export interface Directory__Output { + dirname: string; + extension: string; +} + +export interface S3Files { + files?: string[]; + endpoint?: string; + region?: string; + bucket?: string; + accessKey?: string; + secretKey?: string; + IAM_id?: string; + IAM_api_key?: string; +} +export interface S3Files__Output { + files: string[]; + endpoint: string; + region: string; + bucket: string; + accessKey: string; + secretKey: string; + IAM_id: string; + IAM_api_key: string; +} + +export interface DataStreamSourceGenerationTrainRecord { + jsondata?: DataStreamSourceGenerationTrainRecordJsonData | null; + file?: FileReference | null; + list_of_files?: ListOfFileReferences | null; + directory?: Directory | null; + s3files?: S3Files | null; + data_stream?: "jsondata" | "file" | "list_of_files" | "directory" | "s3files"; +} +export interface DataStreamSourceGenerationTrainRecord__Output { + jsondata?: DataStreamSourceGenerationTrainRecordJsonData__Output | null; + file?: FileReference__Output | null; + list_of_files?: ListOfFileReferences__Output | null; + directory?: Directory__Output | null; + s3files?: S3Files__Output | null; + data_stream: "jsondata" | "file" | "list_of_files" | "directory" | "s3files"; +} + +export interface TuningConfig { + num_virtual_tokens?: number | string | Long; + prompt_tuning_init_text?: string; + prompt_tuning_init_method?: string; + prompt_tuning_init_source_model?: string; + output_model_types?: string[]; +} +export interface TuningConfig__Output { + num_virtual_tokens: number; + prompt_tuning_init_text: string; + prompt_tuning_init_method: string; + prompt_tuning_init_source_model: string; + output_model_types: string[]; +} + +export interface TextGenerationTaskPeftPromptTuningTrainParameters { + base_model?: string; + train_stream?: DataStreamSourceGenerationTrainRecord | null; + tuning_config?: TuningConfig | null; + val_stream?: DataStreamSourceGenerationTrainRecord | null; + device?: string; + tuning_type?: string; + num_epochs?: number | string | Long; + learning_rate?: number | string; + verbalizer?: string; + batch_size?: number | string | Long; + max_source_length?: number | string | Long; + max_target_length?: number | string | Long; + accumulate_steps?: number | string | Long; + torch_dtype?: string; + silence_progress_bars?: boolean; + seed?: number | string | Long; + _val_stream?: "val_stream"; + _device?: "device"; + _tuning_type?: "tuning_type"; + _num_epochs?: "num_epochs"; + _learning_rate?: "learning_rate"; + _verbalizer?: "verbalizer"; + _batch_size?: "batch_size"; + _max_source_length?: "max_source_length"; + _max_target_length?: "max_target_length"; + _accumulate_steps?: "accumulate_steps"; + _torch_dtype?: "torch_dtype"; + _silence_progress_bars?: "silence_progress_bars"; + _seed?: "seed"; +} +export interface TextGenerationTaskPeftPromptTuningTrainParameters__Output { + base_model: string; + train_stream: DataStreamSourceGenerationTrainRecord__Output | null; + tuning_config: TuningConfig__Output | null; + val_stream?: DataStreamSourceGenerationTrainRecord__Output | null; + device?: string; + tuning_type?: string; + num_epochs?: number; + learning_rate?: number; + verbalizer?: string; + batch_size?: number; + max_source_length?: number; + max_target_length?: number; + accumulate_steps?: number; + torch_dtype?: string; + silence_progress_bars?: boolean; + seed?: number; + _val_stream: "val_stream"; + _device: "device"; + _tuning_type: "tuning_type"; + _num_epochs: "num_epochs"; + _learning_rate: "learning_rate"; + _verbalizer: "verbalizer"; + _batch_size: "batch_size"; + _max_source_length: "max_source_length"; + _max_target_length: "max_target_length"; + _accumulate_steps: "accumulate_steps"; + _torch_dtype: "torch_dtype"; + _silence_progress_bars: "silence_progress_bars"; + _seed: "seed"; +} + +export interface TextGenerationTaskPeftPromptTuningTrainRequest { + model_name?: string; + output_path?: S3Path | null; + parameters?: TextGenerationTaskPeftPromptTuningTrainParameters | null; +} +export interface TextGenerationTaskPeftPromptTuningTrainRequest__Output { + model_name: string; + output_path: S3Path__Output | null; + parameters: TextGenerationTaskPeftPromptTuningTrainParameters__Output | null; +} + +export interface TextGenerationTaskTextGenerationTrainParameters { + base_model?: string; + train_stream?: DataStreamSourceGenerationTrainRecord | null; + torch_dtype?: string; + max_source_length?: number | string | Long; + max_target_length?: number | string | Long; + batch_size?: number | string | Long; + num_epochs?: number | string | Long; + accumulate_steps?: number | string | Long; + random_seed?: number | string | Long; + lr?: number | string; + use_iterable_dataset?: boolean; + _torch_dtype?: "torch_dtype"; + _max_source_length?: "max_source_length"; + _max_target_length?: "max_target_length"; + _batch_size?: "batch_size"; + _num_epochs?: "num_epochs"; + _accumulate_steps?: "accumulate_steps"; + _random_seed?: "random_seed"; + _lr?: "lr"; + _use_iterable_dataset?: "use_iterable_dataset"; +} +export interface TextGenerationTaskTextGenerationTrainParameters__Output { + base_model: string; + train_stream: DataStreamSourceGenerationTrainRecord__Output | null; + torch_dtype?: string; + max_source_length?: number; + max_target_length?: number; + batch_size?: number; + num_epochs?: number; + accumulate_steps?: number; + random_seed?: number; + lr?: number; + use_iterable_dataset?: boolean; + _torch_dtype: "torch_dtype"; + _max_source_length: "max_source_length"; + _max_target_length: "max_target_length"; + _batch_size: "batch_size"; + _num_epochs: "num_epochs"; + _accumulate_steps: "accumulate_steps"; + _random_seed: "random_seed"; + _lr: "lr"; + _use_iterable_dataset: "use_iterable_dataset"; +} + +export interface TextGenerationTaskTextGenerationTrainRequest { + model_name?: string; + output_path?: S3Path | null; + parameters?: TextGenerationTaskTextGenerationTrainParameters | null; +} +export interface TextGenerationTaskTextGenerationTrainRequest__Output { + model_name: string; + output_path: S3Path__Output | null; + parameters: TextGenerationTaskTextGenerationTrainParameters__Output | null; +} + +export interface TrainingJob { + training_id?: string; + model_name?: string; +} +export interface TrainingJob__Output { + training_id: string; + model_name: string; +} + +export interface NlpTrainingServiceClient extends grpc.Client { + TextGenerationTaskPeftPromptTuningTrain( + argument: TextGenerationTaskPeftPromptTuningTrainRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskPeftPromptTuningTrain( + argument: TextGenerationTaskPeftPromptTuningTrainRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskPeftPromptTuningTrain( + argument: TextGenerationTaskPeftPromptTuningTrainRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskPeftPromptTuningTrain( + argument: TextGenerationTaskPeftPromptTuningTrainRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskPeftPromptTuningTrain( + argument: TextGenerationTaskPeftPromptTuningTrainRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskPeftPromptTuningTrain( + argument: TextGenerationTaskPeftPromptTuningTrainRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskPeftPromptTuningTrain( + argument: TextGenerationTaskPeftPromptTuningTrainRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskPeftPromptTuningTrain( + argument: TextGenerationTaskPeftPromptTuningTrainRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskTextGenerationTrain( + argument: TextGenerationTaskTextGenerationTrainRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskTextGenerationTrain( + argument: TextGenerationTaskTextGenerationTrainRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskTextGenerationTrain( + argument: TextGenerationTaskTextGenerationTrainRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + TextGenerationTaskTextGenerationTrain( + argument: TextGenerationTaskTextGenerationTrainRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskTextGenerationTrain( + argument: TextGenerationTaskTextGenerationTrainRequest, + metadata: grpc.Metadata, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskTextGenerationTrain( + argument: TextGenerationTaskTextGenerationTrainRequest, + metadata: grpc.Metadata, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskTextGenerationTrain( + argument: TextGenerationTaskTextGenerationTrainRequest, + options: grpc.CallOptions, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; + textGenerationTaskTextGenerationTrain( + argument: TextGenerationTaskTextGenerationTrainRequest, + callback: grpc.requestCallback, + ): grpc.ClientUnaryCall; +} +export interface NlpTrainingServiceDefinition extends grpc.ServiceDefinition { + TextGenerationTaskPeftPromptTuningTrain: MethodDefinition< + TextGenerationTaskPeftPromptTuningTrainRequest, + TrainingJob, + TextGenerationTaskPeftPromptTuningTrainRequest__Output, + TrainingJob__Output + >; + TextGenerationTaskTextGenerationTrain: MethodDefinition< + TextGenerationTaskTextGenerationTrainRequest, + TrainingJob, + TextGenerationTaskTextGenerationTrainRequest__Output, + TrainingJob__Output + >; +} + +export type SubtypeConstructor$1 any, Subtype> = new ( + ...args: ConstructorParameters +) => Subtype; +export interface ProtoGrpcType$1 { + caikit: { + runtime: { + Nlp: { + BidiStreamingTokenClassificationTaskRequest: MessageTypeDefinition; + DataStreamSourceGenerationTrainRecord: MessageTypeDefinition; + DataStreamSourceGenerationTrainRecordJsonData: MessageTypeDefinition; + EmbeddingTaskRequest: MessageTypeDefinition; + EmbeddingTasksRequest: MessageTypeDefinition; + NlpService: SubtypeConstructor$1 & { + service: NlpServiceDefinition; + }; + NlpTrainingService: SubtypeConstructor$1 & { + service: NlpTrainingServiceDefinition; + }; + RerankTaskRequest: MessageTypeDefinition; + RerankTasksRequest: MessageTypeDefinition; + SentenceSimilarityTaskRequest: MessageTypeDefinition; + SentenceSimilarityTasksRequest: MessageTypeDefinition; + ServerStreamingTextGenerationTaskRequest: MessageTypeDefinition; + TextClassificationTaskRequest: MessageTypeDefinition; + TextGenerationTaskPeftPromptTuningTrainParameters: MessageTypeDefinition; + TextGenerationTaskPeftPromptTuningTrainRequest: MessageTypeDefinition; + TextGenerationTaskRequest: MessageTypeDefinition; + TextGenerationTaskTextGenerationTrainParameters: MessageTypeDefinition; + TextGenerationTaskTextGenerationTrainRequest: MessageTypeDefinition; + TokenClassificationTaskRequest: MessageTypeDefinition; + TokenizationTaskRequest: MessageTypeDefinition; + }; + }; + }; + caikit_data_model: { + caikit_nlp: { + EmbeddingResult: MessageTypeDefinition; + EmbeddingResults: MessageTypeDefinition; + ExponentialDecayLengthPenalty: MessageTypeDefinition; + GenerationTrainRecord: MessageTypeDefinition; + RerankResult: MessageTypeDefinition; + RerankResults: MessageTypeDefinition; + RerankScore: MessageTypeDefinition; + RerankScores: MessageTypeDefinition; + SentenceSimilarityResult: MessageTypeDefinition; + SentenceSimilarityResults: MessageTypeDefinition; + SentenceSimilarityScores: MessageTypeDefinition; + TuningConfig: MessageTypeDefinition; + }; + common: { + BoolSequence: MessageTypeDefinition; + ConnectionInfo: MessageTypeDefinition; + ConnectionTlsInfo: MessageTypeDefinition; + Directory: MessageTypeDefinition; + File: MessageTypeDefinition; + FileReference: MessageTypeDefinition; + FloatSequence: MessageTypeDefinition; + IntSequence: MessageTypeDefinition; + ListOfFileReferences: MessageTypeDefinition; + ListOfVector1D: MessageTypeDefinition; + NpFloat32Sequence: MessageTypeDefinition; + NpFloat64Sequence: MessageTypeDefinition; + ProducerId: MessageTypeDefinition; + ProducerPriority: MessageTypeDefinition; + PyFloatSequence: MessageTypeDefinition; + S3Base: MessageTypeDefinition; + S3Files: MessageTypeDefinition; + S3Path: MessageTypeDefinition; + StrSequence: MessageTypeDefinition; + TrainingStatus: EnumTypeDefinition; + Vector1D: MessageTypeDefinition; + }; + nlp: { + ClassificationResult: MessageTypeDefinition; + ClassificationResults: MessageTypeDefinition; + ClassificationTrainRecord: MessageTypeDefinition; + ClassifiedGeneratedTextResult: MessageTypeDefinition; + ClassifiedGeneratedTextStreamResult: MessageTypeDefinition; + FinishReason: EnumTypeDefinition; + GeneratedTextResult: MessageTypeDefinition; + GeneratedTextStreamResult: MessageTypeDefinition; + GeneratedToken: MessageTypeDefinition; + InputWarning: MessageTypeDefinition; + InputWarningReason: EnumTypeDefinition; + TextGenTokenClassificationResults: MessageTypeDefinition; + Token: MessageTypeDefinition; + TokenClassificationResult: MessageTypeDefinition; + TokenClassificationResults: MessageTypeDefinition; + TokenClassificationStreamResult: MessageTypeDefinition; + TokenStreamDetails: MessageTypeDefinition; + TokenizationResults: MessageTypeDefinition; + TokenizationStreamResult: MessageTypeDefinition; + }; + runtime: { + ModelPointer: MessageTypeDefinition; + TrainingInfoRequest: MessageTypeDefinition; + TrainingJob: MessageTypeDefinition; + TrainingStatusResponse: MessageTypeDefinition; + }; + }; + google: { + protobuf: { + ListValue: MessageTypeDefinition; + NullValue: EnumTypeDefinition; + Struct: MessageTypeDefinition; + Timestamp: MessageTypeDefinition; + Value: MessageTypeDefinition; + }; + }; +} + export interface GenerationRequest { text?: string; } diff --git a/src/internals/types.ts b/src/internals/types.ts index d856203c..25a880db 100644 --- a/src/internals/types.ts +++ b/src/internals/types.ts @@ -113,3 +113,7 @@ export type OneOf = T extends [infer Only] : never; export type AnyVoid = Promise | unknown; + +export type OmitPrivateKeys = { + [K in keyof T as K extends `_${string}` ? never : K]: T[K]; +}; diff --git a/yarn.lock b/yarn.lock index 8cb11ac1..ed0f69ca 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4868,7 +4868,7 @@ __metadata: openai-chat-tokens: "npm:^0.2.8" openapi-fetch: "npm:^0.13.3" openapi-typescript: "npm:^7.4.4" - p-queue: "npm:^8.0.1" + p-queue-compat: "npm:^1.0.227" p-throttle: "npm:^7.0.0" picocolors: "npm:^1.1.1" pino: "npm:^9.5.0" @@ -11326,6 +11326,16 @@ __metadata: languageName: node linkType: hard +"p-queue-compat@npm:^1.0.227": + version: 1.0.227 + resolution: "p-queue-compat@npm:1.0.227" + dependencies: + eventemitter3: "npm:5.x" + p-timeout-compat: "npm:^1.0.3" + checksum: 10c0/4b1d241e0734f2dad9669b2d71e28c62218f2f8d29bd575080975154ebf2f9aba9b47ce11714e16763f7c55e89bcf2213ea7a7cd7b90cce5d0246f34fbc8ac8d + languageName: node + linkType: hard + "p-queue@npm:^6.6.2": version: 6.6.2 resolution: "p-queue@npm:6.6.2" @@ -11336,16 +11346,6 @@ __metadata: languageName: node linkType: hard -"p-queue@npm:^8.0.1": - version: 8.0.1 - resolution: "p-queue@npm:8.0.1" - dependencies: - eventemitter3: "npm:^5.0.1" - p-timeout: "npm:^6.1.2" - checksum: 10c0/fe185bc8bbd32d17a5f6dba090077b1bb326b008b4ec9b0646c57a32a6984035aa8ece909a6d0de7f6c4640296dc288197f430e7394cdc76a26d862339494616 - languageName: node - linkType: hard - "p-retry@npm:4": version: 4.6.2 resolution: "p-retry@npm:4.6.2" @@ -11379,13 +11379,6 @@ __metadata: languageName: node linkType: hard -"p-timeout@npm:^6.1.2": - version: 6.1.2 - resolution: "p-timeout@npm:6.1.2" - checksum: 10c0/d46b90a9a5fb7c650a5c56dd5cf7102ea9ab6ce998defa2b3d4672789aaec4e2f45b3b0b5a4a3e17a0fb94301ad5dd26da7d8728402e48db2022ad1847594d19 - languageName: node - linkType: hard - "p-try@npm:^2.0.0": version: 2.2.0 resolution: "p-try@npm:2.2.0" From 8e8fcb4b98e525f43be3070398286cea7f4736b4 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Fri, 13 Dec 2024 13:43:52 +0100 Subject: [PATCH 2/2] fixup! feat(adapters): add embedding support for IBM vLLM Signed-off-by: Tomas Dvorak --- examples/llms/providers/ibm-vllm.ts | 8 ++++++++ src/adapters/ibm-vllm/chat.ts | 16 ++++++++++------ src/adapters/ibm-vllm/client.ts | 11 ++++++++++- tests/e2e/adapters/ibm-vllm/llm.test.ts | 12 ++++++++++-- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/examples/llms/providers/ibm-vllm.ts b/examples/llms/providers/ibm-vllm.ts index a995a93a..50480c63 100644 --- a/examples/llms/providers/ibm-vllm.ts +++ b/examples/llms/providers/ibm-vllm.ts @@ -34,3 +34,11 @@ const client = new Client(); ]); console.info(response.messages); } + +{ + console.info("===EMBEDDING==="); + const llm = new IBMvLLM({ client, modelId: "baai/bge-large-en-v1.5" }); + + const response = await llm.embed([`Hello world!`, `Hello family!`]); + console.info(response); +} diff --git a/src/adapters/ibm-vllm/chat.ts b/src/adapters/ibm-vllm/chat.ts index 8a610b6b..4f776329 100644 --- a/src/adapters/ibm-vllm/chat.ts +++ b/src/adapters/ibm-vllm/chat.ts @@ -16,7 +16,13 @@ import { isFunction, isObjectType } from "remeda"; -import { IBMvLLM, IBMvLLMGenerateOptions, IBMvLLMOutput, IBMvLLMParameters } from "./llm.js"; +import { + IBMvLLM, + IBMvLLMEmbeddingOptions, + IBMvLLMGenerateOptions, + IBMvLLMOutput, + IBMvLLMParameters, +} from "./llm.js"; import { Cache } from "@/cache/decoratorCache.js"; import { BaseMessage, Role } from "@/llms/primitives/message.js"; @@ -25,7 +31,6 @@ import { ChatLLM, ChatLLMGenerateEvents, ChatLLMOutput } from "@/llms/chat.js"; import { AsyncStream, BaseLLMTokenizeOutput, - EmbeddingOptions, EmbeddingOutput, LLMCache, LLMError, @@ -36,7 +41,6 @@ import { shallowCopy } from "@/serializer/utils.js"; import { IBMVllmChatLLMPreset, IBMVllmChatLLMPresetModel } from "@/adapters/ibm-vllm/chatPreset.js"; import { Client } from "./client.js"; import { GetRunContext } from "@/context.js"; -import { NotImplementedError } from "@/errors.js"; export class GrpcChatLLMOutput extends ChatLLMOutput { public readonly raw: IBMvLLMOutput; @@ -118,9 +122,9 @@ export class IBMVllmChatLLM extends ChatLLM { return this.llm.meta(); } - // eslint-disable-next-line unused-imports/no-unused-vars - async embed(input: BaseMessage[][], options?: EmbeddingOptions): Promise { - throw new NotImplementedError(); + async embed(input: BaseMessage[][], options?: IBMvLLMEmbeddingOptions): Promise { + const inputs = input.map((messages) => this.messagesToPrompt(messages)); + return this.llm.embed(inputs, options); } createSnapshot() { diff --git a/src/adapters/ibm-vllm/client.ts b/src/adapters/ibm-vllm/client.ts index b294abf6..0740e11f 100644 --- a/src/adapters/ibm-vllm/client.ts +++ b/src/adapters/ibm-vllm/client.ts @@ -19,6 +19,7 @@ import grpc, { ClientOptions as GRPCClientOptions, ClientReadableStream, ClientUnaryCall, + Metadata, } from "@grpc/grpc-js"; import * as R from "remeda"; @@ -47,6 +48,7 @@ import { z } from "zod"; import { Cache } from "@/cache/decoratorCache.js"; import { Serializable } from "@/internals/serializable.js"; import PQueue from "p-queue-compat"; +import { getProp } from "@/internals/helpers/object.js"; const GENERATION_PROTO_PATH = new URL("./proto/generation.proto", import.meta.url); const NLP_PROTO_PATH = new URL("./proto/caikit_runtime_Nlp.proto", import.meta.url); @@ -193,13 +195,20 @@ export class Client extends Serializable { protected wrapGrpcCall( fn: ( request: TRequest, + metadata: Metadata, options: CallOptions, callback: UnaryCallback, ) => ClientUnaryCall, ) { return (request: TRequest, { signal, ...options }: CallOptions = {}): Promise => { + const metadata = new Metadata(); + const modelId = getProp(request, ["model_id"]); + if (modelId) { + metadata.add("mm-model-id", modelId); + } + return new Promise((resolve, reject) => { - const call = fn(request, options, (err, response) => { + const call = fn(request, metadata, options, (err, response) => { signal?.removeEventListener("abort", abortHandler); if (err) { reject(err); diff --git a/tests/e2e/adapters/ibm-vllm/llm.test.ts b/tests/e2e/adapters/ibm-vllm/llm.test.ts index 24496aee..dc0abf94 100644 --- a/tests/e2e/adapters/ibm-vllm/llm.test.ts +++ b/tests/e2e/adapters/ibm-vllm/llm.test.ts @@ -26,8 +26,8 @@ describe.runIf( process.env.IBM_VLLM_CERT_CHAIN, ].every((env) => Boolean(env)), )("IBM vLLM", () => { - const createLLM = () => { - return new IBMvLLM({ modelId: IBMVllmModel.LLAMA_3_1_70B_INSTRUCT }); + const createLLM = (modelId: string = IBMVllmModel.LLAMA_3_1_70B_INSTRUCT) => { + return new IBMvLLM({ modelId }); }; it("Meta", async () => { @@ -50,6 +50,14 @@ describe.runIf( } }); + it("Embeds", async () => { + const llm = createLLM("baai/bge-large-en-v1.5"); + const response = await llm.embed([`Hello world!`, `Hello family!`]); + expect(response.embeddings.length).toBe(2); + expect(response.embeddings[0].length).toBe(1024); + expect(response.embeddings[1].length).toBe(1024); + }); + it("Serializes", () => { const llm = createLLM(); const serialized = llm.serialize();