Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(adapters): add embedding support for IBM vLLM #251

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
"mathjs": "^14.0.0",
"mustache": "^4.2.0",
"object-hash": "^3.0.0",
"p-queue": "^8.0.1",
"p-queue-compat": "^1.0.227",
"p-throttle": "^7.0.0",
"pino": "^9.5.0",
"promise-based-task": "^3.1.1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ GRPC_PROTO_PATH="./src/adapters/ibm-vllm/proto"
GRPC_TYPES_PATH="./src/adapters/ibm-vllm/types.ts"

SCRIPT_DIR="$(dirname "$0")"
OUTPUT_RELATIVE_PATH="dist/generation.d.ts"
GRPC_TYPES_TMP_PATH=types
OUTPUT_RELATIVE_PATH="dist/merged.d.ts"
GRPC_TYPES_TMP_PATH="types"

rm -f "$GRPC_TYPES_PATH"

Expand All @@ -39,7 +39,7 @@ yarn run proto-loader-gen-types \


cd "$SCRIPT_DIR"
tsup --dts-only
ENTRY="$(basename "$OUTPUT_RELATIVE_PATH" ".d.ts")" tsup --dts-only
sed -i.bak '$ d' "$OUTPUT_RELATIVE_PATH"
sed -i.bak -E "s/^interface/export interface/" "$OUTPUT_RELATIVE_PATH"
sed -i.bak -E "s/^type/export type/" "$OUTPUT_RELATIVE_PATH"
Expand All @@ -50,4 +50,4 @@ rm -rf "${SCRIPT_DIR}"/{dist,dts,types}

yarn run lint:fix "${GRPC_TYPES_PATH}"
yarn prettier --write "${GRPC_TYPES_PATH}"
yarn copyright
TARGETS="$GRPC_TYPES_PATH" yarn copyright
2 changes: 1 addition & 1 deletion scripts/ibm_vllm_generate_protos/tsconfig.proto.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"rootDir": ".",
"baseUrl": ".",
"target": "ESNext",
"module": "ES6",
"module": "ESNext",
"outDir": "dist",
"declaration": true,
"emitDeclarationOnly": true,
Expand Down
17 changes: 15 additions & 2 deletions scripts/ibm_vllm_generate_protos/tsup.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,27 @@
*/

import { defineConfig } from "tsup";
import fs from "node:fs";

if (!process.env.ENTRY) {
throw new Error(`Entry file was not provided!`);
}
const target = `types/${process.env.ENTRY}.ts`;
await fs.promises.writeFile(
target,
[
`export { ProtoGrpcType as A } from "./caikit_runtime_Nlp.js"`,
`export { ProtoGrpcType as B } from "./generation.js"`,
].join("\n"),
);

