Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wip2 #1

Merged
merged 7 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 27 additions & 76 deletions LLama.Examples/Examples/LLama3ChatSession.cs
Original file line number Diff line number Diff line change
@@ -1,49 +1,63 @@
using LLama.Abstractions;
using LLama.Common;
using LLama.Common;
using LLama.Transformers;

namespace LLama.Examples.Examples;

// When using chatsession, it's a common case that you want to strip the role names
// rather than display them. This example shows how to use transforms to strip them.
/// <summary>
/// This sample shows a simple chatbot
/// It's configured to use the default prompt template as provided by llama.cpp and supports
/// models such as llama3, llama2, phi3, qwen1.5, etc.
/// </summary>
public class LLama3ChatSession
{
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 10
};

using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);
session.WithHistoryTransform(new LLama3HistoryTransform());

// add the default templator. If llama.cpp doesn't support the template by default,
// you'll need to write your own transformer to format the prompt correctly
session.WithHistoryTransform(new PromptTemplateTransformer(model, withAssistant: true));

// Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed somtimes
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:", "�" },
[model.Tokens.EndOfTurnToken!, "�"],
redundancyLength: 5));

InferenceParams inferenceParams = new InferenceParams()
var inferenceParams = new InferenceParams()
{
MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
Temperature = 0.6f,
AntiPrompts = new List<string> { "User:" }
AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";
Console.Write("User> ");
var userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Assistant> ");

// as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc
await foreach (
var text
in session.ChatAsync(
Expand All @@ -56,71 +70,8 @@ in session.ChatAsync(
Console.WriteLine();

Console.ForegroundColor = ConsoleColor.Green;
Console.Write("User> ");
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}

class LLama3HistoryTransform : IHistoryTransform
{
/// <summary>
/// Convert a ChatHistory instance to plain text.
/// </summary>
/// <param name="history">The ChatHistory instance</param>
/// <returns></returns>
public string HistoryToText(ChatHistory history)
{
string res = Bos;
foreach (var message in history.Messages)
{
res += EncodeMessage(message);
}
res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, ""));
return res;
}

private string EncodeHeader(ChatHistory.Message message)
{
string res = StartHeaderId;
res += message.AuthorRole.ToString();
res += EndHeaderId;
res += "\n\n";
return res;
}

private string EncodeMessage(ChatHistory.Message message)
{
string res = EncodeHeader(message);
res += message.Content;
res += EndofTurn;
return res;
}

/// <summary>
/// Converts plain text to a ChatHistory instance.
/// </summary>
/// <param name="role">The role for the author.</param>
/// <param name="text">The chat history as plain text.</param>
/// <returns>The updated history.</returns>
public ChatHistory TextToHistory(AuthorRole role, string text)
{
return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) });
}

/// <summary>
/// Copy the transform.
/// </summary>
/// <returns></returns>
public IHistoryTransform Clone()
{
return new LLama3HistoryTransform();
}

private const string StartHeaderId = "<|start_header_id|>";
private const string EndHeaderId = "<|end_header_id|>";
private const string Bos = "<|begin_of_text|>";
private const string Eos = "<|end_of_text|>";
private const string EndofTurn = "<|eot_id|>";
}
}
39 changes: 39 additions & 0 deletions LLama.Unittest/Native/SafeLlamaModelHandleTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using System.Text;
using LLama.Common;
using LLama.Native;
using LLama.Extensions;

namespace LLama.Unittest.Native;

public class SafeLlamaModelHandleTests
{
private readonly LLamaWeights _model;
private readonly SafeLlamaModelHandle TestableHandle;

public SafeLlamaModelHandleTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 1,
GpuLayerCount = Constants.CIGpuLayerCount
};
_model = LLamaWeights.LoadFromFile(@params);

TestableHandle = _model.NativeHandle;
}

[Fact]
public void MetadataValByKey_ReturnsCorrectly()
{
const string key = "general.name";
var template = _model.NativeHandle.MetadataValueByKey(key);
var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span);

const string expected = "LLaMA v2";
Assert.Equal(expected, name);

var metadataLookup = _model.Metadata[key];
Assert.Equal(expected, metadataLookup);
Assert.Equal(name, metadataLookup);
}
}
67 changes: 63 additions & 4 deletions LLama.Unittest/TemplateTests.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using System.Text;
using LLama.Common;
using LLama.Native;
using LLama.Extensions;

namespace LLama.Unittest;

