diff --git a/LLama.Examples/Examples/LLama3ChatSession.cs b/LLama.Examples/Examples/LLama3ChatSession.cs index c9a32e0ce..01aa33cd6 100644 --- a/LLama.Examples/Examples/LLama3ChatSession.cs +++ b/LLama.Examples/Examples/LLama3ChatSession.cs @@ -1,38 +1,47 @@ -using LLama.Abstractions; -using LLama.Common; +using LLama.Common; +using LLama.Transformers; namespace LLama.Examples.Examples; -// When using chatsession, it's a common case that you want to strip the role names -// rather than display them. This example shows how to use transforms to strip them. +/// +/// This sample shows a simple chatbot +/// It's configured to use the default prompt template as provided by llama.cpp and supports +/// models such as llama3, llama2, phi3, qwen1.5, etc. +/// public class LLama3ChatSession { public static async Task Run() { - string modelPath = UserSettings.GetModelPath(); - + var modelPath = UserSettings.GetModelPath(); var parameters = new ModelParams(modelPath) { Seed = 1337, GpuLayerCount = 10 }; + using var model = LLamaWeights.LoadFromFile(parameters); using var context = model.CreateContext(parameters); var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); - session.WithHistoryTransform(new LLama3HistoryTransform()); + + // add the default templator. If llama.cpp doesn't support the template by default, + // you'll need to write your own transformer to format the prompt correctly + session.WithHistoryTransform(new PromptTemplateTransformer(model, withAssistant: true)); + + // Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( - new string[] { "User:", "Assistant:", "�" }, + [model.Tokens.EndOfTurnToken!, "�"], redundancyLength: 5)); - InferenceParams inferenceParams = new InferenceParams() + var inferenceParams = new InferenceParams() { + MaxTokens = -1, // keep generating tokens until the anti prompt is encountered Temperature = 0.6f, - AntiPrompts = new List { "User:" } + AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string }; Console.ForegroundColor = ConsoleColor.Yellow; @@ -40,10 +49,15 @@ public static async Task Run() // show the prompt Console.ForegroundColor = ConsoleColor.Green; - string userInput = Console.ReadLine() ?? ""; + Console.Write("User> "); + var userInput = Console.ReadLine() ?? ""; while (userInput != "exit") { + Console.ForegroundColor = ConsoleColor.White; + Console.Write("Assistant> "); + + // as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc await foreach ( var text in session.ChatAsync( @@ -56,71 +70,8 @@ in session.ChatAsync( Console.WriteLine(); Console.ForegroundColor = ConsoleColor.Green; + Console.Write("User> "); userInput = Console.ReadLine() ?? ""; - - Console.ForegroundColor = ConsoleColor.White; - } - } - - class LLama3HistoryTransform : IHistoryTransform - { - /// - /// Convert a ChatHistory instance to plain text. - /// - /// The ChatHistory instance - /// - public string HistoryToText(ChatHistory history) - { - string res = Bos; - foreach (var message in history.Messages) - { - res += EncodeMessage(message); - } - res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, "")); - return res; - } - - private string EncodeHeader(ChatHistory.Message message) - { - string res = StartHeaderId; - res += message.AuthorRole.ToString(); - res += EndHeaderId; - res += "\n\n"; - return res; - } - - private string EncodeMessage(ChatHistory.Message message) - { - string res = EncodeHeader(message); - res += message.Content; - res += EndofTurn; - return res; } - - /// - /// Converts plain text to a ChatHistory instance. - /// - /// The role for the author. - /// The chat history as plain text. - /// The updated history. - public ChatHistory TextToHistory(AuthorRole role, string text) - { - return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) }); - } - - /// - /// Copy the transform. - /// - /// - public IHistoryTransform Clone() - { - return new LLama3HistoryTransform(); - } - - private const string StartHeaderId = "<|start_header_id|>"; - private const string EndHeaderId = "<|end_header_id|>"; - private const string Bos = "<|begin_of_text|>"; - private const string Eos = "<|end_of_text|>"; - private const string EndofTurn = "<|eot_id|>"; } } diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs new file mode 100644 index 000000000..5211d4f6a --- /dev/null +++ b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs @@ -0,0 +1,39 @@ +using System.Text; +using LLama.Common; +using LLama.Native; +using LLama.Extensions; + +namespace LLama.Unittest.Native; + +public class SafeLlamaModelHandleTests +{ + private readonly LLamaWeights _model; + private readonly SafeLlamaModelHandle TestableHandle; + + public SafeLlamaModelHandleTests() + { + var @params = new ModelParams(Constants.GenerativeModelPath) + { + ContextSize = 1, + GpuLayerCount = Constants.CIGpuLayerCount + }; + _model = LLamaWeights.LoadFromFile(@params); + + TestableHandle = _model.NativeHandle; + } + + [Fact] + public void MetadataValByKey_ReturnsCorrectly() + { + const string key = "general.name"; + var template = _model.NativeHandle.MetadataValueByKey(key); + var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span); + + const string expected = "LLaMA v2"; + Assert.Equal(expected, name); + + var metadataLookup = _model.Metadata[key]; + Assert.Equal(expected, metadataLookup); + Assert.Equal(name, metadataLookup); + } +} diff --git a/LLama.Unittest/TemplateTests.cs b/LLama.Unittest/TemplateTests.cs index 3a5bb0cea..9520905b6 100644 --- a/LLama.Unittest/TemplateTests.cs +++ b/LLama.Unittest/TemplateTests.cs @@ -1,6 +1,6 @@ using System.Text; using LLama.Common; -using LLama.Native; +using LLama.Extensions; namespace LLama.Unittest; @@ -8,7 +8,7 @@ public sealed class TemplateTests : IDisposable { private readonly LLamaWeights _model; - + public TemplateTests() { var @params = new ModelParams(Constants.GenerativeModelPath) @@ -18,12 +18,12 @@ public TemplateTests() }; _model = LLamaWeights.LoadFromFile(@params); } - + public void Dispose() { _model.Dispose(); } - + [Fact] public void BasicTemplate() { @@ -47,18 +47,10 @@ public void BasicTemplate() templater.Add("user", "ccc"); Assert.Equal(8, templater.Count); - // Call once with empty array to discover length - var length = templater.Apply(Array.Empty()); - var dest = new byte[length]; - - Assert.Equal(8, templater.Count); - - // Call again to get contents - length = templater.Apply(dest); - + var dest = templater.Apply(); Assert.Equal(8, templater.Count); - var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + var templateResult = Encoding.UTF8.GetString(dest); const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + "<|im_start|>user\nworld<|im_end|>\n" + "<|im_start|>assistant\n" + @@ -93,17 +85,10 @@ public void CustomTemplate() Assert.Equal(4, templater.Count); // Call once with empty array to discover length - var length = templater.Apply(Array.Empty()); - var dest = new byte[length]; - + var dest = templater.Apply(); Assert.Equal(4, templater.Count); - // Call again to get contents - length = templater.Apply(dest); - - Assert.Equal(4, templater.Count); - - var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + var templateResult = Encoding.UTF8.GetString(dest); const string expected = "model\n" + "hello\n" + "user\n" + @@ -143,17 +128,10 @@ public void BasicTemplateWithAddAssistant() Assert.Equal(8, templater.Count); // Call once with empty array to discover length - var length = templater.Apply(Array.Empty()); - var dest = new byte[length]; - + var dest = templater.Apply(); Assert.Equal(8, templater.Count); - // Call again to get contents - length = templater.Apply(dest); - - Assert.Equal(8, templater.Count); - - var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + var templateResult = Encoding.UTF8.GetString(dest); const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + "<|im_start|>user\nworld<|im_end|>\n" + "<|im_start|>assistant\n" + @@ -249,4 +227,40 @@ public void RemoveOutOfRange() Assert.Throws(() => templater.RemoveAt(-1)); Assert.Throws(() => templater.RemoveAt(2)); } + + [Fact] + public void Clear_ResetsTemplateState() + { + var templater = new LLamaTemplate(_model); + templater.Add("assistant", "1") + .Add("user", "2"); + + Assert.Equal(2, templater.Count); + + templater.Clear(); + + Assert.Equal(0, templater.Count); + + const string userData = nameof(userData); + templater.Add("user", userData); + + // Generte the template string + var dest = templater.Apply(); + var templateResult = Encoding.UTF8.GetString(dest); + + const string expectedTemplate = $"<|im_start|>user\n{userData}<|im_end|>\n"; + Assert.Equal(expectedTemplate, templateResult); + } + + [Fact] + public void EndOTurnToken_ReturnsExpected() + { + Assert.Null(_model.Tokens.EndOfTurnToken); + } + + [Fact] + public void EndOSpeechToken_ReturnsExpected() + { + Assert.Equal("", _model.Tokens.EndOfSpeechToken); + } } \ No newline at end of file diff --git a/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs new file mode 100644 index 000000000..9b1255f9b --- /dev/null +++ b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs @@ -0,0 +1,83 @@ +using LLama.Common; +using LLama.Transformers; + +namespace LLama.Unittest.Transformers; + +public class PromptTemplateTransformerTests +{ + private readonly LLamaWeights _model; + private readonly PromptTemplateTransformer TestableTransformer; + + public PromptTemplateTransformerTests() + { + var @params = new ModelParams(Constants.GenerativeModelPath) + { + ContextSize = 1, + GpuLayerCount = Constants.CIGpuLayerCount + }; + _model = LLamaWeights.LoadFromFile(@params); + + TestableTransformer = new PromptTemplateTransformer(_model, true); + } + + [Fact] + public void HistoryToText_EncodesCorrectly() + { + const string userData = nameof(userData); + var template = TestableTransformer.HistoryToText(new ChatHistory(){ + Messages = [new ChatHistory.Message(AuthorRole.User, userData)] + }); + + const string expected = "<|im_start|>user\n" + + $"{userData}<|im_end|>\n" + + "<|im_start|>assistant\n"; + Assert.Equal(expected, template); + } + + [Fact] + public void ToModelPrompt_FormatsCorrectly() + { + var templater = new LLamaTemplate(_model) + { + AddAssistant = true, + }; + + Assert.Equal(0, templater.Count); + templater.Add("assistant", "hello"); + Assert.Equal(1, templater.Count); + templater.Add("user", "world"); + Assert.Equal(2, templater.Count); + templater.Add("assistant", "111"); + Assert.Equal(3, templater.Count); + templater.Add("user", "aaa"); + Assert.Equal(4, templater.Count); + templater.Add("assistant", "222"); + Assert.Equal(5, templater.Count); + templater.Add("user", "bbb"); + Assert.Equal(6, templater.Count); + templater.Add("assistant", "333"); + Assert.Equal(7, templater.Count); + templater.Add("user", "ccc"); + Assert.Equal(8, templater.Count); + + // Call once with empty array to discover length + var templateResult = PromptTemplateTransformer.ToModelPrompt(templater); + const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + + "<|im_start|>user\nworld<|im_end|>\n" + + "<|im_start|>assistant\n" + + "111<|im_end|>" + + "\n<|im_start|>user\n" + + "aaa<|im_end|>\n" + + "<|im_start|>assistant\n" + + "222<|im_end|>\n" + + "<|im_start|>user\n" + + "bbb<|im_end|>\n" + + "<|im_start|>assistant\n" + + "333<|im_end|>\n" + + "<|im_start|>user\n" + + "ccc<|im_end|>\n" + + "<|im_start|>assistant\n"; + + Assert.Equal(expected, templateResult); + } +} diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 3d5b5b616..2f667be0b 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -62,7 +62,7 @@ public class ChatSession /// /// The input transform pipeline used in this session. /// - public List InputTransformPipeline { get; set; } = new(); + public List InputTransformPipeline { get; set; } = []; /// /// The output transform used in this session. diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 44818a1ff..b2e429f83 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -1,5 +1,4 @@ using LLama.Abstractions; -using System; using System.Collections.Generic; using LLama.Native; using LLama.Sampling; @@ -31,7 +30,7 @@ public record InferenceParams /// /// Sequences where the model will stop generating further tokens. /// - public IReadOnlyList AntiPrompts { get; set; } = Array.Empty(); + public IReadOnlyList AntiPrompts { get; set; } = []; /// public int TopK { get; set; } = 40; diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 263ab2716..e01a40ccc 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -307,7 +307,7 @@ public virtual async IAsyncEnumerable InferAsync(string text, IInference var args = new InferStateArgs { - Antiprompts = inferenceParams.AntiPrompts.ToList(), + Antiprompts = [.. inferenceParams.AntiPrompts], RemainedTokens = inferenceParams.MaxTokens, ReturnValue = false, WaitForInput = false, @@ -359,7 +359,7 @@ public virtual async Task PrefillPromptAsync(string prompt) }; var args = new InferStateArgs { - Antiprompts = new List(), + Antiprompts = [], RemainedTokens = 0, ReturnValue = false, WaitForInput = true, diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 226b18ef9..869a0bb44 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -123,7 +123,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) } else { - PreprocessLlava(text, args, true ); + PreprocessLlava(text, args, true); } } else diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs index 0677ddb43..fb2268ac2 100644 --- a/LLama/LLamaTemplate.cs +++ b/LLama/LLamaTemplate.cs @@ -13,8 +13,6 @@ namespace LLama; public sealed class LLamaTemplate { #region private state - private static readonly Encoding Encoding = Encoding.UTF8; - /// /// The model this template is for. May be null if a custom template was supplied to the constructor. /// @@ -28,12 +26,12 @@ public sealed class LLamaTemplate /// /// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times. /// - private readonly Dictionary> _roleCache = new(); + private readonly Dictionary> _roleCache = []; /// /// Array of messages. The property indicates how many messages there are /// - private TextMessage?[] _messages = new TextMessage[4]; + private TextMessage[] _messages = new TextMessage[4]; /// /// Backing field for @@ -53,7 +51,7 @@ public sealed class LLamaTemplate /// /// Result bytes of last call to /// - private byte[] _result = Array.Empty(); + private byte[] _result = []; /// /// Indicates if this template has been modified and needs regenerating @@ -62,6 +60,11 @@ public sealed class LLamaTemplate #endregion #region properties + /// + /// The encoding algorithm to use + /// + public static readonly Encoding Encoding = Encoding.UTF8; + /// /// Number of messages added to this template /// @@ -189,14 +192,28 @@ public LLamaTemplate RemoveAt(int index) return this; } + + /// + /// Remove all messages from the template and resets internal state to accept/generate new messages + /// + public void Clear() + { + _messages = new TextMessage[4]; + Count = 0; + + _resultLength = 0; + _result = []; + _nativeChatMessages = new LLamaChatMessage[4]; + + _dirty = true; + } #endregion /// /// Apply the template to the messages and write it into the output buffer /// - /// Destination to write template bytes into - /// The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer - public int Apply(Memory dest) + /// A span over the buffer that holds the applied template + public ReadOnlySpan Apply() { // Recalculate template if necessary if (_dirty) @@ -213,7 +230,6 @@ public int Apply(Memory dest) for (var i = 0; i < Count; i++) { ref var m = ref _messages[i]!; - Debug.Assert(m != null); totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length; // Pin byte arrays in place @@ -233,7 +249,6 @@ public int Apply(Memory dest) var output = ArrayPool.Shared.Rent(Math.Max(32, totalInputBytes * 2)); try { - // Run templater and discover true length var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); @@ -264,8 +279,7 @@ public int Apply(Memory dest) } // Now that the template has been applied and is in the result buffer, copy it to the dest - _result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span); - return _resultLength; + return _result.AsSpan(0, _resultLength); unsafe int ApplyInternal(Span messages, byte[] output) { @@ -281,7 +295,7 @@ unsafe int ApplyInternal(Span messages, byte[] output) /// /// A message that has been added to a template /// - public sealed class TextMessage + public readonly struct TextMessage { /// /// The "role" string for this message diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index ce712b724..8646e4d93 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -5,7 +5,6 @@ using System.Threading.Tasks; using LLama.Abstractions; using LLama.Exceptions; -using LLama.Extensions; using LLama.Native; using Microsoft.Extensions.Logging; diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index 64d263a7a..dd8bca1e2 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -10,6 +10,11 @@ namespace LLama.Native; [DebuggerDisplay("{Value}")] public readonly record struct LLamaToken { + /// + /// Token Value used when token is inherently null + /// + public static readonly LLamaToken InvalidToken = -1; + /// /// The raw value /// diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index f54a8680b..3812a3517 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -19,6 +19,10 @@ public sealed class SafeLLamaContextHandle /// public int VocabCount => ThrowIfDisposed().VocabCount; + /// + /// The underlying vocabulary for the model + /// + /// public LLamaVocabType LLamaVocabType => ThrowIfDisposed().VocabType; /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index f24cfe5fd..1597908e3 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -6,7 +6,6 @@ using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; -using LLama.Extensions; namespace LLama.Native { @@ -221,11 +220,30 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, /// /// /// - /// - /// + /// /// The length of the string on success, or -1 on failure - [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); + private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string key, Span dest) + { + var bytesCount = Encoding.UTF8.GetByteCount(key); + var bytes = ArrayPool.Shared.Rent(bytesCount); + + unsafe + { + fixed (char* keyPtr = key) + fixed (byte* bytesPtr = bytes) + fixed (byte* destPtr = dest) + { + // Convert text into bytes + Encoding.UTF8.GetBytes(keyPtr, key.Length, bytesPtr, bytesCount); + + return llama_model_meta_val_str_native(model, bytesPtr, destPtr, dest.Length); + } + } + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str")] + static extern unsafe int llama_model_meta_val_str_native(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); + } + /// /// Get the number of tokens in the model vocabulary @@ -461,8 +479,8 @@ internal Span TokensToSpan(IReadOnlyList tokens, Span de public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { // Early exit if there's no work to do - if (text == "" && !add_bos) - return Array.Empty(); + if (text == string.Empty && !add_bos) + return []; // Convert string to bytes, adding one extra byte to the end (null terminator) var bytesCount = encoding.GetByteCount(text); @@ -484,7 +502,7 @@ public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding e var tokens = new LLamaToken[count]; fixed (LLamaToken* tokensPtr = tokens) { - NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); + _ = NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); return tokens; } } @@ -510,6 +528,26 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) #endregion #region metadata + /// + /// Get the metadata value for the given key + /// + /// The key to fetch + /// The value, null if there is no such key + public Memory? MetadataValueByKey(string key) + { + // Check if the key exists, without getting any bytes of data + var keyLength = llama_model_meta_val_str(this, key, []); + if (keyLength < 0) + return null; + + // get a buffer large enough to hold it + var buffer = new byte[keyLength + 1]; + keyLength = llama_model_meta_val_str(this, key, buffer); + Debug.Assert(keyLength >= 0); + + return buffer.AsMemory().Slice(0,keyLength); + } + /// /// Get the metadata key for the given index /// @@ -576,13 +614,39 @@ internal IReadOnlyDictionary ReadMetadata() /// /// Get tokens for a model /// - public class ModelTokens + public sealed class ModelTokens { private readonly SafeLlamaModelHandle _model; + private readonly string? _eot; + private readonly string? _eos; internal ModelTokens(SafeLlamaModelHandle model) { _model = model; + _eot = LLamaTokenToString(EOT, true); + _eos = LLamaTokenToString(EOS, true); + } + + private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken) + { + const int buffSize = 32; + Span buff = stackalloc byte[buffSize]; + var tokenLength = _model.TokenToSpan(token ?? LLamaToken.InvalidToken, buff, special: isSpecialToken); + + if (tokenLength <= 0) + { + return null; + } + + // if the original buffer wasn't large enough, create a new one + if (tokenLength > buffSize) + { + buff = stackalloc byte[(int)tokenLength]; + _ = _model.TokenToSpan(token ?? LLamaToken.InvalidToken, buff, special: isSpecialToken); + } + + var slice = buff.Slice(0, (int)tokenLength); + return Encoding.UTF8.GetStringFromSpan(slice); } private static LLamaToken? Normalize(LLamaToken token) @@ -599,6 +663,11 @@ internal ModelTokens(SafeLlamaModelHandle model) /// Get the End of Sentence token for this model /// public LLamaToken? EOS => Normalize(llama_token_eos(_model)); + + /// + /// The textual representation of the end of speech special token for this model + /// + public string? EndOfSpeechToken => _eos; /// /// Get the newline token for this model @@ -635,6 +704,11 @@ internal ModelTokens(SafeLlamaModelHandle model) /// public LLamaToken? EOT => Normalize(llama_token_eot(_model)); + /// + /// Returns the string representation of this model's end_of_text token + /// + public string? EndOfTurnToken => _eot; + /// /// Check if the given token should end generation /// diff --git a/LLama/Transformers/PromptTemplateTransformer.cs b/LLama/Transformers/PromptTemplateTransformer.cs new file mode 100644 index 000000000..3543f9a1a --- /dev/null +++ b/LLama/Transformers/PromptTemplateTransformer.cs @@ -0,0 +1,67 @@ +using System; +using System.Text; +using LLama.Abstractions; +using LLama.Common; + +namespace LLama.Transformers; + +/// +/// A prompt formatter that will use llama.cpp's template formatter +/// If your model is not supported, you will need to define your own formatter according the cchat prompt specification for your model +/// +public class PromptTemplateTransformer(LLamaWeights model, + bool withAssistant = true) : IHistoryTransform +{ + private readonly LLamaWeights _model = model; + private readonly bool _withAssistant = withAssistant; + + /// + public string HistoryToText(ChatHistory history) + { + var template = new LLamaTemplate(_model.NativeHandle) + { + AddAssistant = _withAssistant, + }; + + // encode each message and return the final prompt + foreach (var message in history.Messages) + { + template.Add(message.AuthorRole.ToString().ToLowerInvariant(), message.Content); + } + return ToModelPrompt(template); + } + + /// + public ChatHistory TextToHistory(AuthorRole role, string text) + { + return new ChatHistory([new ChatHistory.Message(role, text)]); + } + + /// + public IHistoryTransform Clone() + { + // need to preserve history? + return new PromptTemplateTransformer(_model); + } + + #region utils + /// + /// Apply the template to the messages and return the resulting prompt as a string + /// + /// + /// The formatted template string as defined by the model + public static string ToModelPrompt(LLamaTemplate template) + { + // Apply the template to update state and get data length + var templateBuffer = template.Apply(); + + // convert the resulting buffer to a string +#if NET6_0_OR_GREATER + return LLamaTemplate.Encoding.GetString(templateBuffer); +#endif + + // need the ToArray call for netstandard -- avoided in newer runtimes + return LLamaTemplate.Encoding.GetString(templateBuffer.ToArray()); + } + #endregion utils +}