diff --git a/LLama.Examples/Examples/LLama3ChatSession.cs b/LLama.Examples/Examples/LLama3ChatSession.cs
index c9a32e0ce..01aa33cd6 100644
--- a/LLama.Examples/Examples/LLama3ChatSession.cs
+++ b/LLama.Examples/Examples/LLama3ChatSession.cs
@@ -1,38 +1,47 @@
-using LLama.Abstractions;
-using LLama.Common;
+using LLama.Common;
+using LLama.Transformers;
namespace LLama.Examples.Examples;
-// When using chatsession, it's a common case that you want to strip the role names
-// rather than display them. This example shows how to use transforms to strip them.
+///
+/// This sample shows a simple chatbot
+/// It's configured to use the default prompt template as provided by llama.cpp and supports
+/// models such as llama3, llama2, phi3, qwen1.5, etc.
+///
public class LLama3ChatSession
{
public static async Task Run()
{
- string modelPath = UserSettings.GetModelPath();
-
+ var modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 10
};
+
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
- ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
+ var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession session = new(executor, chatHistory);
- session.WithHistoryTransform(new LLama3HistoryTransform());
+
+ // add the default templator. If llama.cpp doesn't support the template by default,
+ // you'll need to write your own transformer to format the prompt correctly
+ session.WithHistoryTransform(new PromptTemplateTransformer(model, withAssistant: true));
+
+ // Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
- new string[] { "User:", "Assistant:", "�" },
+ [model.Tokens.EndOfTurnToken!, "�"],
redundancyLength: 5));
- InferenceParams inferenceParams = new InferenceParams()
+ var inferenceParams = new InferenceParams()
{
+ MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
Temperature = 0.6f,
- AntiPrompts = new List { "User:" }
+ AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string
};
Console.ForegroundColor = ConsoleColor.Yellow;
@@ -40,10 +49,15 @@ public static async Task Run()
// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
- string userInput = Console.ReadLine() ?? "";
+ Console.Write("User> ");
+ var userInput = Console.ReadLine() ?? "";
while (userInput != "exit")
{
+ Console.ForegroundColor = ConsoleColor.White;
+ Console.Write("Assistant> ");
+
+ // as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc
await foreach (
var text
in session.ChatAsync(
@@ -56,71 +70,8 @@ in session.ChatAsync(
Console.WriteLine();
Console.ForegroundColor = ConsoleColor.Green;
+ Console.Write("User> ");
userInput = Console.ReadLine() ?? "";
-
- Console.ForegroundColor = ConsoleColor.White;
- }
- }
-
- class LLama3HistoryTransform : IHistoryTransform
- {
- ///
- /// Convert a ChatHistory instance to plain text.
- ///
- /// The ChatHistory instance
- ///
- public string HistoryToText(ChatHistory history)
- {
- string res = Bos;
- foreach (var message in history.Messages)
- {
- res += EncodeMessage(message);
- }
- res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, ""));
- return res;
- }
-
- private string EncodeHeader(ChatHistory.Message message)
- {
- string res = StartHeaderId;
- res += message.AuthorRole.ToString();
- res += EndHeaderId;
- res += "\n\n";
- return res;
- }
-
- private string EncodeMessage(ChatHistory.Message message)
- {
- string res = EncodeHeader(message);
- res += message.Content;
- res += EndofTurn;
- return res;
}
-
- ///
- /// Converts plain text to a ChatHistory instance.
- ///
- /// The role for the author.
- /// The chat history as plain text.
- /// The updated history.
- public ChatHistory TextToHistory(AuthorRole role, string text)
- {
- return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) });
- }
-
- ///
- /// Copy the transform.
- ///
- ///
- public IHistoryTransform Clone()
- {
- return new LLama3HistoryTransform();
- }
-
- private const string StartHeaderId = "<|start_header_id|>";
- private const string EndHeaderId = "<|end_header_id|>";
- private const string Bos = "<|begin_of_text|>";
- private const string Eos = "<|end_of_text|>";
- private const string EndofTurn = "<|eot_id|>";
}
}
diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs
new file mode 100644
index 000000000..5211d4f6a
--- /dev/null
+++ b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs
@@ -0,0 +1,39 @@
+using System.Text;
+using LLama.Common;
+using LLama.Native;
+using LLama.Extensions;
+
+namespace LLama.Unittest.Native;
+
+public class SafeLlamaModelHandleTests
+{
+ private readonly LLamaWeights _model;
+ private readonly SafeLlamaModelHandle TestableHandle;
+
+ public SafeLlamaModelHandleTests()
+ {
+ var @params = new ModelParams(Constants.GenerativeModelPath)
+ {
+ ContextSize = 1,
+ GpuLayerCount = Constants.CIGpuLayerCount
+ };
+ _model = LLamaWeights.LoadFromFile(@params);
+
+ TestableHandle = _model.NativeHandle;
+ }
+
+ [Fact]
+ public void MetadataValByKey_ReturnsCorrectly()
+ {
+ const string key = "general.name";
+ var template = _model.NativeHandle.MetadataValueByKey(key);
+ var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span);
+
+ const string expected = "LLaMA v2";
+ Assert.Equal(expected, name);
+
+ var metadataLookup = _model.Metadata[key];
+ Assert.Equal(expected, metadataLookup);
+ Assert.Equal(name, metadataLookup);
+ }
+}
diff --git a/LLama.Unittest/TemplateTests.cs b/LLama.Unittest/TemplateTests.cs
index 3a5bb0cea..9520905b6 100644
--- a/LLama.Unittest/TemplateTests.cs
+++ b/LLama.Unittest/TemplateTests.cs
@@ -1,6 +1,6 @@
using System.Text;
using LLama.Common;
-using LLama.Native;
+using LLama.Extensions;
namespace LLama.Unittest;
@@ -8,7 +8,7 @@ public sealed class TemplateTests
: IDisposable
{
private readonly LLamaWeights _model;
-
+
public TemplateTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath)
@@ -18,12 +18,12 @@ public TemplateTests()
};
_model = LLamaWeights.LoadFromFile(@params);
}
-
+
public void Dispose()
{
_model.Dispose();
}
-
+
[Fact]
public void BasicTemplate()
{
@@ -47,18 +47,10 @@ public void BasicTemplate()
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);
- // Call once with empty array to discover length
- var length = templater.Apply(Array.Empty());
- var dest = new byte[length];
-
- Assert.Equal(8, templater.Count);
-
- // Call again to get contents
- length = templater.Apply(dest);
-
+ var dest = templater.Apply();
Assert.Equal(8, templater.Count);
- var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
+ var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
@@ -93,17 +85,10 @@ public void CustomTemplate()
Assert.Equal(4, templater.Count);
// Call once with empty array to discover length
- var length = templater.Apply(Array.Empty());
- var dest = new byte[length];
-
+ var dest = templater.Apply();
Assert.Equal(4, templater.Count);
- // Call again to get contents
- length = templater.Apply(dest);
-
- Assert.Equal(4, templater.Count);
-
- var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
+ var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "model\n" +
"hello\n" +
"user\n" +
@@ -143,17 +128,10 @@ public void BasicTemplateWithAddAssistant()
Assert.Equal(8, templater.Count);
// Call once with empty array to discover length
- var length = templater.Apply(Array.Empty());
- var dest = new byte[length];
-
+ var dest = templater.Apply();
Assert.Equal(8, templater.Count);
- // Call again to get contents
- length = templater.Apply(dest);
-
- Assert.Equal(8, templater.Count);
-
- var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
+ var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
@@ -249,4 +227,40 @@ public void RemoveOutOfRange()
Assert.Throws(() => templater.RemoveAt(-1));
Assert.Throws(() => templater.RemoveAt(2));
}
+
+ [Fact]
+ public void Clear_ResetsTemplateState()
+ {
+ var templater = new LLamaTemplate(_model);
+ templater.Add("assistant", "1")
+ .Add("user", "2");
+
+ Assert.Equal(2, templater.Count);
+
+ templater.Clear();
+
+ Assert.Equal(0, templater.Count);
+
+ const string userData = nameof(userData);
+ templater.Add("user", userData);
+
+ // Generte the template string
+ var dest = templater.Apply();
+ var templateResult = Encoding.UTF8.GetString(dest);
+
+ const string expectedTemplate = $"<|im_start|>user\n{userData}<|im_end|>\n";
+ Assert.Equal(expectedTemplate, templateResult);
+ }
+
+ [Fact]
+ public void EndOTurnToken_ReturnsExpected()
+ {
+ Assert.Null(_model.Tokens.EndOfTurnToken);
+ }
+
+ [Fact]
+ public void EndOSpeechToken_ReturnsExpected()
+ {
+ Assert.Equal("", _model.Tokens.EndOfSpeechToken);
+ }
}
\ No newline at end of file
diff --git a/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs
new file mode 100644
index 000000000..9b1255f9b
--- /dev/null
+++ b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs
@@ -0,0 +1,83 @@
+using LLama.Common;
+using LLama.Transformers;
+
+namespace LLama.Unittest.Transformers;
+
+public class PromptTemplateTransformerTests
+{
+ private readonly LLamaWeights _model;
+ private readonly PromptTemplateTransformer TestableTransformer;
+
+ public PromptTemplateTransformerTests()
+ {
+ var @params = new ModelParams(Constants.GenerativeModelPath)
+ {
+ ContextSize = 1,
+ GpuLayerCount = Constants.CIGpuLayerCount
+ };
+ _model = LLamaWeights.LoadFromFile(@params);
+
+ TestableTransformer = new PromptTemplateTransformer(_model, true);
+ }
+
+ [Fact]
+ public void HistoryToText_EncodesCorrectly()
+ {
+ const string userData = nameof(userData);
+ var template = TestableTransformer.HistoryToText(new ChatHistory(){
+ Messages = [new ChatHistory.Message(AuthorRole.User, userData)]
+ });
+
+ const string expected = "<|im_start|>user\n" +
+ $"{userData}<|im_end|>\n" +
+ "<|im_start|>assistant\n";
+ Assert.Equal(expected, template);
+ }
+
+ [Fact]
+ public void ToModelPrompt_FormatsCorrectly()
+ {
+ var templater = new LLamaTemplate(_model)
+ {
+ AddAssistant = true,
+ };
+
+ Assert.Equal(0, templater.Count);
+ templater.Add("assistant", "hello");
+ Assert.Equal(1, templater.Count);
+ templater.Add("user", "world");
+ Assert.Equal(2, templater.Count);
+ templater.Add("assistant", "111");
+ Assert.Equal(3, templater.Count);
+ templater.Add("user", "aaa");
+ Assert.Equal(4, templater.Count);
+ templater.Add("assistant", "222");
+ Assert.Equal(5, templater.Count);
+ templater.Add("user", "bbb");
+ Assert.Equal(6, templater.Count);
+ templater.Add("assistant", "333");
+ Assert.Equal(7, templater.Count);
+ templater.Add("user", "ccc");
+ Assert.Equal(8, templater.Count);
+
+ // Call once with empty array to discover length
+ var templateResult = PromptTemplateTransformer.ToModelPrompt(templater);
+ const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
+ "<|im_start|>user\nworld<|im_end|>\n" +
+ "<|im_start|>assistant\n" +
+ "111<|im_end|>" +
+ "\n<|im_start|>user\n" +
+ "aaa<|im_end|>\n" +
+ "<|im_start|>assistant\n" +
+ "222<|im_end|>\n" +
+ "<|im_start|>user\n" +
+ "bbb<|im_end|>\n" +
+ "<|im_start|>assistant\n" +
+ "333<|im_end|>\n" +
+ "<|im_start|>user\n" +
+ "ccc<|im_end|>\n" +
+ "<|im_start|>assistant\n";
+
+ Assert.Equal(expected, templateResult);
+ }
+}
diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index 3d5b5b616..2f667be0b 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -62,7 +62,7 @@ public class ChatSession
///
/// The input transform pipeline used in this session.
///
- public List InputTransformPipeline { get; set; } = new();
+ public List InputTransformPipeline { get; set; } = [];
///
/// The output transform used in this session.
diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs
index 44818a1ff..b2e429f83 100644
--- a/LLama/Common/InferenceParams.cs
+++ b/LLama/Common/InferenceParams.cs
@@ -1,5 +1,4 @@
using LLama.Abstractions;
-using System;
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;
@@ -31,7 +30,7 @@ public record InferenceParams
///
/// Sequences where the model will stop generating further tokens.
///
- public IReadOnlyList AntiPrompts { get; set; } = Array.Empty();
+ public IReadOnlyList AntiPrompts { get; set; } = [];
///
public int TopK { get; set; } = 40;
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index 263ab2716..e01a40ccc 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -307,7 +307,7 @@ public virtual async IAsyncEnumerable InferAsync(string text, IInference
var args = new InferStateArgs
{
- Antiprompts = inferenceParams.AntiPrompts.ToList(),
+ Antiprompts = [.. inferenceParams.AntiPrompts],
RemainedTokens = inferenceParams.MaxTokens,
ReturnValue = false,
WaitForInput = false,
@@ -359,7 +359,7 @@ public virtual async Task PrefillPromptAsync(string prompt)
};
var args = new InferStateArgs
{
- Antiprompts = new List(),
+ Antiprompts = [],
RemainedTokens = 0,
ReturnValue = false,
WaitForInput = true,
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 226b18ef9..869a0bb44 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -123,7 +123,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
}
else
{
- PreprocessLlava(text, args, true );
+ PreprocessLlava(text, args, true);
}
}
else
diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs
index 0677ddb43..fb2268ac2 100644
--- a/LLama/LLamaTemplate.cs
+++ b/LLama/LLamaTemplate.cs
@@ -13,8 +13,6 @@ namespace LLama;
public sealed class LLamaTemplate
{
#region private state
- private static readonly Encoding Encoding = Encoding.UTF8;
-
///
/// The model this template is for. May be null if a custom template was supplied to the constructor.
///
@@ -28,12 +26,12 @@ public sealed class LLamaTemplate
///
/// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times.
///
- private readonly Dictionary> _roleCache = new();
+ private readonly Dictionary> _roleCache = [];
///
/// Array of messages. The property indicates how many messages there are
///
- private TextMessage?[] _messages = new TextMessage[4];
+ private TextMessage[] _messages = new TextMessage[4];
///
/// Backing field for
@@ -53,7 +51,7 @@ public sealed class LLamaTemplate
///
/// Result bytes of last call to
///
- private byte[] _result = Array.Empty();
+ private byte[] _result = [];
///
/// Indicates if this template has been modified and needs regenerating
@@ -62,6 +60,11 @@ public sealed class LLamaTemplate
#endregion
#region properties
+ ///
+ /// The encoding algorithm to use
+ ///
+ public static readonly Encoding Encoding = Encoding.UTF8;
+
///
/// Number of messages added to this template
///
@@ -189,14 +192,28 @@ public LLamaTemplate RemoveAt(int index)
return this;
}
+
+ ///
+ /// Remove all messages from the template and resets internal state to accept/generate new messages
+ ///
+ public void Clear()
+ {
+ _messages = new TextMessage[4];
+ Count = 0;
+
+ _resultLength = 0;
+ _result = [];
+ _nativeChatMessages = new LLamaChatMessage[4];
+
+ _dirty = true;
+ }
#endregion
///
/// Apply the template to the messages and write it into the output buffer
///
- /// Destination to write template bytes into
- /// The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer
- public int Apply(Memory dest)
+ /// A span over the buffer that holds the applied template
+ public ReadOnlySpan Apply()
{
// Recalculate template if necessary
if (_dirty)
@@ -213,7 +230,6 @@ public int Apply(Memory dest)
for (var i = 0; i < Count; i++)
{
ref var m = ref _messages[i]!;
- Debug.Assert(m != null);
totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length;
// Pin byte arrays in place
@@ -233,7 +249,6 @@ public int Apply(Memory dest)
var output = ArrayPool.Shared.Rent(Math.Max(32, totalInputBytes * 2));
try
{
-
// Run templater and discover true length
var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output);
@@ -264,8 +279,7 @@ public int Apply(Memory dest)
}
// Now that the template has been applied and is in the result buffer, copy it to the dest
- _result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span);
- return _resultLength;
+ return _result.AsSpan(0, _resultLength);
unsafe int ApplyInternal(Span messages, byte[] output)
{
@@ -281,7 +295,7 @@ unsafe int ApplyInternal(Span messages, byte[] output)
///
/// A message that has been added to a template
///
- public sealed class TextMessage
+ public readonly struct TextMessage
{
///
/// The "role" string for this message
diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs
index ce712b724..8646e4d93 100644
--- a/LLama/LLamaWeights.cs
+++ b/LLama/LLamaWeights.cs
@@ -5,7 +5,6 @@
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Exceptions;
-using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.Logging;
diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs
index 64d263a7a..dd8bca1e2 100644
--- a/LLama/Native/LLamaToken.cs
+++ b/LLama/Native/LLamaToken.cs
@@ -10,6 +10,11 @@ namespace LLama.Native;
[DebuggerDisplay("{Value}")]
public readonly record struct LLamaToken
{
+ ///
+ /// Token Value used when token is inherently null
+ ///
+ public static readonly LLamaToken InvalidToken = -1;
+
///
/// The raw value
///
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index f54a8680b..3812a3517 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -19,6 +19,10 @@ public sealed class SafeLLamaContextHandle
///
public int VocabCount => ThrowIfDisposed().VocabCount;
+ ///
+ /// The underlying vocabulary for the model
+ ///
+ ///
public LLamaVocabType LLamaVocabType => ThrowIfDisposed().VocabType;
///
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index f24cfe5fd..1597908e3 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -6,7 +6,6 @@
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
-using LLama.Extensions;
namespace LLama.Native
{
@@ -221,11 +220,30 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
///
///
///
- ///
- ///
+ ///
/// The length of the string on success, or -1 on failure
- [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);
+ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string key, Span dest)
+ {
+ var bytesCount = Encoding.UTF8.GetByteCount(key);
+ var bytes = ArrayPool.Shared.Rent(bytesCount);
+
+ unsafe
+ {
+ fixed (char* keyPtr = key)
+ fixed (byte* bytesPtr = bytes)
+ fixed (byte* destPtr = dest)
+ {
+ // Convert text into bytes
+ Encoding.UTF8.GetBytes(keyPtr, key.Length, bytesPtr, bytesCount);
+
+ return llama_model_meta_val_str_native(model, bytesPtr, destPtr, dest.Length);
+ }
+ }
+
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str")]
+ static extern unsafe int llama_model_meta_val_str_native(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);
+ }
+
///
/// Get the number of tokens in the model vocabulary
@@ -461,8 +479,8 @@ internal Span TokensToSpan(IReadOnlyList tokens, Span de
public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
// Early exit if there's no work to do
- if (text == "" && !add_bos)
- return Array.Empty();
+ if (text == string.Empty && !add_bos)
+ return [];
// Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text);
@@ -484,7 +502,7 @@ public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding e
var tokens = new LLamaToken[count];
fixed (LLamaToken* tokensPtr = tokens)
{
- NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
+ _ = NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens;
}
}
@@ -510,6 +528,26 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
#endregion
#region metadata
+ ///
+ /// Get the metadata value for the given key
+ ///
+ /// The key to fetch
+ /// The value, null if there is no such key
+ public Memory? MetadataValueByKey(string key)
+ {
+ // Check if the key exists, without getting any bytes of data
+ var keyLength = llama_model_meta_val_str(this, key, []);
+ if (keyLength < 0)
+ return null;
+
+ // get a buffer large enough to hold it
+ var buffer = new byte[keyLength + 1];
+ keyLength = llama_model_meta_val_str(this, key, buffer);
+ Debug.Assert(keyLength >= 0);
+
+ return buffer.AsMemory().Slice(0,keyLength);
+ }
+
///
/// Get the metadata key for the given index
///
@@ -576,13 +614,39 @@ internal IReadOnlyDictionary ReadMetadata()
///
/// Get tokens for a model
///
- public class ModelTokens
+ public sealed class ModelTokens
{
private readonly SafeLlamaModelHandle _model;
+ private readonly string? _eot;
+ private readonly string? _eos;
internal ModelTokens(SafeLlamaModelHandle model)
{
_model = model;
+ _eot = LLamaTokenToString(EOT, true);
+ _eos = LLamaTokenToString(EOS, true);
+ }
+
+ private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
+ {
+ const int buffSize = 32;
+ Span buff = stackalloc byte[buffSize];
+ var tokenLength = _model.TokenToSpan(token ?? LLamaToken.InvalidToken, buff, special: isSpecialToken);
+
+ if (tokenLength <= 0)
+ {
+ return null;
+ }
+
+ // if the original buffer wasn't large enough, create a new one
+ if (tokenLength > buffSize)
+ {
+ buff = stackalloc byte[(int)tokenLength];
+ _ = _model.TokenToSpan(token ?? LLamaToken.InvalidToken, buff, special: isSpecialToken);
+ }
+
+ var slice = buff.Slice(0, (int)tokenLength);
+ return Encoding.UTF8.GetStringFromSpan(slice);
}
private static LLamaToken? Normalize(LLamaToken token)
@@ -599,6 +663,11 @@ internal ModelTokens(SafeLlamaModelHandle model)
/// Get the End of Sentence token for this model
///
public LLamaToken? EOS => Normalize(llama_token_eos(_model));
+
+ ///
+ /// The textual representation of the end of speech special token for this model
+ ///
+ public string? EndOfSpeechToken => _eos;
///
/// Get the newline token for this model
@@ -635,6 +704,11 @@ internal ModelTokens(SafeLlamaModelHandle model)
///
public LLamaToken? EOT => Normalize(llama_token_eot(_model));
+ ///
+ /// Returns the string representation of this model's end_of_text token
+ ///
+ public string? EndOfTurnToken => _eot;
+
///
/// Check if the given token should end generation
///
diff --git a/LLama/Transformers/PromptTemplateTransformer.cs b/LLama/Transformers/PromptTemplateTransformer.cs
new file mode 100644
index 000000000..3543f9a1a
--- /dev/null
+++ b/LLama/Transformers/PromptTemplateTransformer.cs
@@ -0,0 +1,67 @@
+using System;
+using System.Text;
+using LLama.Abstractions;
+using LLama.Common;
+
+namespace LLama.Transformers;
+
+///
+/// A prompt formatter that will use llama.cpp's template formatter
+/// If your model is not supported, you will need to define your own formatter according the cchat prompt specification for your model
+///
+public class PromptTemplateTransformer(LLamaWeights model,
+ bool withAssistant = true) : IHistoryTransform
+{
+ private readonly LLamaWeights _model = model;
+ private readonly bool _withAssistant = withAssistant;
+
+ ///
+ public string HistoryToText(ChatHistory history)
+ {
+ var template = new LLamaTemplate(_model.NativeHandle)
+ {
+ AddAssistant = _withAssistant,
+ };
+
+ // encode each message and return the final prompt
+ foreach (var message in history.Messages)
+ {
+ template.Add(message.AuthorRole.ToString().ToLowerInvariant(), message.Content);
+ }
+ return ToModelPrompt(template);
+ }
+
+ ///
+ public ChatHistory TextToHistory(AuthorRole role, string text)
+ {
+ return new ChatHistory([new ChatHistory.Message(role, text)]);
+ }
+
+ ///
+ public IHistoryTransform Clone()
+ {
+ // need to preserve history?
+ return new PromptTemplateTransformer(_model);
+ }
+
+ #region utils
+ ///
+ /// Apply the template to the messages and return the resulting prompt as a string
+ ///
+ ///
+ /// The formatted template string as defined by the model
+ public static string ToModelPrompt(LLamaTemplate template)
+ {
+ // Apply the template to update state and get data length
+ var templateBuffer = template.Apply();
+
+ // convert the resulting buffer to a string
+#if NET6_0_OR_GREATER
+ return LLamaTemplate.Encoding.GetString(templateBuffer);
+#endif
+
+ // need the ToArray call for netstandard -- avoided in newer runtimes
+ return LLamaTemplate.Encoding.GetString(templateBuffer.ToArray());
+ }
+ #endregion utils
+}