From b21d3184a8d129f340a8563d85e47e9ace6f3f88 Mon Sep 17 00:00:00 2001 From: Richard Groves Date: Mon, 29 Apr 2024 13:52:55 +0100 Subject: [PATCH 1/2] Adding Llama3 prompt generator --- .../ollama-chat-chatbot-completion-llama3.ts | 48 +++++ .../ollama/ollama-chat-chatbot-llama3.ts | 47 +++++ .../Llama3PromptTemplate.test.ts | 72 +++++++ .../prompt-template/Llama3PromptTemplate.ts | 198 ++++++++++++++++++ .../Llama3PromptTemplate.test.ts.snap | 69 ++++++ .../generate-text/prompt-template/index.ts | 1 + .../model-provider/llamacpp/LlamaCppPrompt.ts | 2 + .../ollama/OllamaCompletionPrompt.ts | 5 + 8 files changed, 442 insertions(+) create mode 100644 examples/basic/src/model-provider/ollama/ollama-chat-chatbot-completion-llama3.ts create mode 100644 examples/basic/src/model-provider/ollama/ollama-chat-chatbot-llama3.ts create mode 100644 packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.test.ts create mode 100644 packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.ts create mode 100644 packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/Llama3PromptTemplate.test.ts.snap diff --git a/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-completion-llama3.ts b/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-completion-llama3.ts new file mode 100644 index 000000000..61030e093 --- /dev/null +++ b/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-completion-llama3.ts @@ -0,0 +1,48 @@ +import { ChatPrompt, ollama, streamText } from "modelfusion"; +import * as readline from "node:readline/promises"; + +const systemPrompt = `You are a helpful, respectful and honest assistant.`; + +const terminal = readline.createInterface({ + input: process.stdin, + output: process.stdout, +}); + +async function main() { + const chat: ChatPrompt = { system: systemPrompt, messages: [] }; + + while (true) { + const userInput = await terminal.question("You: "); + + chat.messages.push({ role: "user", content: userInput }); + + // The advanced version that calls the Prompt consruction code + const textStream = await streamText({ + model: ollama + .CompletionTextGenerator({ + model: "llama3", + promptTemplate: ollama.prompt.Llama3, + raw: true, // required when using custom prompt template + }) + .withChatPrompt(), + prompt: chat, + }); + + //console.log("Full chat: " + JSON.stringify(chat)); + + let fullResponse = ""; + + process.stdout.write("\nAssistant : "); + for await (const textPart of textStream) { + fullResponse += textPart; + process.stdout.write(textPart); + } + + process.stdout.write("\n\n"); + + chat.messages.push({ role: "assistant", content: fullResponse }); + //console.log("Full response: " + fullResponse); + } +} + +main().catch(console.error); diff --git a/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-llama3.ts b/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-llama3.ts new file mode 100644 index 000000000..89a74364c --- /dev/null +++ b/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-llama3.ts @@ -0,0 +1,47 @@ +import { ChatPrompt, ollama, streamText } from "modelfusion"; +import * as readline from "node:readline/promises"; + +const systemPrompt = `You are a helpful, respectful and honest assistant.`; + +const terminal = readline.createInterface({ + input: process.stdin, + output: process.stdout, +}); + +async function main() { + const chat: ChatPrompt = { system: systemPrompt, messages: [] }; + + while (true) { + let userInput = await terminal.question("You: "); + + chat.messages.push({ role: "user", content: userInput }); + + // Llama3 we have to explicitly set the stop option value as otherwise it never ends the response + // - see https://github.com/ollama/ollama/issues/3759#issuecomment-2076973989 + let model = ollama + .ChatTextGenerator({ + model: "llama3", + stopSequences: ollama.prompt.Llama3.chat().stopSequences, + }) + .withChatPrompt(); + + const textStream = await streamText({ + model: model, + prompt: chat, + }); + + let fullResponse = ""; + + process.stdout.write("\nAssistant : "); + for await (const textPart of textStream) { + fullResponse += textPart; + process.stdout.write(textPart); + } + + process.stdout.write("\n\n"); + + chat.messages.push({ role: "assistant", content: fullResponse }); + } +} + +main().catch(console.error); diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.test.ts b/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.test.ts new file mode 100644 index 000000000..1829671fc --- /dev/null +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.test.ts @@ -0,0 +1,72 @@ +import { chat, instruction, text } from "./Llama3PromptTemplate"; + +describe("text prompt", () => { + it("should format prompt", () => { + const prompt = text().format("prompt"); + + expect(prompt).toMatchSnapshot(); + }); +}); + +describe("instruction prompt", () => { + it("should format prompt with instruction", () => { + const prompt = instruction().format({ + instruction: "instruction", + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with system and instruction", () => { + const prompt = instruction().format({ + system: "system", + instruction: "instruction", + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with instruction and response prefix", () => { + const prompt = instruction().format({ + instruction: "instruction", + responsePrefix: "response prefix", + }); + + expect(prompt).toMatchSnapshot(); + }); +}); + +describe("chat prompt", () => { + it("should format prompt with user message", () => { + const prompt = chat().format({ + messages: [{ role: "user", content: "user message" }], + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with user-assistant-user messages", () => { + const prompt = chat().format({ + messages: [ + { role: "user", content: "1st user message" }, + { role: "assistant", content: "assistant message" }, + { role: "user", content: "2nd user message" }, + ], + }); + + expect(prompt).toMatchSnapshot(); + }); + + it("should format prompt with system message and user-assistant-user messages", () => { + const prompt = chat().format({ + system: "you are a chatbot", + messages: [ + { role: "user", content: "1st user message" }, + { role: "assistant", content: "assistant message" }, + { role: "user", content: "2nd user message" }, + ], + }); + + expect(prompt).toMatchSnapshot(); + }); +}); diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.ts b/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.ts new file mode 100644 index 000000000..21460796a --- /dev/null +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.ts @@ -0,0 +1,198 @@ +import { TextGenerationPromptTemplate } from "../TextGenerationPromptTemplate"; +import { ChatPrompt } from "./ChatPrompt"; +import { validateContentIsString } from "./ContentPart"; +import { InstructionPrompt } from "./InstructionPrompt"; +import { InvalidPromptError } from "./InvalidPromptError"; + +// See https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ + +const BEGIN_SEGMENT = "<|begin_of_text|>"; // Doesn't appear to be needed, but is in documentation so leave it in +//const END_SEGMENT = "<|end_of_text|>"; // In the docs but never used as a sent item, or return AFAICS. Linter wont let it be defined and not used + +const BEGIN_INSTRUCTION = "<|start_header_id|>user<|end_header_id|>\n\n"; +const END_INSTRUCTION = "<|eot_id|>"; + +// This is the marker of an assistant response, or the end of the prompt to indicate it should carry on +const BEGIN_RESPONSE_ASSISTANT = + "<|start_header_id|>assistant<|end_header_id|>\n\n"; + +const BEGIN_SYSTEM = "<|start_header_id|>system<|end_header_id|>\n\n"; +const END_SYSTEM = "<|eot_id|>"; + +const STOP_SEQUENCE = "<|eot_id|>"; // <|eot_id|> is what the assistant sends to indicate it has finished and has no more to say + +/** + * Formats a text prompt as a Llama 3 prompt. + * + * Llama 3 prompt template: + * ``` + * <|begin_of_text|><|start_header_id|>user<|end_header_id|> + * + * { instruction }<|eot_id|><|start_header_id|>assistant<|end_header_id|> + * + * + * ``` + * + * @see https://github.com/meta-llama/llama-recipes + */ +export function text(): TextGenerationPromptTemplate { + return { + stopSequences: [STOP_SEQUENCE], + format(prompt) { + let result = `${BEGIN_SEGMENT}${BEGIN_INSTRUCTION}${prompt}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`; + + //console.log(`text(): ${result}\n--END--`); + + return result; + }, + }; +} + +/** + * Formats an instruction prompt as a Llama 3 prompt. + * + * Llama 3 prompt template: + * ``` + * <|begin_of_text|><|start_header_id|>system<|end_header_id|> + * + * ${ system prompt }<|eot_id|><|start_header_id|>user<|end_header_id|> + * + * ${ instruction }<|eot_id|><|start_header_id|>assistant<|end_header_id|> + * + * + * ``` + * + * @see https://github.com/meta-llama/llama-recipes + */ +export function instruction(): TextGenerationPromptTemplate< + InstructionPrompt, + string +> { + return { + stopSequences: [STOP_SEQUENCE], + format(prompt) { + const instruction = validateContentIsString(prompt.instruction, prompt); + + let result = `${BEGIN_SEGMENT}`; + result += `${prompt.system != null ? `${BEGIN_SYSTEM}${prompt.system}${END_SYSTEM}` : ""}`; + result += `${BEGIN_INSTRUCTION}${instruction}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`; + + //console.log(`instruction(): ${result}\n--END--`); + + return result; + }, + }; +} + +/** + * Formats a chat prompt as a Llama 3 prompt. + * + * Llama 3 prompt template: + * + * ``` + * <|begin_of_text|><|start_header_id|>system<|end_header_id|> + * + * ${ system prompt }<|eot_id|><|start_header_id|>user<|end_header_id|> + * + * ${ user msg 1 }<|eot_id|><|start_header_id|>assistant<|end_header_id|> + * + * ${ model response 1 }<|eot_id|><|start_header_id|>user<|end_header_id|> + * + * ${ user msg 2 }<|eot_id|><|start_header_id|>assistant<|end_header_id|> + * + * ${ model response 2 }<|eot_id|><|start_header_id|>user<|end_header_id|> + * + * ${ user msg 3 }<|eot_id|><|start_header_id|>assistant<|end_header_id|> + * + * + * ``` + * + * @see https://github.com/meta-llama/llama-recipes + */ +export function chat(): TextGenerationPromptTemplate { + return { + format(prompt) { + validateLlama3Prompt(prompt); + + // get content of the first message (validated to be a user message) + const content = prompt.messages[0].content; + + let text = `${BEGIN_SEGMENT}`; + text += `${prompt.system != null ? `${BEGIN_SYSTEM}${prompt.system}${END_SYSTEM}` : ""}`; + text += `${BEGIN_INSTRUCTION}${content}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`; + + // process remaining messages + for (let i = 1; i < prompt.messages.length; i++) { + const { role, content } = prompt.messages[i]; + switch (role) { + case "user": { + const textContent = validateContentIsString(content, prompt); + text += `${BEGIN_INSTRUCTION}${textContent}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`; + break; + } + case "assistant": { + // The assistant will have added \n\n to the start of their response - we don't do that so the tests are slightly different than reality + text += `${validateContentIsString(content, prompt)}${END_INSTRUCTION}`; + break; + } + case "tool": { + throw new InvalidPromptError( + "Tool messages are not supported.", + prompt + ); + } + default: { + const _exhaustiveCheck: never = role; + throw new Error(`Unsupported role: ${_exhaustiveCheck}`); + } + } + } + + //console.log(`chat(): ${text}\n--END--`); + + return text; + }, + stopSequences: [STOP_SEQUENCE], + }; +} + +/** + * Checks if a Llama3 chat prompt is valid. Throws a {@link ChatPromptValidationError} if it's not. + * + * - The first message of the chat must be a user message. + * - Then it must be alternating between an assistant message and a user message. + * - The last message must always be a user message (when submitting to a model). + * + * The type checking is done at runtime when you submit a chat prompt to a model with a prompt template. + * + * @throws {@link ChatPromptValidationError} + */ +export function validateLlama3Prompt(chatPrompt: ChatPrompt) { + const messages = chatPrompt.messages; + + if (messages.length < 1) { + throw new InvalidPromptError( + "ChatPrompt should have at least one message.", + chatPrompt + ); + } + + for (let i = 0; i < messages.length; i++) { + const expectedRole = i % 2 === 0 ? "user" : "assistant"; + const role = messages[i].role; + + if (role !== expectedRole) { + throw new InvalidPromptError( + `Message at index ${i} should have role '${expectedRole}', but has role '${role}'.`, + chatPrompt + ); + } + } + + if (messages.length % 2 === 0) { + throw new InvalidPromptError( + "The last message must be a user message.", + chatPrompt + ); + } +} diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/Llama3PromptTemplate.test.ts.snap b/packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/Llama3PromptTemplate.test.ts.snap new file mode 100644 index 000000000..82f3132e1 --- /dev/null +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/__snapshots__/Llama3PromptTemplate.test.ts.snap @@ -0,0 +1,69 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`chat prompt > should format prompt with system message and user-assistant-user messages 1`] = ` +"<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +you are a chatbot<|eot_id|><|start_header_id|>user<|end_header_id|> + +1st user message<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +assistant message<|eot_id|><|start_header_id|>user<|end_header_id|> + +2nd user message<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +" +`; + +exports[`chat prompt > should format prompt with user message 1`] = ` +"<|begin_of_text|><|start_header_id|>user<|end_header_id|> + +user message<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +" +`; + +exports[`chat prompt > should format prompt with user-assistant-user messages 1`] = ` +"<|begin_of_text|><|start_header_id|>user<|end_header_id|> + +1st user message<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +assistant message<|eot_id|><|start_header_id|>user<|end_header_id|> + +2nd user message<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +" +`; + +exports[`instruction prompt > should format prompt with instruction 1`] = ` +"<|begin_of_text|><|start_header_id|>user<|end_header_id|> + +instruction<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +" +`; + +exports[`instruction prompt > should format prompt with instruction and response prefix 1`] = ` +"<|begin_of_text|><|start_header_id|>user<|end_header_id|> + +instruction<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +" +`; + +exports[`instruction prompt > should format prompt with system and instruction 1`] = ` +"<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +system<|eot_id|><|start_header_id|>user<|end_header_id|> + +instruction<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +" +`; + +exports[`text prompt > should format prompt 1`] = ` +"<|begin_of_text|><|start_header_id|>user<|end_header_id|> + +prompt<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +" +`; diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/index.ts b/packages/modelfusion/src/model-function/generate-text/prompt-template/index.ts index e163d0272..348cfb500 100644 --- a/packages/modelfusion/src/model-function/generate-text/prompt-template/index.ts +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/index.ts @@ -5,6 +5,7 @@ export * from "./ContentPart"; export * from "./InstructionPrompt"; export * from "./InvalidPromptError"; export * as Llama2Prompt from "./Llama2PromptTemplate"; +export * as Llama3Prompt from "./Llama3PromptTemplate"; export * as MistralInstructPrompt from "./MistralInstructPromptTemplate"; export * as NeuralChatPrompt from "./NeuralChatPromptTemplate"; export * from "./PromptTemplateProvider"; diff --git a/packages/modelfusion/src/model-provider/llamacpp/LlamaCppPrompt.ts b/packages/modelfusion/src/model-provider/llamacpp/LlamaCppPrompt.ts index 05e14c9d1..28c55bc0b 100644 --- a/packages/modelfusion/src/model-provider/llamacpp/LlamaCppPrompt.ts +++ b/packages/modelfusion/src/model-provider/llamacpp/LlamaCppPrompt.ts @@ -2,6 +2,7 @@ import { TextGenerationPromptTemplate } from "../../model-function/generate-text import * as alpacaPrompt from "../../model-function/generate-text/prompt-template/AlpacaPromptTemplate"; import * as chatMlPrompt from "../../model-function/generate-text/prompt-template/ChatMLPromptTemplate"; import * as llama2Prompt from "../../model-function/generate-text/prompt-template/Llama2PromptTemplate"; +import * as llama3Prompt from "../../model-function/generate-text/prompt-template/Llama3PromptTemplate"; import * as mistralPrompt from "../../model-function/generate-text/prompt-template/MistralInstructPromptTemplate"; import * as neuralChatPrompt from "../../model-function/generate-text/prompt-template/NeuralChatPromptTemplate"; import { TextGenerationPromptTemplateProvider } from "../../model-function/generate-text/prompt-template/PromptTemplateProvider"; @@ -73,6 +74,7 @@ export const Mistral = asLlamaCppTextPromptTemplateProvider(mistralPrompt); export const ChatML = asLlamaCppTextPromptTemplateProvider(chatMlPrompt); export const Llama2 = asLlamaCppTextPromptTemplateProvider(llama2Prompt); +export const Llama3 = asLlamaCppTextPromptTemplateProvider(llama3Prompt); export const NeuralChat = asLlamaCppTextPromptTemplateProvider(neuralChatPrompt); export const Alpaca = asLlamaCppTextPromptTemplateProvider(alpacaPrompt); diff --git a/packages/modelfusion/src/model-provider/ollama/OllamaCompletionPrompt.ts b/packages/modelfusion/src/model-provider/ollama/OllamaCompletionPrompt.ts index 738d679ef..68235aee1 100644 --- a/packages/modelfusion/src/model-provider/ollama/OllamaCompletionPrompt.ts +++ b/packages/modelfusion/src/model-provider/ollama/OllamaCompletionPrompt.ts @@ -2,6 +2,7 @@ import { TextGenerationPromptTemplate } from "../../model-function/generate-text import * as alpacaPrompt from "../../model-function/generate-text/prompt-template/AlpacaPromptTemplate"; import * as chatMlPrompt from "../../model-function/generate-text/prompt-template/ChatMLPromptTemplate"; import * as llama2Prompt from "../../model-function/generate-text/prompt-template/Llama2PromptTemplate"; +import * as llama3Prompt from "../../model-function/generate-text/prompt-template/Llama3PromptTemplate"; import * as mistralPrompt from "../../model-function/generate-text/prompt-template/MistralInstructPromptTemplate"; import * as neuralChatPrompt from "../../model-function/generate-text/prompt-template/NeuralChatPromptTemplate"; import { TextGenerationPromptTemplateProvider } from "../../model-function/generate-text/prompt-template/PromptTemplateProvider"; @@ -13,6 +14,8 @@ import { OllamaCompletionPrompt } from "./OllamaCompletionModel"; export function asOllamaCompletionPromptTemplate( promptTemplate: TextGenerationPromptTemplate ): TextGenerationPromptTemplate { + console.log("stopSeq1: " + promptTemplate.stopSequences); + return { format: (prompt) => ({ prompt: promptTemplate.format(prompt), @@ -75,6 +78,8 @@ export const ChatML = asOllamaCompletionTextPromptTemplateProvider(chatMlPrompt); export const Llama2 = asOllamaCompletionTextPromptTemplateProvider(llama2Prompt); +export const Llama3 = + asOllamaCompletionTextPromptTemplateProvider(llama3Prompt); export const NeuralChat = asOllamaCompletionTextPromptTemplateProvider(neuralChatPrompt); export const Alpaca = From fcda3b1e9604887dc90a8e069e05c4100f2343ff Mon Sep 17 00:00:00 2001 From: Richard Groves Date: Mon, 29 Apr 2024 16:44:04 +0100 Subject: [PATCH 2/2] Removing commented out logging --- .../ollama/ollama-chat-chatbot-completion-llama3.ts | 3 --- .../generate-text/prompt-template/Llama3PromptTemplate.ts | 8 -------- 2 files changed, 11 deletions(-) diff --git a/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-completion-llama3.ts b/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-completion-llama3.ts index 61030e093..42c5879e1 100644 --- a/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-completion-llama3.ts +++ b/examples/basic/src/model-provider/ollama/ollama-chat-chatbot-completion-llama3.ts @@ -28,8 +28,6 @@ async function main() { prompt: chat, }); - //console.log("Full chat: " + JSON.stringify(chat)); - let fullResponse = ""; process.stdout.write("\nAssistant : "); @@ -41,7 +39,6 @@ async function main() { process.stdout.write("\n\n"); chat.messages.push({ role: "assistant", content: fullResponse }); - //console.log("Full response: " + fullResponse); } } diff --git a/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.ts b/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.ts index 21460796a..2940f2a43 100644 --- a/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.ts +++ b/packages/modelfusion/src/model-function/generate-text/prompt-template/Llama3PromptTemplate.ts @@ -40,9 +40,6 @@ export function text(): TextGenerationPromptTemplate { stopSequences: [STOP_SEQUENCE], format(prompt) { let result = `${BEGIN_SEGMENT}${BEGIN_INSTRUCTION}${prompt}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`; - - //console.log(`text(): ${result}\n--END--`); - return result; }, }; @@ -76,9 +73,6 @@ export function instruction(): TextGenerationPromptTemplate< let result = `${BEGIN_SEGMENT}`; result += `${prompt.system != null ? `${BEGIN_SYSTEM}${prompt.system}${END_SYSTEM}` : ""}`; result += `${BEGIN_INSTRUCTION}${instruction}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`; - - //console.log(`instruction(): ${result}\n--END--`); - return result; }, }; @@ -148,8 +142,6 @@ export function chat(): TextGenerationPromptTemplate { } } - //console.log(`chat(): ${text}\n--END--`); - return text; }, stopSequences: [STOP_SEQUENCE],