diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index dd45ebbea..c1b85f2ff 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -1,4 +1,5 @@ using System.Diagnostics; +using System.Linq; namespace LLama.Native; @@ -98,10 +99,7 @@ public bool IsControl(SafeLlamaModelHandle model) /// public bool IsControl(SafeLlamaModelHandle.Vocabulary vocab) { - unsafe - { - return LLamaVocabNative.llama_vocab_is_control(vocab.VocabNative, this); - } + return vocab.ControlTokens.Contains((int) this); } /// @@ -121,10 +119,7 @@ public bool IsEndOfGeneration(SafeLlamaModelHandle model) /// public bool IsEndOfGeneration(SafeLlamaModelHandle.Vocabulary vocab) { - unsafe - { - return LLamaVocabNative.llama_vocab_is_eog(vocab.VocabNative, this); - } + return vocab.EOGTokens.Contains((int) this); } /// diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 06e3baee4..b862fc069 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -7,7 +7,7 @@ namespace LLama.Native /// /// Direct translation of the llama.cpp API /// - public static partial class NativeApi + public static partial class NativeApi { /// /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. @@ -202,15 +202,31 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* /// The length written, or if the buffer is too small a negative that indicates the length required public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span buffer, int lstrip, bool special) { + unsafe + { + return llama_token_to_piece(vocab.VocabNative, llamaToken, buffer, lstrip, special); + } + } + + /// + /// Convert a single token into text + /// + /// + /// + /// buffer to write string into + /// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') + /// If true, special tokens are rendered in the output + /// The length written, or if the buffer is too small a negative that indicates the length required + internal static unsafe int llama_token_to_piece(LLamaVocabNative* vocabNative, LLamaToken llamaToken, Span buffer, int lstrip, bool special) { // Handle invalid tokens - if ((int)llamaToken < 0) + if ((int) llamaToken < 0) return 0; unsafe { fixed (byte* bufferPtr = buffer) { - return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special); + return llama_token_to_piece_native(vocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special); } } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 9439c2bb3..63cc502e6 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.Linq; using System.Text; using LLama.Exceptions; @@ -631,34 +632,63 @@ public sealed class Vocabulary internal unsafe LLamaVocabNative* VocabNative => llama_model_get_vocab(_model); - internal Vocabulary(SafeLlamaModelHandle model) - { - _model = model; - } + /// + /// Map of each token in this vocabulary to its string representation + /// + public readonly IReadOnlyDictionary TokenToString; + + /// + /// Contains unique tokens that are supposed to end the generation (e.g.: EOS, EOT, etc) + /// + internal readonly HashSet EOGTokens; - private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken) + /// + /// Contains unique tokens that exist for inference control rather than text output + /// + internal readonly HashSet ControlTokens; + + internal unsafe Vocabulary(SafeLlamaModelHandle model) { - if (!token.HasValue) - return null; + _model = model; - // Try to convert using a fixed size buffer - const int buffSize = 32; - Span buff = stackalloc byte[buffSize]; - var tokenLength = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken); - - // Negative indicates that there was no result - if (tokenLength <= 0) - return null; - - // if the original buffer wasn't large enough, try again with one that's the right size - if (tokenLength > buffSize) - { - buff = stackalloc byte[(int)tokenLength]; - _ = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken); - } + // Cache the various properties that llama.cpp API exposes about the vocab + var vocabNative = llama_model_get_vocab(_model); + Count = LLamaVocabNative.llama_vocab_n_tokens(vocabNative); + Type = LLamaVocabNative.llama_vocab_type(vocabNative); + + BOS = Normalize(LLamaVocabNative.llama_vocab_bos(vocabNative)); + EOS = Normalize(LLamaVocabNative.llama_vocab_eos(vocabNative)); + EOT = Normalize(LLamaVocabNative.llama_vocab_eot(vocabNative)); + Pad = Normalize(LLamaVocabNative.llama_vocab_pad(vocabNative)); + SEP = Normalize(LLamaVocabNative.llama_vocab_sep(vocabNative)); + Newline = Normalize(LLamaVocabNative.llama_vocab_nl(vocabNative)); + + InfillPrefix = Normalize(LLamaVocabNative.llama_vocab_fim_pre(vocabNative)); + InfillMiddle = Normalize(LLamaVocabNative.llama_vocab_fim_mid(vocabNative)); + InfillSuffix = Normalize(LLamaVocabNative.llama_vocab_fim_suf(vocabNative)); + InfillPad = Normalize(LLamaVocabNative.llama_vocab_fim_pad(vocabNative)); + InfillRep = Normalize(LLamaVocabNative.llama_vocab_fim_rep(vocabNative)); + InfillSep = Normalize(LLamaVocabNative.llama_vocab_fim_sep(vocabNative)); + + DecoderStartToken = Normalize(llama_model_decoder_start_token(_model)); + ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative); + ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative); + + // Cache `TokenToString` for quick access + var decoder = Encoding.UTF8.GetDecoder(); + var (bytesArr, charsArr) = (new byte[1024], new char[1024]); + TokenToString = Enumerable.Range(0, Count).ToDictionary( + keySelector: i => (LLamaToken) i, + elementSelector: i => + { + var length = NativeApi.llama_token_to_piece(vocabNative, (LLamaToken) i, bytesArr, 0, true); + decoder.Convert(bytesArr, 0, length, charsArr, 0, charsArr.Length, true, out var _, out var charsUsed, out var _); + return string.Join("", charsArr.Take(charsUsed)); + } + ); - var slice = buff.Slice(0, (int)tokenLength); - return Encoding.UTF8.GetStringFromSpan(slice); + EOGTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token)).Select(x => (int) x)); + ControlTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token)).Select(x => (int) x)); } private static LLamaToken? Normalize(LLamaToken token) @@ -669,232 +699,88 @@ internal Vocabulary(SafeLlamaModelHandle model) /// /// Total number of tokens in this vocabulary /// - public int Count - { - get - { - unsafe - { - return LLamaVocabNative.llama_vocab_n_tokens(VocabNative); - } - } - } + public int Count { get; } /// /// Get the the type of this vocabulary /// - public LLamaVocabType Type - { - get - { - unsafe - { - return LLamaVocabNative.llama_vocab_type(VocabNative); - } - } - } + public LLamaVocabType Type { get; } /// /// Get the Beginning of Sentence token for this model /// - public LLamaToken? BOS - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_bos(VocabNative)); - } - } - } + public LLamaToken? BOS { get; } /// /// Get the End of Sentence token for this model /// - public LLamaToken? EOS - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_eos(VocabNative)); - } - } - } + public LLamaToken? EOS { get; } /// /// Get the newline token for this model /// - public LLamaToken? Newline - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_nl(VocabNative)); - } - } - } + public LLamaToken? Newline { get; } /// /// Get the padding token for this model /// - public LLamaToken? Pad - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_pad(VocabNative)); - } - } - } + public LLamaToken? Pad { get; } /// /// Get the sentence separator token for this model /// - public LLamaToken? SEP - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_sep(VocabNative)); - } - } - } + public LLamaToken? SEP { get; } /// /// Codellama beginning of infill prefix /// - public LLamaToken? InfillPrefix - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_pre(VocabNative)); - } - } - } + public LLamaToken? InfillPrefix { get; } /// /// Codellama beginning of infill middle /// - public LLamaToken? InfillMiddle - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_mid(VocabNative)); - } - } - } + public LLamaToken? InfillMiddle { get; } /// /// Codellama beginning of infill suffix /// - public LLamaToken? InfillSuffix - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_suf(VocabNative)); - } - } - } + public LLamaToken? InfillSuffix { get; } /// /// Codellama pad /// - public LLamaToken? InfillPad - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_pad(VocabNative)); - } - } - } + public LLamaToken? InfillPad { get; } /// /// Codellama rep /// - public LLamaToken? InfillRep - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_rep(VocabNative)); - } - } - } + public LLamaToken? InfillRep { get; } /// /// Codellama rep /// - public LLamaToken? InfillSep - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_fim_sep(VocabNative)); - } - } - } + public LLamaToken? InfillSep { get; } /// /// end-of-turn token /// - public LLamaToken? EOT - { - get - { - unsafe - { - return Normalize(LLamaVocabNative.llama_vocab_eot(VocabNative)); - } - } - } + public LLamaToken? EOT { get; } /// /// For encoder-decoder models, this function returns id of the token that must be provided /// to the decoder to start generating output sequence. /// - public LLamaToken? DecoderStartToken => Normalize(llama_model_decoder_start_token(_model)); + public LLamaToken? DecoderStartToken { get; } /// /// Check if the current model requires a BOS token added /// - public bool ShouldAddBOS - { - get - { - unsafe - { - return LLamaVocabNative.llama_vocab_get_add_bos(llama_model_get_vocab(_model)); - } - } - } + public bool ShouldAddBOS { get; } /// /// Check if the current model requires a EOS token added /// - public bool ShouldAddEOS - { - get - { - unsafe - { - return LLamaVocabNative.llama_vocab_get_add_eos(llama_model_get_vocab(_model)); - } - } - } + public bool ShouldAddEOS { get; } } } }