diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs
index dd45ebbea..c1b85f2ff 100644
--- a/LLama/Native/LLamaToken.cs
+++ b/LLama/Native/LLamaToken.cs
@@ -1,4 +1,5 @@
using System.Diagnostics;
+using System.Linq;
namespace LLama.Native;
@@ -98,10 +99,7 @@ public bool IsControl(SafeLlamaModelHandle model)
///
public bool IsControl(SafeLlamaModelHandle.Vocabulary vocab)
{
- unsafe
- {
- return LLamaVocabNative.llama_vocab_is_control(vocab.VocabNative, this);
- }
+ return vocab.ControlTokens.Contains((int) this);
}
///
@@ -121,10 +119,7 @@ public bool IsEndOfGeneration(SafeLlamaModelHandle model)
///
public bool IsEndOfGeneration(SafeLlamaModelHandle.Vocabulary vocab)
{
- unsafe
- {
- return LLamaVocabNative.llama_vocab_is_eog(vocab.VocabNative, this);
- }
+ return vocab.EOGTokens.Contains((int) this);
}
///
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 06e3baee4..b862fc069 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -7,7 +7,7 @@ namespace LLama.Native
///
/// Direct translation of the llama.cpp API
///
- public static partial class NativeApi
+ public static partial class NativeApi
{
///
/// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded.
@@ -202,15 +202,31 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage*
/// The length written, or if the buffer is too small a negative that indicates the length required
public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span buffer, int lstrip, bool special)
{
+ unsafe
+ {
+ return llama_token_to_piece(vocab.VocabNative, llamaToken, buffer, lstrip, special);
+ }
+ }
+
+ ///
+ /// Convert a single token into text
+ ///
+ ///
+ ///
+ /// buffer to write string into
+ /// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
+ /// If true, special tokens are rendered in the output
+ /// The length written, or if the buffer is too small a negative that indicates the length required
+ internal static unsafe int llama_token_to_piece(LLamaVocabNative* vocabNative, LLamaToken llamaToken, Span buffer, int lstrip, bool special) {
// Handle invalid tokens
- if ((int)llamaToken < 0)
+ if ((int) llamaToken < 0)
return 0;
unsafe
{
fixed (byte* bufferPtr = buffer)
{
- return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special);
+ return llama_token_to_piece_native(vocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special);
}
}
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 9439c2bb3..63cc502e6 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
+using System.Linq;
using System.Text;
using LLama.Exceptions;
@@ -631,34 +632,63 @@ public sealed class Vocabulary
internal unsafe LLamaVocabNative* VocabNative => llama_model_get_vocab(_model);
- internal Vocabulary(SafeLlamaModelHandle model)
- {
- _model = model;
- }
+ ///
+ /// Map of each token in this vocabulary to its string representation
+ ///
+ public readonly IReadOnlyDictionary TokenToString;
+
+ ///
+ /// Contains unique tokens that are supposed to end the generation (e.g.: EOS, EOT, etc)
+ ///
+ internal readonly HashSet EOGTokens;
- private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
+ ///
+ /// Contains unique tokens that exist for inference control rather than text output
+ ///
+ internal readonly HashSet ControlTokens;
+
+ internal unsafe Vocabulary(SafeLlamaModelHandle model)
{
- if (!token.HasValue)
- return null;
+ _model = model;
- // Try to convert using a fixed size buffer
- const int buffSize = 32;
- Span buff = stackalloc byte[buffSize];
- var tokenLength = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken);
-
- // Negative indicates that there was no result
- if (tokenLength <= 0)
- return null;
-
- // if the original buffer wasn't large enough, try again with one that's the right size
- if (tokenLength > buffSize)
- {
- buff = stackalloc byte[(int)tokenLength];
- _ = _model.TokenToSpan((LLamaToken)token, buff, special: isSpecialToken);
- }
+ // Cache the various properties that llama.cpp API exposes about the vocab
+ var vocabNative = llama_model_get_vocab(_model);
+ Count = LLamaVocabNative.llama_vocab_n_tokens(vocabNative);
+ Type = LLamaVocabNative.llama_vocab_type(vocabNative);
+
+ BOS = Normalize(LLamaVocabNative.llama_vocab_bos(vocabNative));
+ EOS = Normalize(LLamaVocabNative.llama_vocab_eos(vocabNative));
+ EOT = Normalize(LLamaVocabNative.llama_vocab_eot(vocabNative));
+ Pad = Normalize(LLamaVocabNative.llama_vocab_pad(vocabNative));
+ SEP = Normalize(LLamaVocabNative.llama_vocab_sep(vocabNative));
+ Newline = Normalize(LLamaVocabNative.llama_vocab_nl(vocabNative));
+
+ InfillPrefix = Normalize(LLamaVocabNative.llama_vocab_fim_pre(vocabNative));
+ InfillMiddle = Normalize(LLamaVocabNative.llama_vocab_fim_mid(vocabNative));
+ InfillSuffix = Normalize(LLamaVocabNative.llama_vocab_fim_suf(vocabNative));
+ InfillPad = Normalize(LLamaVocabNative.llama_vocab_fim_pad(vocabNative));
+ InfillRep = Normalize(LLamaVocabNative.llama_vocab_fim_rep(vocabNative));
+ InfillSep = Normalize(LLamaVocabNative.llama_vocab_fim_sep(vocabNative));
+
+ DecoderStartToken = Normalize(llama_model_decoder_start_token(_model));
+ ShouldAddBOS = LLamaVocabNative.llama_vocab_get_add_bos(vocabNative);
+ ShouldAddEOS = LLamaVocabNative.llama_vocab_get_add_eos(vocabNative);
+
+ // Cache `TokenToString` for quick access
+ var decoder = Encoding.UTF8.GetDecoder();
+ var (bytesArr, charsArr) = (new byte[1024], new char[1024]);
+ TokenToString = Enumerable.Range(0, Count).ToDictionary(
+ keySelector: i => (LLamaToken) i,
+ elementSelector: i =>
+ {
+ var length = NativeApi.llama_token_to_piece(vocabNative, (LLamaToken) i, bytesArr, 0, true);
+ decoder.Convert(bytesArr, 0, length, charsArr, 0, charsArr.Length, true, out var _, out var charsUsed, out var _);
+ return string.Join("", charsArr.Take(charsUsed));
+ }
+ );
- var slice = buff.Slice(0, (int)tokenLength);
- return Encoding.UTF8.GetStringFromSpan(slice);
+ EOGTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_eog(vocabNative, token)).Select(x => (int) x));
+ ControlTokens = new(TokenToString.Keys.Where(token => LLamaVocabNative.llama_vocab_is_control(vocabNative, token)).Select(x => (int) x));
}
private static LLamaToken? Normalize(LLamaToken token)
@@ -669,232 +699,88 @@ internal Vocabulary(SafeLlamaModelHandle model)
///
/// Total number of tokens in this vocabulary
///
- public int Count
- {
- get
- {
- unsafe
- {
- return LLamaVocabNative.llama_vocab_n_tokens(VocabNative);
- }
- }
- }
+ public int Count { get; }
///
/// Get the the type of this vocabulary
///
- public LLamaVocabType Type
- {
- get
- {
- unsafe
- {
- return LLamaVocabNative.llama_vocab_type(VocabNative);
- }
- }
- }
+ public LLamaVocabType Type { get; }
///
/// Get the Beginning of Sentence token for this model
///
- public LLamaToken? BOS
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_bos(VocabNative));
- }
- }
- }
+ public LLamaToken? BOS { get; }
///
/// Get the End of Sentence token for this model
///
- public LLamaToken? EOS
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_eos(VocabNative));
- }
- }
- }
+ public LLamaToken? EOS { get; }
///
/// Get the newline token for this model
///
- public LLamaToken? Newline
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_nl(VocabNative));
- }
- }
- }
+ public LLamaToken? Newline { get; }
///
/// Get the padding token for this model
///
- public LLamaToken? Pad
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_pad(VocabNative));
- }
- }
- }
+ public LLamaToken? Pad { get; }
///
/// Get the sentence separator token for this model
///
- public LLamaToken? SEP
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_sep(VocabNative));
- }
- }
- }
+ public LLamaToken? SEP { get; }
///
/// Codellama beginning of infill prefix
///
- public LLamaToken? InfillPrefix
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_fim_pre(VocabNative));
- }
- }
- }
+ public LLamaToken? InfillPrefix { get; }
///
/// Codellama beginning of infill middle
///
- public LLamaToken? InfillMiddle
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_fim_mid(VocabNative));
- }
- }
- }
+ public LLamaToken? InfillMiddle { get; }
///
/// Codellama beginning of infill suffix
///
- public LLamaToken? InfillSuffix
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_fim_suf(VocabNative));
- }
- }
- }
+ public LLamaToken? InfillSuffix { get; }
///
/// Codellama pad
///
- public LLamaToken? InfillPad
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_fim_pad(VocabNative));
- }
- }
- }
+ public LLamaToken? InfillPad { get; }
///
/// Codellama rep
///
- public LLamaToken? InfillRep
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_fim_rep(VocabNative));
- }
- }
- }
+ public LLamaToken? InfillRep { get; }
///
/// Codellama rep
///
- public LLamaToken? InfillSep
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_fim_sep(VocabNative));
- }
- }
- }
+ public LLamaToken? InfillSep { get; }
///
/// end-of-turn token
///
- public LLamaToken? EOT
- {
- get
- {
- unsafe
- {
- return Normalize(LLamaVocabNative.llama_vocab_eot(VocabNative));
- }
- }
- }
+ public LLamaToken? EOT { get; }
///
/// For encoder-decoder models, this function returns id of the token that must be provided
/// to the decoder to start generating output sequence.
///
- public LLamaToken? DecoderStartToken => Normalize(llama_model_decoder_start_token(_model));
+ public LLamaToken? DecoderStartToken { get; }
///
/// Check if the current model requires a BOS token added
///
- public bool ShouldAddBOS
- {
- get
- {
- unsafe
- {
- return LLamaVocabNative.llama_vocab_get_add_bos(llama_model_get_vocab(_model));
- }
- }
- }
+ public bool ShouldAddBOS { get; }
///
/// Check if the current model requires a EOS token added
///
- public bool ShouldAddEOS
- {
- get
- {
- unsafe
- {
- return LLamaVocabNative.llama_vocab_get_add_eos(llama_model_get_vocab(_model));
- }
- }
- }
+ public bool ShouldAddEOS { get; }
}
}
}