Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Llama3 prompt generator #335

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { ChatPrompt, ollama, streamText } from "modelfusion";
import * as readline from "node:readline/promises";

const systemPrompt = `You are a helpful, respectful and honest assistant.`;

const terminal = readline.createInterface({
input: process.stdin,
output: process.stdout,
});

async function main() {
const chat: ChatPrompt = { system: systemPrompt, messages: [] };

while (true) {
const userInput = await terminal.question("You: ");

chat.messages.push({ role: "user", content: userInput });

// The advanced version that calls the Prompt consruction code
const textStream = await streamText({
model: ollama
.CompletionTextGenerator({
model: "llama3",
promptTemplate: ollama.prompt.Llama3,
raw: true, // required when using custom prompt template
})
.withChatPrompt(),
prompt: chat,
});

let fullResponse = "";

process.stdout.write("\nAssistant : ");
for await (const textPart of textStream) {
fullResponse += textPart;
process.stdout.write(textPart);
}

process.stdout.write("\n\n");

chat.messages.push({ role: "assistant", content: fullResponse });
}
}

main().catch(console.error);
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { ChatPrompt, ollama, streamText } from "modelfusion";
import * as readline from "node:readline/promises";

const systemPrompt = `You are a helpful, respectful and honest assistant.`;

const terminal = readline.createInterface({
input: process.stdin,
output: process.stdout,
});

async function main() {
const chat: ChatPrompt = { system: systemPrompt, messages: [] };

while (true) {
let userInput = await terminal.question("You: ");

chat.messages.push({ role: "user", content: userInput });

// Llama3 we have to explicitly set the stop option value as otherwise it never ends the response
// - see https://github.com/ollama/ollama/issues/3759#issuecomment-2076973989
let model = ollama
.ChatTextGenerator({
model: "llama3",
stopSequences: ollama.prompt.Llama3.chat().stopSequences,
})
.withChatPrompt();

const textStream = await streamText({
model: model,
prompt: chat,
});

let fullResponse = "";

process.stdout.write("\nAssistant : ");
for await (const textPart of textStream) {
fullResponse += textPart;
process.stdout.write(textPart);
}

process.stdout.write("\n\n");

chat.messages.push({ role: "assistant", content: fullResponse });
}
}

main().catch(console.error);
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { chat, instruction, text } from "./Llama3PromptTemplate";

describe("text prompt", () => {
it("should format prompt", () => {
const prompt = text().format("prompt");

expect(prompt).toMatchSnapshot();
});
});

describe("instruction prompt", () => {
it("should format prompt with instruction", () => {
const prompt = instruction().format({
instruction: "instruction",
});

expect(prompt).toMatchSnapshot();
});

it("should format prompt with system and instruction", () => {
const prompt = instruction().format({
system: "system",
instruction: "instruction",
});

expect(prompt).toMatchSnapshot();
});

it("should format prompt with instruction and response prefix", () => {
const prompt = instruction().format({
instruction: "instruction",
responsePrefix: "response prefix",
});

expect(prompt).toMatchSnapshot();
});
});