export default defineConfig({
entry: ["types/generation.ts"],
entry: [target],
tsconfig: "./tsconfig.proto.json",
sourcemap: false,
dts: true,
format: ["esm"],
treeshake: false,
treeshake: true,
legacyOutput: false,
skipNodeModulesBundle: true,
bundle: true,
Expand Down
78 changes: 61 additions & 17 deletions src/adapters/ibm-vllm/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,31 @@ import * as R from "remeda";
// eslint-disable-next-line no-restricted-imports
import { UnaryCallback } from "@grpc/grpc-js/build/src/client.js";
import { FrameworkError, ValueError } from "@/errors.js";
import protoLoader from "@grpc/proto-loader";
import protoLoader, { Options } from "@grpc/proto-loader";

import {
BatchedGenerationRequest,
BatchedGenerationResponse__Output,
BatchedTokenizeRequest,
BatchedTokenizeResponse__Output,
type EmbeddingTasksRequest,
GenerationRequest__Output,
ModelInfoRequest,
ModelInfoResponse__Output,
ProtoGrpcType as GenerationProtoGentypes,
ProtoGrpcType$1 as CaikitProtoGentypes,
SingleGenerationRequest,
EmbeddingResults__Output,
type SubtypeConstructor,
} from "@/adapters/ibm-vllm/types.js";
import { parseEnv } from "@/internals/env.js";
import { z } from "zod";
import { Cache } from "@/cache/decoratorCache.js";
import { Serializable } from "@/internals/serializable.js";
import PQueue from "p-queue-compat";

const GENERATION_PROTO_PATH = new URL("./proto/generation.proto", import.meta.url);
const NLP_PROTO_PATH = new URL("./proto/caikit_runtime_Nlp.proto", import.meta.url);

interface ClientOptions {
modelRouterSubdomain?: string;
Expand All @@ -55,6 +61,11 @@ interface ClientOptions {
};
grpcClientOptions: GRPCClientOptions;
clientShutdownDelay: number;
limits?: {
concurrency?: {
embeddings?: number;
};
};
}

const defaultOptions = {
Expand All @@ -66,18 +77,24 @@ const defaultOptions = {
},
};

const generationPackageObject = grpc.loadPackageDefinition(
protoLoader.loadSync([GENERATION_PROTO_PATH.pathname], {
longs: Number,
enums: String,
arrays: true,
objects: true,
oneofs: true,
keepCase: true,
defaults: true,
}),
const grpcConfig: Options = {
longs: Number,
enums: String,
arrays: true,
objects: true,
oneofs: true,
keepCase: true,
defaults: true,
};

const generationPackage = grpc.loadPackageDefinition(
protoLoader.loadSync([GENERATION_PROTO_PATH.pathname], grpcConfig),
) as unknown as GenerationProtoGentypes;

const embeddingsPackage = grpc.loadPackageDefinition(
protoLoader.loadSync([NLP_PROTO_PATH.pathname], grpcConfig),
) as unknown as CaikitProtoGentypes;

const GRPC_CLIENT_TTL = 15 * 60 * 1000;

type CallOptions = GRPCCallOptions & { signal?: AbortSignal };
Expand All @@ -88,9 +105,12 @@ export class Client extends Serializable {
private usedDefaultCredentials = false;

@Cache({ ttl: GRPC_CLIENT_TTL })
protected getClient(modelId: string) {
protected getClient<T extends { close: () => void }>(
modelId: string,
factory: SubtypeConstructor<typeof grpc.Client, T>,
): T {
const modelSpecificUrl = this.options.url.replace(/{model_id}/, modelId.replaceAll("/", "--"));
const client = new generationPackageObject.fmaas.GenerationService(
const client = new factory(
modelSpecificUrl,
grpc.credentials.createSsl(
Buffer.from(this.options.credentials.rootCert),
Expand Down Expand Up @@ -129,33 +149,47 @@ export class Client extends Serializable {
}

async modelInfo(request: RequiredModel<ModelInfoRequest>, options?: CallOptions) {
const client = this.getClient(request.model_id);
const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService);
return this.wrapGrpcCall<ModelInfoRequest, ModelInfoResponse__Output>(
client.modelInfo.bind(client),
)(request, options);
}

async generate(request: RequiredModel<BatchedGenerationRequest>, options?: CallOptions) {
const client = this.getClient(request.model_id);
const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService);
return this.wrapGrpcCall<BatchedGenerationRequest, BatchedGenerationResponse__Output>(
client.generate.bind(client),
)(request, options);
}

async generateStream(request: RequiredModel<SingleGenerationRequest>, options?: CallOptions) {
const client = this.getClient(request.model_id);
const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService);
return this.wrapGrpcStream<SingleGenerationRequest, GenerationRequest__Output>(
client.generateStream.bind(client),
)(request, options);
}

async tokenize(request: RequiredModel<BatchedTokenizeRequest>, options?: CallOptions) {
const client = this.getClient(request.model_id);
const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService);
return this.wrapGrpcCall<BatchedTokenizeRequest, BatchedTokenizeResponse__Output>(
client.tokenize.bind(client),
)(request, options);
}

async embed(request: RequiredModel<EmbeddingTasksRequest>, options?: CallOptions) {
const client = this.getClient(
request.model_id,
embeddingsPackage.caikit.runtime.Nlp.NlpService,
);
return this.queues.embeddings.add(
() =>
this.wrapGrpcCall<EmbeddingTasksRequest, EmbeddingResults__Output>(
client.embeddingTasksPredict.bind(client),
)(request, options),
{ throwOnTimeout: true },
);
}

protected wrapGrpcCall<TRequest, TResponse>(
fn: (
request: TRequest,
Expand Down Expand Up @@ -213,4 +247,14 @@ export class Client extends Serializable {
Object.assign(this, snapshot);
this.options.credentials = this.getDefaultCredentials();
}

@Cache({ enumerable: false })
protected get queues() {
return {
embeddings: new PQueue({
concurrency: this.options.limits?.concurrency?.embeddings ?? 5,
throwOnTimeout: true,
}),
};
}
}
48 changes: 43 additions & 5 deletions src/adapters/ibm-vllm/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ import {
LLMError,
LLMMeta,
} from "@/llms/base.js";
import { isEmpty, isString } from "remeda";
import type { DecodingParameters, SingleGenerationRequest } from "@/adapters/ibm-vllm/types.js";
import { chunk, isEmpty, isString } from "remeda";
import type {
DecodingParameters,
SingleGenerationRequest,
EmbeddingTasksRequest,
} from "@/adapters/ibm-vllm/types.js";
import { LLM, LLMEvents, LLMInput } from "@/llms/llm.js";
import { Emitter } from "@/emitter/emitter.js";
import { GenerationResponse__Output } from "@/adapters/ibm-vllm/types.js";
Expand All @@ -39,6 +43,7 @@ import { ServiceError } from "@grpc/grpc-js";
import { Client } from "@/adapters/ibm-vllm/client.js";
import { GetRunContext } from "@/context.js";
import { BatchedGenerationRequest } from "./types.js";
import { OmitPrivateKeys } from "@/internals/types.js";

function isGrpcServiceError(err: unknown): err is ServiceError {
return (
Expand Down Expand Up @@ -100,6 +105,12 @@ export type IBMvLLMParameters = NonNullable<

export interface IBMvLLMGenerateOptions extends GenerateOptions {}

export interface IBMvLLMEmbeddingOptions
extends EmbeddingOptions,
Omit<OmitPrivateKeys<EmbeddingTasksRequest>, "texts"> {
chunkSize?: number;
}

export type IBMvLLMEvents = LLMEvents<IBMvLLMOutput>;

export class IBMvLLM extends LLM<IBMvLLMOutput, IBMvLLMGenerateOptions> {
Expand Down Expand Up @@ -128,9 +139,36 @@ export class IBMvLLM extends LLM<IBMvLLMOutput, IBMvLLMGenerateOptions> {
};
}

// eslint-disable-next-line unused-imports/no-unused-vars
async embed(input: LLMInput[], options?: EmbeddingOptions): Promise<EmbeddingOutput> {
throw new NotImplementedError();
async embed(
input: LLMInput[],
{ chunkSize, signal, ...options }: IBMvLLMEmbeddingOptions = {},
): Promise<EmbeddingOutput> {
const results = await Promise.all(
chunk(input, chunkSize ?? 100).map(async (texts) => {
const response = await this.client.embed(
{
model_id: this.modelId,
truncate_input_tokens: options?.truncate_input_tokens ?? 512,
texts,
},
{
signal,
},
);
const embeddings = response.results?.vectors.map((vector) => {
const embedding = vector[vector.data]?.values;
if (!embedding) {
throw new LLMError("Missing embedding");
}
return embedding;
});
if (embeddings?.length !== texts.length) {
throw new LLMError("Missing embedding");
}
return embeddings;
}),
);
return { embeddings: results.flat() };
}

async tokenize(input: LLMInput): Promise<BaseLLMTokenizeOutput> {
Expand Down
Loading
Loading