Skip to content

Commit

Permalink
[vertex] Add PDF/plein texts support (#1520)
Browse files Browse the repository at this point in the history
* [vertex] Add PDF support

* [vertex] Fix lint

* [vertex] Add support for text/plain
  • Loading branch information
ArthurGoupil authored Oct 15, 2024
1 parent db229c6 commit cc5dfd4
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/lib/components/chat/ChatWindow.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@
...(!$page.data?.assistant && currentModel.tools
? activeTools.flatMap((tool: ToolFront) => tool.mimeTypes ?? [])
: []),
...(currentModel.multimodal ? ["image/*"] : []),
...(currentModel.multimodal ? currentModel.multimodalAcceptedMimetypes ?? ["image/*"] : []),
];
$: isFileUploadEnabled = activeMimeTypes.length > 0;
Expand Down
69 changes: 69 additions & 0 deletions src/lib/server/endpoints/document.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import type { MessageFile } from "$lib/types/Message";
import { z } from "zod";

export interface FileProcessorOptions<TMimeType extends string = string> {
supportedMimeTypes: TMimeType[];
maxSizeInMB: number;
}

export type ImageProcessor<TMimeType extends string = string> = (file: MessageFile) => Promise<{
file: Buffer;
mime: TMimeType;
}>;

export const createDocumentProcessorOptionsValidator = <TMimeType extends string = string>(
defaults: FileProcessorOptions<TMimeType>
) => {
return z
.object({
supportedMimeTypes: z
.array(
z.enum<string, [TMimeType, ...TMimeType[]]>([
defaults.supportedMimeTypes[0],
...defaults.supportedMimeTypes.slice(1),
])
)
.default(defaults.supportedMimeTypes),
maxSizeInMB: z.number().positive().default(defaults.maxSizeInMB),
})
.default(defaults);
};

export type DocumentProcessor<TMimeType extends string = string> = (file: MessageFile) => {
file: Buffer;
mime: TMimeType;
};

export function makeDocumentProcessor<TMimeType extends string = string>(
options: FileProcessorOptions<TMimeType>
): DocumentProcessor<TMimeType> {
return (file) => {
const { supportedMimeTypes, maxSizeInMB } = options;
const { mime, value } = file;

const buffer = Buffer.from(value, "base64");

const tooLargeInBytes = buffer.byteLength > maxSizeInMB * 1000 * 1000;

if (tooLargeInBytes) {
throw Error("Document is too large");
}

const outputMime = validateMimeType(supportedMimeTypes, mime);

return { file: buffer, mime: outputMime };
};
}

const validateMimeType = <T extends readonly string[]>(
supportedMimes: T,
mime: string
): T[number] => {
if (!supportedMimes.includes(mime)) {
const supportedMimesStr = supportedMimes.join(", ");

throw Error(`Mimetype "${mime}" not found in supported mimes: ${supportedMimesStr}`);
}

return mime;
};
30 changes: 26 additions & 4 deletions src/lib/server/endpoints/google/endpointVertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { z } from "zod";
import type { Message } from "$lib/types/Message";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
import { createDocumentProcessorOptionsValidator, makeDocumentProcessor } from "../document";

export const endpointVertexParametersSchema = z.object({
weight: z.number().int().positive().default(1),
Expand Down Expand Up @@ -39,12 +40,17 @@ export const endpointVertexParametersSchema = z.object({
"image/avif",
"image/tiff",
"image/gif",
"application/pdf",
],
preferredMimeType: "image/webp",
maxSizeInMB: Infinity,
maxSizeInMB: 20,
maxWidth: 4096,
maxHeight: 4096,
}),
document: createDocumentProcessorOptionsValidator({
supportedMimeTypes: ["application/pdf", "text/plain"],
maxSizeInMB: 20,
}),
})
.default({}),
});
Expand Down Expand Up @@ -109,17 +115,33 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
const vertexMessages = await Promise.all(
messages.map(async ({ from, content, files }: Omit<Message, "id">): Promise<Content> => {
const imageProcessor = makeImageProcessor(multimodal.image);
const processedFiles =
const documentProcessor = makeDocumentProcessor(multimodal.document);

const processedFilesWithNull =
files && files.length > 0
? await Promise.all(files.map(async (file) => imageProcessor(file)))
? await Promise.all(
files.map(async (file) => {
if (file.mime.includes("image")) {
const { image, mime } = await imageProcessor(file);

return { file: image, mime };
} else if (file.mime === "application/pdf" || file.mime === "text/plain") {
return documentProcessor(file);
}

return null;
})
)
: [];

const processedFiles = processedFilesWithNull.filter((file) => file !== null);

return {
role: from === "user" ? "user" : "model",
parts: [
...processedFiles.map((processedFile) => ({
inlineData: {
data: processedFile.image.toString("base64"),
data: processedFile.file.toString("base64"),
mimeType: processedFile.mime,
},
})),
Expand Down
1 change: 1 addition & 0 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ const modelConfig = z.object({
.passthrough()
.optional(),
multimodal: z.boolean().default(false),
multimodalAcceptedMimetypes: z.array(z.string()).optional(),
tools: z.boolean().default(false),
unlisted: z.boolean().default(false),
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
Expand Down
1 change: 1 addition & 0 deletions src/lib/types/Model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export type Model = Pick<
| "datasetUrl"
| "preprompt"
| "multimodal"
| "multimodalAcceptedMimetypes"
| "unlisted"
| "tools"
| "hasInferenceAPI"
Expand Down
1 change: 1 addition & 0 deletions src/routes/+layout.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ export const load: LayoutServerLoad = async ({ locals, depends, request }) => {
parameters: model.parameters,
preprompt: model.preprompt,
multimodal: model.multimodal,
multimodalAcceptedMimetypes: model.multimodalAcceptedMimetypes,
tools:
model.tools &&
// disable tools on huggingchat android app
Expand Down

0 comments on commit cc5dfd4

Please sign in to comment.