describe("chat prompt", () => {
it("should format prompt with user message", () => {
const prompt = chat().format({
messages: [{ role: "user", content: "user message" }],
});

expect(prompt).toMatchSnapshot();
});

it("should format prompt with user-assistant-user messages", () => {
const prompt = chat().format({
messages: [
{ role: "user", content: "1st user message" },
{ role: "assistant", content: "assistant message" },
{ role: "user", content: "2nd user message" },
],
});

expect(prompt).toMatchSnapshot();
});

it("should format prompt with system message and user-assistant-user messages", () => {
const prompt = chat().format({
system: "you are a chatbot",
messages: [
{ role: "user", content: "1st user message" },
{ role: "assistant", content: "assistant message" },
{ role: "user", content: "2nd user message" },
],
});

expect(prompt).toMatchSnapshot();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import { TextGenerationPromptTemplate } from "../TextGenerationPromptTemplate";
import { ChatPrompt } from "./ChatPrompt";
import { validateContentIsString } from "./ContentPart";
import { InstructionPrompt } from "./InstructionPrompt";
import { InvalidPromptError } from "./InvalidPromptError";

// See https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/

const BEGIN_SEGMENT = "<|begin_of_text|>"; // Doesn't appear to be needed, but is in documentation so leave it in
//const END_SEGMENT = "<|end_of_text|>"; // In the docs but never used as a sent item, or return AFAICS. Linter wont let it be defined and not used

const BEGIN_INSTRUCTION = "<|start_header_id|>user<|end_header_id|>\n\n";
const END_INSTRUCTION = "<|eot_id|>";

// This is the marker of an assistant response, or the end of the prompt to indicate it should carry on
const BEGIN_RESPONSE_ASSISTANT =
"<|start_header_id|>assistant<|end_header_id|>\n\n";

const BEGIN_SYSTEM = "<|start_header_id|>system<|end_header_id|>\n\n";
const END_SYSTEM = "<|eot_id|>";

const STOP_SEQUENCE = "<|eot_id|>"; // <|eot_id|> is what the assistant sends to indicate it has finished and has no more to say

/**
* Formats a text prompt as a Llama 3 prompt.
*
* Llama 3 prompt template:
* ```
* <|begin_of_text|><|start_header_id|>user<|end_header_id|>
*
* { instruction }<|eot_id|><|start_header_id|>assistant<|end_header_id|>
*
*
* ```
*
* @see https://github.com/meta-llama/llama-recipes
*/
export function text(): TextGenerationPromptTemplate<string, string> {
return {
stopSequences: [STOP_SEQUENCE],
format(prompt) {
let result = `${BEGIN_SEGMENT}${BEGIN_INSTRUCTION}${prompt}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`;
return result;
},
};
}

/**
* Formats an instruction prompt as a Llama 3 prompt.
*
* Llama 3 prompt template:
* ```
* <|begin_of_text|><|start_header_id|>system<|end_header_id|>
*
* ${ system prompt }<|eot_id|><|start_header_id|>user<|end_header_id|>
*
* ${ instruction }<|eot_id|><|start_header_id|>assistant<|end_header_id|>
*
*
* ```
*
* @see https://github.com/meta-llama/llama-recipes
*/
export function instruction(): TextGenerationPromptTemplate<
InstructionPrompt,
string
> {
return {
stopSequences: [STOP_SEQUENCE],
format(prompt) {
const instruction = validateContentIsString(prompt.instruction, prompt);

let result = `${BEGIN_SEGMENT}`;
result += `${prompt.system != null ? `${BEGIN_SYSTEM}${prompt.system}${END_SYSTEM}` : ""}`;
result += `${BEGIN_INSTRUCTION}${instruction}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`;
return result;
},
};
}

/**
* Formats a chat prompt as a Llama 3 prompt.
*
* Llama 3 prompt template:
*
* ```
* <|begin_of_text|><|start_header_id|>system<|end_header_id|>
*
* ${ system prompt }<|eot_id|><|start_header_id|>user<|end_header_id|>
*
* ${ user msg 1 }<|eot_id|><|start_header_id|>assistant<|end_header_id|>
*
* ${ model response 1 }<|eot_id|><|start_header_id|>user<|end_header_id|>
*
* ${ user msg 2 }<|eot_id|><|start_header_id|>assistant<|end_header_id|>
*
* ${ model response 2 }<|eot_id|><|start_header_id|>user<|end_header_id|>
*
* ${ user msg 3 }<|eot_id|><|start_header_id|>assistant<|end_header_id|>
*
*
* ```
*
* @see https://github.com/meta-llama/llama-recipes
*/
export function chat(): TextGenerationPromptTemplate<ChatPrompt, string> {
return {
format(prompt) {
validateLlama3Prompt(prompt);

// get content of the first message (validated to be a user message)
const content = prompt.messages[0].content;

let text = `${BEGIN_SEGMENT}`;
text += `${prompt.system != null ? `${BEGIN_SYSTEM}${prompt.system}${END_SYSTEM}` : ""}`;
text += `${BEGIN_INSTRUCTION}${content}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`;

// process remaining messages
for (let i = 1; i < prompt.messages.length; i++) {
const { role, content } = prompt.messages[i];
switch (role) {
case "user": {
const textContent = validateContentIsString(content, prompt);
text += `${BEGIN_INSTRUCTION}${textContent}${END_INSTRUCTION}${BEGIN_RESPONSE_ASSISTANT}`;
break;
}
case "assistant": {
// The assistant will have added \n\n to the start of their response - we don't do that so the tests are slightly different than reality
text += `${validateContentIsString(content, prompt)}${END_INSTRUCTION}`;
break;
}
case "tool": {
throw new InvalidPromptError(
"Tool messages are not supported.",
prompt
);
}
default: {
const _exhaustiveCheck: never = role;
throw new Error(`Unsupported role: ${_exhaustiveCheck}`);
}
}
}

return text;
},
stopSequences: [STOP_SEQUENCE],
};
}

/**
* Checks if a Llama3 chat prompt is valid. Throws a {@link ChatPromptValidationError} if it's not.
*
* - The first message of the chat must be a user message.
* - Then it must be alternating between an assistant message and a user message.
* - The last message must always be a user message (when submitting to a model).
*
* The type checking is done at runtime when you submit a chat prompt to a model with a prompt template.
*
* @throws {@link ChatPromptValidationError}
*/
export function validateLlama3Prompt(chatPrompt: ChatPrompt) {
const messages = chatPrompt.messages;

if (messages.length < 1) {
throw new InvalidPromptError(
"ChatPrompt should have at least one message.",
chatPrompt
);
}

for (let i = 0; i < messages.length; i++) {
const expectedRole = i % 2 === 0 ? "user" : "assistant";
const role = messages[i].role;

if (role !== expectedRole) {
throw new InvalidPromptError(
`Message at index ${i} should have role '${expectedRole}', but has role '${role}'.`,
chatPrompt
);
}
}

if (messages.length % 2 === 0) {
throw new InvalidPromptError(
"The last message must be a user message.",
chatPrompt
);
}
}
Loading