public sealed class TemplateTests
: IDisposable
{
private readonly LLamaWeights _model;

public TemplateTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
Expand All @@ -18,12 +18,12 @@ public TemplateTests()
};
_model = LLamaWeights.LoadFromFile(@params);
}

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void BasicTemplate()
{
Expand Down Expand Up @@ -173,6 +173,53 @@ public void BasicTemplateWithAddAssistant()
Assert.Equal(expected, templateResult);
}

[Fact]
public void ToModelPrompt_FormatsCorrectly()
{
var templater = new LLamaTemplate(_model)
{
AddAssistant = true,
};

Assert.Equal(0, templater.Count);
templater.Add("assistant", "hello");
Assert.Equal(1, templater.Count);
templater.Add("user", "world");
Assert.Equal(2, templater.Count);
templater.Add("assistant", "111");
Assert.Equal(3, templater.Count);
templater.Add("user", "aaa");
Assert.Equal(4, templater.Count);
templater.Add("assistant", "222");
Assert.Equal(5, templater.Count);
templater.Add("user", "bbb");
Assert.Equal(6, templater.Count);
templater.Add("assistant", "333");
Assert.Equal(7, templater.Count);
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var templateResult = templater.ToModelPrompt();
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
"111<|im_end|>" +
"\n<|im_start|>user\n" +
"aaa<|im_end|>\n" +
"<|im_start|>assistant\n" +
"222<|im_end|>\n" +
"<|im_start|>user\n" +
"bbb<|im_end|>\n" +
"<|im_start|>assistant\n" +
"333<|im_end|>\n" +
"<|im_start|>user\n" +
"ccc<|im_end|>\n" +
"<|im_start|>assistant\n";

Assert.Equal(expected, templateResult);
}

[Fact]
public void GetOutOfRangeThrows()
{
Expand Down Expand Up @@ -249,4 +296,16 @@ public void RemoveOutOfRange()
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(-1));
Assert.Throws<ArgumentOutOfRangeException>(() => templater.RemoveAt(2));
}

[Fact]
public void EndOTurnToken_ReturnsExpected()
{
Assert.Null(_model.Tokens.EndOfTurnToken);
}

[Fact]
public void EndOSpeechToken_ReturnsExpected()
{
Assert.Equal("</s>", _model.Tokens.EndOfSpeechToken);
}
}
36 changes: 36 additions & 0 deletions LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using LLama.Common;
using LLama.Transformers;

namespace LLama.Unittest.Transformers;

public class PromptTemplateTransformerTests
{
private readonly LLamaWeights _model;
private readonly PromptTemplateTransformer TestableTransformer;

public PromptTemplateTransformerTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 1,
GpuLayerCount = Constants.CIGpuLayerCount
};
_model = LLamaWeights.LoadFromFile(@params);

TestableTransformer = new PromptTemplateTransformer(_model, true);
}

[Fact]
public void HistoryToText_EncodesCorrectly()
{
const string userData = nameof(userData);
var template = TestableTransformer.HistoryToText(new ChatHistory(){
Messages = [new ChatHistory.Message(AuthorRole.User, userData)]
});

const string expected = "<|im_start|>user\n" +
$"{userData}<|im_end|>\n" +
"<|im_start|>assistant\n";
Assert.Equal(expected, template);
}
}
2 changes: 1 addition & 1 deletion LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public class ChatSession
/// <summary>
/// The input transform pipeline used in this session.
/// </summary>
public List<ITextTransform> InputTransformPipeline { get; set; } = new();
public List<ITextTransform> InputTransformPipeline { get; set; } = [];

/// <summary>
/// The output transform used in this session.
Expand Down
3 changes: 1 addition & 2 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;
Expand Down Expand Up @@ -31,7 +30,7 @@ public record InferenceParams
/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
public IReadOnlyList<string> AntiPrompts { get; set; } = [];

/// <inheritdoc />
public int TopK { get; set; } = 40;
Expand Down
4 changes: 2 additions & 2 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string text, IInference

var args = new InferStateArgs
{
Antiprompts = inferenceParams.AntiPrompts.ToList(),
Antiprompts = [.. inferenceParams.AntiPrompts],
RemainedTokens = inferenceParams.MaxTokens,
ReturnValue = false,
WaitForInput = false,
Expand Down Expand Up @@ -359,7 +359,7 @@ public virtual async Task PrefillPromptAsync(string prompt)
};
var args = new InferStateArgs
{
Antiprompts = new List<string>(),
Antiprompts = [],
RemainedTokens = 0,
ReturnValue = false,
WaitForInput = true,
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
}
else
{
PreprocessLlava(text, args, true );
PreprocessLlava(text, args, true);
}
}
else
Expand Down
Loading