Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jinja template for chat formatting (#730) #744

Merged
merged 6 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ MODELS=`[
"description": "The latest and biggest model from Meta, fine-tuned for chat.",
"logoUrl": "https://huggingface.co/datasets/huggingchat/models-logo/resolve/main/meta-logo.png",
"websiteUrl": "https://ai.meta.com/llama/",
"preprompt": " ",
"preprompt": "",
"chatPromptTemplate" : "<s>[INST] <<SYS>>\n{{preprompt}}\n<</SYS>>\n\n{{#each messages}}{{#ifUser}}{{content}} [/INST] {{/ifUser}}{{#ifAssistant}}{{content}} </s><s>[INST] {{/ifAssistant}}{{/each}}",
"promptExamples": [
{
Expand Down
48 changes: 38 additions & 10 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"@huggingface/inference": "^2.6.3",
"@iconify-json/bi": "^1.1.21",
"@resvg/resvg-js": "^2.6.0",
"@xenova/transformers": "^2.6.0",
"@xenova/transformers": "^2.16.1",
"autoprefixer": "^10.4.14",
"browser-image-resizer": "^2.4.1",
"date-fns": "^2.29.3",
Expand Down Expand Up @@ -83,8 +83,8 @@
},
"optionalDependencies": {
"@anthropic-ai/sdk": "^0.17.1",
"@google-cloud/vertexai": "^0.5.0",
"aws4fetch": "^1.0.17",
"openai": "^4.14.2",
"@google-cloud/vertexai": "^0.5.0"
"openai": "^4.14.2"
}
}
20 changes: 2 additions & 18 deletions src/lib/components/TokensCounter.svelte
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<script lang="ts">
import type { Model } from "$lib/types/Model";
import { AutoTokenizer, PreTrainedTokenizer } from "@xenova/transformers";
import { getTokenizer } from "$lib/utils/getTokenizer";
import type { PreTrainedTokenizer } from "@xenova/transformers";

export let classNames = "";
export let prompt = "";
Expand All @@ -9,23 +10,6 @@

let tokenizer: PreTrainedTokenizer | undefined = undefined;

async function getTokenizer(_modelTokenizer: Exclude<Model["tokenizer"], undefined>) {
if (typeof _modelTokenizer === "string") {
// return auto tokenizer
return await AutoTokenizer.from_pretrained(_modelTokenizer);
}
{
// construct & return pretrained tokenizer
const { tokenizerUrl, tokenizerConfigUrl } = _modelTokenizer satisfies {
tokenizerUrl: string;
tokenizerConfigUrl: string;
};
const tokenizerJSON = await (await fetch(tokenizerUrl)).json();
const tokenizerConfig = await (await fetch(tokenizerConfigUrl)).json();
return new PreTrainedTokenizer(tokenizerJSON, tokenizerConfig);
}
}

async function tokenizeText(_prompt: string) {
if (!tokenizer) {
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { z } from "zod";
import type { EmbeddingEndpoint } from "../embeddingEndpoints";
import type { Tensor, Pipeline } from "@xenova/transformers";
import type { Tensor, FeatureExtractionPipeline } from "@xenova/transformers";
import { pipeline } from "@xenova/transformers";

export const embeddingEndpointTransformersJSParametersSchema = z.object({
Expand All @@ -11,9 +11,9 @@ export const embeddingEndpointTransformersJSParametersSchema = z.object({

// Use the Singleton pattern to enable lazy construction of the pipeline.
class TransformersJSModelsSingleton {
static instances: Array<[string, Promise<Pipeline>]> = [];
static instances: Array<[string, Promise<FeatureExtractionPipeline>]> = [];

static async getInstance(modelName: string): Promise<Pipeline> {
static async getInstance(modelName: string): Promise<FeatureExtractionPipeline> {
const modelPipelineInstance = this.instances.find(([name]) => name === modelName);

if (modelPipelineInstance) {
Expand Down
79 changes: 61 additions & 18 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ import endpointTgi from "./endpoints/tgi/endpointTgi";
import { sum } from "$lib/utils/sum";
import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";

import type { PreTrainedTokenizer } from "@xenova/transformers";

import JSON5 from "json5";
import { getTokenizer } from "$lib/utils/getTokenizer";

type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;

Expand All @@ -39,23 +42,9 @@ const modelConfig = z.object({
.optional(),
datasetName: z.string().min(1).optional(),
datasetUrl: z.string().url().optional(),
userMessageToken: z.string().default(""),
userMessageEndToken: z.string().default(""),
assistantMessageToken: z.string().default(""),
assistantMessageEndToken: z.string().default(""),
messageEndToken: z.string().default(""),
preprompt: z.string().default(""),
prepromptUrl: z.string().url().optional(),
chatPromptTemplate: z
.string()
.default(
"{{preprompt}}" +
"{{#each messages}}" +
"{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
"{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
"{{/each}}" +
"{{assistantMessageToken}}"
),
chatPromptTemplate: z.string().optional(),
promptExamples: z
.array(
z.object({
Expand Down Expand Up @@ -84,11 +73,65 @@ const modelConfig = z.object({

const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));

async function getChatPromptRender(
m: z.infer<typeof modelConfig>
): Promise<ReturnType<typeof compileTemplate<ChatTemplateInput>>> {
if (m.chatPromptTemplate) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt the order of fallback be the other way around? First, try to use transfomersjs, then try to use chatPromptTemplate

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downside of doing it this way is that for example in prod, we specify the tokenizer for token counting but we actually want to override the chat template with our own template.

I think in terms of specificity it makes sense to be custom chat template > tokenizer "default" template wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep sounds good. And if we want to change the order, let's do it in follow-up PR. Also, if we want to change the order, we would need to handle #744 (comment) first

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agreed, overall we should push towards using tokenizers rather than chat prompt templates, but we will need to handle the edge cases before indeed 😁

return compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m);
} else {
let tokenizer: PreTrainedTokenizer;

if (!m.tokenizer) {
throw new Error(
"No tokenizer specified and no chat prompt template specified for model " + m.name
);
}

try {
tokenizer = await getTokenizer(m.tokenizer);
} catch (e) {
throw Error(
"Failed to load tokenizer for model " +
m.name +
" consider setting chatPromptTemplate manually or making sure the model is available on the hub."
);
}

const renderTemplate = ({ messages, preprompt }: ChatTemplateInput) => {
let formattedMessages: { role: string; content: string }[] = messages.map((message) => ({
content: message.content,
role: message.from,
}));

if (preprompt) {
formattedMessages = [
{
role: "system",
content: preprompt,
},
...formattedMessages,
];
}

const output = tokenizer.apply_chat_template(formattedMessages, {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now, "system" role would case error on prompts that does not naturally support "system" right ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I tried looking into it, but because the error messages are hardcoded in the jinja template, it's not obvious to me how to determine if the template is failing bc of a lack of system prompt or for some other reason

I guess you could retry without the "system" message and see if it builds the prompt then? we could do it in a second PR if that sounds good to you ?

tokenize: false,
add_generation_prompt: true,
});

if (typeof output !== "string") {
throw new Error("Failed to apply chat template, the output is not a string");
}

return output;
};

return renderTemplate;
}
}

const processModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
chatPromptRender: await getChatPromptRender(m),
id: m.id || m.name,
displayName: m.displayName || m.name,
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
Expand Down
Loading
Loading