Skip to content

Commit

Permalink
Merge pull request #787 from patrick-hovsepian/generic_prompt
Browse files Browse the repository at this point in the history
Generic Prompt Formatter
  • Loading branch information
martindevans authored Jun 10, 2024
2 parents a5de5f7 + 8c9bbb6 commit 2990b47
Show file tree
Hide file tree
Showing 14 changed files with 386 additions and 137 deletions.
103 changes: 27 additions & 76 deletions LLama.Examples/Examples/LLama3ChatSession.cs
Original file line number Diff line number Diff line change
@@ -1,49 +1,63 @@
using LLama.Abstractions;
using LLama.Common;
using LLama.Common;
using LLama.Transformers;

namespace LLama.Examples.Examples;

// When using chatsession, it's a common case that you want to strip the role names
// rather than display them. This example shows how to use transforms to strip them.
/// <summary>
/// This sample shows a simple chatbot
/// It's configured to use the default prompt template as provided by llama.cpp and supports
/// models such as llama3, llama2, phi3, qwen1.5, etc.
/// </summary>
public class LLama3ChatSession
{
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 10
};

using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);
session.WithHistoryTransform(new LLama3HistoryTransform());

// add the default templator. If llama.cpp doesn't support the template by default,
// you'll need to write your own transformer to format the prompt correctly
session.WithHistoryTransform(new PromptTemplateTransformer(model, withAssistant: true));

// Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:", "�" },
[model.Tokens.EndOfTurnToken!, "�"],
redundancyLength: 5));

InferenceParams inferenceParams = new InferenceParams()
var inferenceParams = new InferenceParams()
{
MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
Temperature = 0.6f,
AntiPrompts = new List<string> { "User:" }
AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";
Console.Write("User> ");
var userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Assistant> ");

// as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc
await foreach (
var text
in session.ChatAsync(
Expand All @@ -56,71 +70,8 @@ in session.ChatAsync(
Console.WriteLine();

Console.ForegroundColor = ConsoleColor.Green;
Console.Write("User> ");
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}

class LLama3HistoryTransform : IHistoryTransform
{
/// <summary>
/// Convert a ChatHistory instance to plain text.
/// </summary>
/// <param name="history">The ChatHistory instance</param>
/// <returns></returns>
public string HistoryToText(ChatHistory history)
{
string res = Bos;
foreach (var message in history.Messages)
{
res += EncodeMessage(message);
}
res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, ""));
return res;
}

private string EncodeHeader(ChatHistory.Message message)
{
string res = StartHeaderId;
res += message.AuthorRole.ToString();
res += EndHeaderId;
res += "\n\n";
return res;
}

private string EncodeMessage(ChatHistory.Message message)
{
string res = EncodeHeader(message);
res += message.Content;
res += EndofTurn;
return res;
}

/// <summary>
/// Converts plain text to a ChatHistory instance.
/// </summary>
/// <param name="role">The role for the author.</param>
/// <param name="text">The chat history as plain text.</param>
/// <returns>The updated history.</returns>
public ChatHistory TextToHistory(AuthorRole role, string text)
{
return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) });
}

/// <summary>
/// Copy the transform.
/// </summary>
/// <returns></returns>
public IHistoryTransform Clone()
{
return new LLama3HistoryTransform();
}

private const string StartHeaderId = "<|start_header_id|>";
private const string EndHeaderId = "<|end_header_id|>";
private const string Bos = "<|begin_of_text|>";
private const string Eos = "<|end_of_text|>";
private const string EndofTurn = "<|eot_id|>";
}
}
39 changes: 39 additions & 0 deletions LLama.Unittest/Native/SafeLlamaModelHandleTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using System.Text;
using LLama.Common;
using LLama.Native;
using LLama.Extensions;

namespace LLama.Unittest.Native;

public class SafeLlamaModelHandleTests
{
private readonly LLamaWeights _model;
private readonly SafeLlamaModelHandle TestableHandle;

public SafeLlamaModelHandleTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 1,
GpuLayerCount = Constants.CIGpuLayerCount
};
_model = LLamaWeights.LoadFromFile(@params);

TestableHandle = _model.NativeHandle;
}

[Fact]
public void MetadataValByKey_ReturnsCorrectly()
{
const string key = "general.name";
var template = _model.NativeHandle.MetadataValueByKey(key);
var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span);

const string expected = "LLaMA v2";
Assert.Equal(expected, name);

var metadataLookup = _model.Metadata[key];
Assert.Equal(expected, metadataLookup);
Assert.Equal(name, metadataLookup);
}
}
78 changes: 46 additions & 32 deletions LLama.Unittest/TemplateTests.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using System.Text;
using LLama.Common;
using LLama.Native;
using LLama.Extensions;

namespace LLama.Unittest;

public sealed class TemplateTests
: IDisposable
{
private readonly LLamaWeights _model;

public TemplateTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
Expand All @@ -18,12 +18,12 @@ public TemplateTests()
};
_model = LLamaWeights.LoadFromFile(@params);
}

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void BasicTemplate()
{
Expand All @@ -47,18 +47,10 @@ public void BasicTemplate()
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

var dest = templater.Apply();
Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
Expand Down Expand Up @@ -93,17 +85,10 @@ public void CustomTemplate()
Assert.Equal(4, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

var dest = templater.Apply();
Assert.Equal(4, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(4, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<start_of_turn>model\n" +
"hello<end_of_turn>\n" +
"<start_of_turn>user\n" +
Expand Down Expand Up @@ -143,17 +128,10 @@ public void BasicTemplateWithAddAssistant()
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

var dest = templater.Apply();
Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
Expand Down Expand Up @@ -249,4 +227,40 @@ public void RemoveOutOfRange()
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1));
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2));
}

[Fact]
public void Clear_ResetsTemplateState()
{
var templater = new LLamaTemplate(_model);
templater.Add("assistant", "1")
.Add("user", "2");

Assert.Equal(2, templater.Count);

templater.Clear();

Assert.Equal(0, templater.Count);

const string userData = nameof(userData);
templater.Add("user", userData);

// Generte the template string
var dest = templater.Apply();
var templateResult = Encoding.UTF8.GetString(dest);

const string expectedTemplate = $"<|im_start|>user\n{userData}<|im_end|>\n";
Assert.Equal(expectedTemplate, templateResult);
}

[Fact]
public void EndOTurnToken_ReturnsExpected()
{
Assert.Null(_model.Tokens.EndOfTurnToken);
}

[Fact]
public void EndOSpeechToken_ReturnsExpected()
{
Assert.Equal("</s>", _model.Tokens.EndOfSpeechToken);
}
}
Loading

0 comments on commit 2990b47

Please sign in to comment.