diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index 415e2e0e57..008dacb573 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -95,6 +95,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st (Dictionary? vocab1, Vec<(string, string)> merges) = ReadFile(vocabFile, mergesFile); Vocab = vocab1 ?? new Dictionary(); + Cache = new Cache(); VocabReverse = new(); @@ -146,23 +147,33 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st /// Tokenize a sequence string to a list of tokens. /// /// The sequence to tokenize. + /// Indicate if the token is a special token. /// The list of tokens generated from the sequence tokenization. - public override IReadOnlyList Tokenize(string sequence) + public override IReadOnlyList Tokenize(string sequence, bool isSpecialToken = false) { if (sequence.Length == 0) { return EmptyTokensList; } - if (!Dropout.HasValue) - { - return TokenizeWithCache(sequence); - } + return TokenizeWithCache(sequence); + } - Word word = MergeWord(sequence); + /// + /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list. + /// + /// The sequence to split. + /// Indicate if the token is a special token. + /// The list of accumulated tokenized Ids. + public override void TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) => TokenizeToIdsWithCache(sequence, accumulatedIds); - return WordToTokens(ref word); - } + /// + /// Get the number of tokens that the input sequence will be encoded to. + /// + /// The text to tokenize. + /// Indicate if the token is special token. + /// The number of tokens that the input sequence will be encoded to. + public override int CountTokens(string sequence, bool isSpecialToken) => TokenizeToIdsWithCache(sequence, null); /// /// Map the token to tokenized Id. @@ -195,14 +206,6 @@ public override IReadOnlyList Tokenize(string sequence) return null; } - /// - /// Map the tokenized Id to the token. - /// - /// The Id to map to the token. - /// Indicate if want to skip the special tokens during the decoding. - /// The mapped token of the Id. - public override string? IdToString(int id, bool skipSpecialTokens = false) => throw new NotImplementedException(); - /// /// Gets the dictionary mapping tokens to Ids. /// @@ -332,7 +335,7 @@ internal string CharToString(char c) internal Word MergeWord(string w) { - Word word = Word.WithCapacity((int)w.Length); + Word word = Word.WithCapacity(w.Length); (int Id, int Len)? unk = null; int i = 0; @@ -344,7 +347,7 @@ internal Word MergeWord(string w) if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1])) { length = 2; - s = w.Substring(i, (int)length); + s = w.Substring(i, length); } else { @@ -403,7 +406,7 @@ internal Word MergeWord(string w) } } - i += (int)length; + i += length; } if (unk.HasValue) @@ -415,45 +418,59 @@ internal Word MergeWord(string w) return word; } - // internal Word.Enumerator WordToTokens(Word word) => word.GetIterator(VocabReverse); - internal List WordToTokens(ref Word word) + internal List WordToTokens(ref Word word) => word.ToTokens(VocabReverse); + + internal List TokenizeWithCache(string sequence) { - List tokens = new(word.SymbolsCount); + Word word; + if (Cache is not null) + { + if (Cache.TryGet(sequence, out word)) + { + return WordToTokens(ref word); + } - foreach (Token token in word.GetIterator(VocabReverse)) + word = MergeWord(sequence); + Cache.Set(sequence, word); + } + else { - tokens.Add(token); + word = MergeWord(sequence); } - return tokens; + return WordToTokens(ref word); } - internal List TokenizeWithCache(string sequence) + internal int WordToIds(ref Word word, IList? accumulatedIds) { - if (Cache is not null) + if (accumulatedIds is not null) { - Word? hit = Cache.Get(sequence); - if (hit.HasValue) - { - Word w = hit.Value; - return WordToTokens(ref w); - } + word.PopulateIds(accumulatedIds); } - Word word = MergeWord(sequence); - List tokens = WordToTokens(ref word); + return word.SymbolsCount; + } + + internal int TokenizeToIdsWithCache(string sequence, IList? accumulatedIds) + { + Word word; if (Cache is not null) { + if (Cache.TryGet(sequence, out Word hit)) + { + return WordToIds(ref hit, accumulatedIds); + } + + word = MergeWord(sequence); Cache.Set(sequence, word); } + else + { + word = MergeWord(sequence); + } - return tokens; - } - - public override bool IsValidChar(char ch) - { - throw new NotImplementedException(); + return WordToIds(ref word, accumulatedIds); } internal static readonly List EmptyTokensList = new(); diff --git a/src/Microsoft.ML.Tokenizers/Model/Cache.cs b/src/Microsoft.ML.Tokenizers/Model/Cache.cs index 269a580b1e..1fcfa849ec 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Cache.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Cache.cs @@ -9,14 +9,14 @@ namespace Microsoft.ML.Tokenizers { - internal sealed class Cache where TKey : notnull + internal sealed class Cache where TKey : notnull where TValue : notnull { internal Cache() : this(Bpe.DefaultCacheCapacity) { } internal Cache(int capacity) { Capacity = capacity; - Map = new Dictionary((int)Capacity); + Map = new Dictionary(Capacity); } private readonly ReaderWriterLockSlim _cacheLock = new ReaderWriterLockSlim(); @@ -25,7 +25,7 @@ internal Cache(int capacity) internal int Capacity { get; set; } - internal void Fresh() => Map = new Dictionary((int)Capacity); + internal void Fresh() => Map = new Dictionary(Capacity); internal void Clear() { @@ -56,27 +56,22 @@ internal List GetValues(IEnumerable keys) return values; } - internal TValue? Get(TKey key) + internal bool TryGet(TKey key, out TValue value) { _cacheLock.EnterReadLock(); try { - if (Map.TryGetValue(key, out TValue? value)) - { - return value; - } + return Map.TryGetValue(key, out value!); } finally { _cacheLock.ExitReadLock(); } - - return default; } - internal void SetValues(IEnumerable<(TKey, TValue)> enteries) + internal void SetValues(IEnumerable<(TKey, TValue)> entries) { _cacheLock.EnterWriteLock(); try { - foreach ((TKey, TValue) entry in enteries) + foreach ((TKey, TValue) entry in entries) { if (Capacity <= Map.Count) { diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index 02fef91ded..e750a5df6c 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.IO; @@ -55,9 +56,9 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc using Stream mergeStream = File.OpenRead(mergePath); using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath); - // vocabularyPath like encoder.json - // merge file like vocab.bpe - // highestOccurrenceMappingPath like dict.txt + // vocabularyPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" + // merge file like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" + // highestOccurrenceMappingPath like "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt" _vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream); _vocab = GetVocabulary(vocabularyStream); @@ -136,12 +137,12 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes skipSpecialTokens && id < 0 ? null : _vocabReverse.TryGetValue(id, out var value) ? value : null; /// - /// Map the tokenized Id to the original string. + /// Map the tokenized Id to the original string while filtering out unsupported characters. /// /// The Id to map to the string. /// Indicate if want to skip the special tokens during the decoding. /// The mapped token of the Id. - public override string? IdToString(int id, bool skipSpecialTokens = false) + public string? IdToFilteredToken(int id, bool skipSpecialTokens = false) { if (skipSpecialTokens && id < 0) return null; @@ -166,8 +167,8 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes public override string[] Save(string path, string? prefix = null) { // Write vocab.json - string vocabFileNname = prefix is null ? "vocab.json" : $"{prefix}-vocab.json"; - string vocabPath = Path.Combine(path, vocabFileNname); + string vocabFileName = prefix is null ? "vocab.json" : $"{prefix}-vocab.json"; + string vocabPath = Path.Combine(path, vocabFileName); string serialized = JsonSerializer.Serialize(_vocabReverse, new JsonSerializerOptions { Converters = { new DictReversingConverter() } }); File.WriteAllText(vocabPath, serialized, System.Text.Encoding.UTF8); @@ -203,10 +204,75 @@ public override string[] Save(string path, string? prefix = null) /// Tokenize a sequence string to a list of tokens. /// /// The sequence to tokenize. + /// Indicate if the token is a special token. /// The list of tokens generated from the sequence tokenization. - public override IReadOnlyList Tokenize(string sequence) + public override IReadOnlyList Tokenize(string sequence, bool isSpecialToken = false) { - var bpeTokens = new List(); + char[] token = ArrayPool.Shared.Rent(sequence.Length); + int[] indexMapping = ArrayPool.Shared.Rent(sequence.Length); + + int newTokenIndex = 0; + for (int i = 0; i < sequence.Length; i++) + { + if (_byteToUnicode.TryGetValue(sequence[i], out var value)) + { + token[newTokenIndex] = value; + indexMapping[newTokenIndex] = i; + newTokenIndex++; + } + } + + if (newTokenIndex == 0) + { + ArrayPool.Shared.Return(token); + ArrayPool.Shared.Return(indexMapping); + return Bpe.EmptyTokensList; + } + + if (_cache.TryGet(sequence, out IReadOnlyList? hit)) + { + ArrayPool.Shared.Return(token); + ArrayPool.Shared.Return(indexMapping); + return ModifyTokenListOffsets(hit, indexMapping); + } + + IReadOnlyList result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping); + _cache.Set(sequence, result); + ArrayPool.Shared.Return(token); + ArrayPool.Shared.Return(indexMapping); + return result; + } + + /// + /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list. + /// + /// The sequence to split. + /// Indicate if the token is a special token. + /// The list of accumulated tokenized Ids. + public override void TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) => TokenizeToIds(sequence, accumulatedIds); + + /// + /// Get the number of tokens that the input sequence will be encoded to. + /// + /// The text to tokenize. + /// Indicate if the token is special token. + /// The number of tokens that the input sequence will be encoded to. + public override int CountTokens(string sequence, bool isSpecialToken) => TokenizeToIds(sequence, null); + + private int TokenizeToIds(string sequence, IList? accumulatedIds) + { + if (_cache.TryGet(sequence, out IReadOnlyList? hit)) + { + if (accumulatedIds is not null) + { + foreach (var t in hit) + { + accumulatedIds.Add(t.Id); + } + } + + return hit.Count; + } Span token = stackalloc char[100]; Span indexMapping = stackalloc int[100]; @@ -230,18 +296,12 @@ public override IReadOnlyList Tokenize(string sequence) if (newTokenIndex == 0) { - return Bpe.EmptyTokensList; - } - - IReadOnlyList? hit = _cache.Get(sequence); - if (hit is not null) - { - return ModifyTokenListOffsets(hit, indexMapping); + return 0; } - IReadOnlyList result = BpeToken(token.Slice(0, newTokenIndex), indexMapping); + IReadOnlyList result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping); _cache.Set(sequence, result); - return result; + return result.Count; } /// @@ -422,14 +482,28 @@ private Dictionary GetVocabulary(Stream vocabularyStream) private Dictionary<(string, string), int> GetMergeRanks(Stream mergeStream) { - List splitContents = new(); - + var mergeRanks = new Dictionary<(string, string), int>(); try { using StreamReader reader = new StreamReader(mergeStream); + + // We ignore the first and last line in the file + if (reader.Peek() >= 0) + { + string ignored = reader.ReadLine()!; + } + + int rank = 1; while (reader.Peek() >= 0) { - splitContents.Add(reader.ReadLine()!); + string line = reader.ReadLine()!; + int index = line.IndexOf(' '); + if (index < 1 || index == line.Length - 1 || line.IndexOf(' ', index + 1) != -1) + { + throw new Exception($"Invalid format of merge file: \"{line}\""); + } + + mergeRanks.Add((line.Substring(0, index), line.Substring(index + 1)), rank++); } } catch (Exception e) @@ -437,20 +511,6 @@ private Dictionary GetVocabulary(Stream vocabularyStream) throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e); } - var mergeRanks = new Dictionary<(string, string), int>(); - - // We ignore the first and last line in the file - for (int i = 1; i < splitContents.Count - 1; i++) - { - var split = splitContents[i].Split(' '); - if (split.Length != 2 || string.IsNullOrEmpty(split[0]) || string.IsNullOrEmpty(split[1])) - { - throw new Exception($"Invalid format of merge file: \"{splitContents[i]}\""); - } - - mergeRanks.Add((split[0], split[1]), i); - } - return mergeRanks; } @@ -481,10 +541,28 @@ private static int GetByteToUnicode(out IReadOnlyDictionary byteToUn } /// - /// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"]. + /// Encode a token into BPE-ed Ids. E.g., "playing" into ["play", "ing"]. /// - private List BpeToken(Span token, Span indexMapping) + /// The token to encode. + /// The list of Ids to encode the token into. + /// The number of encoded ids. + private int EncodeToIds(Span token, IList? ids) { + if (token.Length == 0) + { + return 0; + } + + if (token.Length == 1) + { + if (ids is not null) + { + ids.Add(_vocab[_charToString[token[0]]]); + } + + return 1; + } + List word = new(token.Length); foreach (char c in token) { @@ -492,14 +570,120 @@ private List BpeToken(Span token, Span indexMapping) word.Add(_charToString[c]); } - HashSet<(string, string)> pairs = WordToPairs(word); + HashSet<(string, string)> pairs = new(); - if (pairs.Count == 0) + WordToPairs(word, pairs); + + var newWord = new List(); + + Debug.Assert(pairs.Count != 0, "Pairs should not be empty."); + + while (true) { - string tokenValue = token.ToString(); - return new List { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], indexMapping[token.Length - 1] + 1)) }; + /* while conditions */ + // if only one element left, merge is finished (with the whole word merged) + if (word.Count == 1) + { + break; + } + + // get the most frequent bi-gram pair + var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue)); + if (!_mergeRanks.ContainsKey((first, second))) + { + break; + } + /* end while conditions */ + + // search and merge all (first, second) pairs in {word} + var i = 0; + while (i < word.Count) + { + // find the next occurrence of {first} and add the elements before into {newWord} + var j = word.IndexOf(first, i); + if (j == -1) + { + // Equivalent to newWord.AddRange(word.Skip(i)) without allocations + for (int k = i; k < word.Count; k++) + { + newWord.Add(word[k]); + } + break; + } + else + { + // Equivalent to newWord.AddRange(word.Skip(i).Take(j - i)) without allocations + for (int k = i; k < j; k++) + { + newWord.Add(word[k]); + } + i = j; + } + + // check the next element is {second} or not + if (i < word.Count - 1 && word[i + 1] == second) + { + newWord.Add(first + second); + i += 2; + } + else + { + newWord.Add(word[i]); + i += 1; + } + } + + List temp = word; + word = newWord; + newWord = temp; + newWord.Clear(); + + // otherwise, continue merging + WordToPairs(word, pairs); + } + + if (ids is not null) + { + foreach (string w in word) + { + ids.Add(_vocab[w]); + } + } + + return word.Count; + } + + /// + /// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"]. + /// + private List EncodeToTokens(Span token, Span indexMapping) + { + if (token.Length == 0) + { + return Bpe.EmptyTokensList; + } + + if (token.Length == 1) + { + string tokenValue = _charToString[token[0]]; + return new List { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], indexMapping[0] + 1)) }; } + List word = new(token.Length); + foreach (char c in token) + { + Debug.Assert(c < _charToString.Length); + word.Add(_charToString[c]); + } + + HashSet<(string, string)> pairs = new(); + + WordToPairs(word, pairs); + + var newWord = new List(); + + Debug.Assert(pairs.Count != 0, "Pairs should not be empty."); + while (true) { /* while conditions */ @@ -518,7 +702,6 @@ private List BpeToken(Span token, Span indexMapping) /* end while conditions */ // search and merge all (first, second) pairs in {word} - var newWord = new List(); var i = 0; while (i < word.Count) { @@ -548,10 +731,13 @@ private List BpeToken(Span token, Span indexMapping) } } + List temp = word; word = newWord; + newWord = temp; + newWord.Clear(); // otherwise, continue merging - pairs = WordToPairs(word); + WordToPairs(word, pairs); } var tokens = new List(word.Count); @@ -570,12 +756,13 @@ private List BpeToken(Span token, Span indexMapping) /// Extract element pairs in an aggregating word. E.g. [p, l, ay] into [(p,l), (l,ay)]. /// If word contains 0 or 1 element, an empty HashSet will be returned. /// - private static HashSet<(string, string)> WordToPairs(IReadOnlyList word) + private static void WordToPairs(IReadOnlyList word, HashSet<(string, string)> pairs) { - var pairs = new HashSet<(string, string)>(); + pairs.Clear(); + if (word.Count <= 1) { - return pairs; + return; } var prevElem = word[0]; @@ -584,11 +771,9 @@ private List BpeToken(Span token, Span indexMapping) pairs.Add((prevElem, elem)); prevElem = elem; } - - return pairs; } - public override bool IsValidChar(char ch) + public bool CharInSupportedRange(char ch) { return _byteToUnicode.ContainsKey(ch); } @@ -629,16 +814,16 @@ public HighestOccurrenceMapping(string pad = "", string eos = "", strin PadWord = pad; EosWord = eos; UnkWord = unk; - BosIndex = ReserveStringSymboleSlot(bos); - PadIndex = ReserveStringSymboleSlot(pad); - EosIndex = ReserveStringSymboleSlot(eos); - UnkIndex = ReserveStringSymboleSlot(unk); + BosIndex = ReserveStringSymbolSlot(bos); + PadIndex = ReserveStringSymbolSlot(pad); + EosIndex = ReserveStringSymbolSlot(eos); + UnkIndex = ReserveStringSymbolSlot(unk); if (extraSpecialSymbols is not null) { foreach (var symbol in extraSpecialSymbols) { - ReserveStringSymboleSlot(symbol); + ReserveStringSymbolSlot(symbol); } } } @@ -675,7 +860,7 @@ public int OccurrenceRankToId(int rank) return _symbols[rank].Id; } - private int ReserveStringSymboleSlot(string symbol, int defaultOccurrence = -1) + private int ReserveStringSymbolSlot(string symbol, int defaultOccurrence = -1) { if (symbol is null) { @@ -707,7 +892,7 @@ public int AddSymbol(int id, int highOccuranceScore) public int AddMaskSymbol(string mask = "") { MaskWord = mask; - MaskIndex = ReserveStringSymboleSlot(mask, 1); + MaskIndex = ReserveStringSymbolSlot(mask, 1); return MaskIndex; } @@ -780,7 +965,7 @@ public void AddFromStream(Stream stream) if (!int.TryParse(splitLine[0], out var id)) { - ReserveStringSymboleSlot(splitLine[0], occurrenceScore); + ReserveStringSymbolSlot(splitLine[0], occurrenceScore); } else { diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs index 9f7fe9e698..c8bd01cc06 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Model.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs @@ -13,20 +13,13 @@ namespace Microsoft.ML.Tokenizers /// public abstract class Model { - /// - /// Tokenize a sequence string to a list of tokens. - /// - /// The sequence to tokenize. - /// The list of tokens generated from the sequence tokenization. - public abstract IReadOnlyList Tokenize(string sequence); - /// /// Tokenize a split sequence string to a list of tokens. /// /// The text to tokenize. /// Indicate if the token is a special token. /// The list of tokens generated from the sequence tokenization. - public virtual IReadOnlyList Tokenize(string sequence, bool isSpecialToken) => Tokenize(sequence); + public abstract IReadOnlyList Tokenize(string sequence, bool isSpecialToken = false); /// /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list. @@ -34,8 +27,11 @@ public abstract class Model /// The sequence to split. /// Indicate if the token is a special token. /// The list of accumulated tokenized Ids. - /// True if the operation succeeded, false otherwise. - public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) + /// + /// This method does the default implementation that uses the Tokenize method to get the token's Ids. + /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. + /// + public virtual void TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) { if (accumulatedIds is null) { @@ -47,7 +43,23 @@ public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, IList + /// Get the number of tokens that the input sequence will be encoded to. + /// + /// The text to tokenize. + /// Indicate if the token is special token. + /// The number of tokens that the input sequence will be encoded to. + /// + /// This method does the default implementation that uses the TokenizeToIds method to get the number of token's Ids. + /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. + /// + public virtual int CountTokens(string sequence, bool isSpecialToken) + { + var ids = new List(); + TokenizeToIds(sequence, isSpecialToken, ids); + return ids.Count; } /// @@ -73,8 +85,6 @@ public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, IListThe mapped token of the Id. public abstract string? IdToToken(int id, bool skipSpecialTokens = false); - public abstract string? IdToString(int id, bool skipSpecialTokens = false); - /// /// Gets the dictionary mapping tokens to Ids. /// @@ -97,12 +107,5 @@ public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, IList public abstract Trainer? GetTrainer(); - - /// - /// Return true if the char is valid in the tokenizer; otherwise return false. - /// - /// - /// - public abstract bool IsValidChar(char ch); } } diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index cfc657afbc..bd9a376e20 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -164,13 +164,6 @@ internal static (Dictionary, Dictionary, IReadOnlyDict /// The dictionary mapping special tokens to Ids. public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoder; - /// - /// Tokenize a sequence string to a list of tokens. - /// - /// The sequence to tokenize. - /// The list of tokens generated from the sequence tokenization. - public override IReadOnlyList Tokenize(string sequence) => Tokenize(sequence, isSpecialToken: false); - /// /// Tokenize a split sequence string to a list of tokens. /// @@ -240,12 +233,11 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok /// The sequence to tokenize. /// Indicate if the token is a special token. /// The list of accumulated Ids. - /// True if the operation succeeded, false otherwise. - public override bool TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) + public override void TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) { if (string.IsNullOrEmpty(sequence)) { - return true; + return; } if (isSpecialToken) @@ -253,29 +245,62 @@ public override bool TokenizeToIds(string sequence, bool isSpecialToken, IList + /// Get the number of tokens that the input sequence will be encoded to. + /// + /// The text to tokenize. + /// Indicate if the token is special token. + /// The number of tokens that the input sequence will be encoded to. + public override int CountTokens(string sequence, bool isSpecialToken) + { + if (string.IsNullOrEmpty(sequence)) + { + return 0; + } + + if (isSpecialToken && _specialTokensEncoder is not null) + { + return _specialTokensEncoder.TryGetValue(sequence, out int id) ? 1 : 0; + } + + if (_cache.Lookup(sequence, out int[] ids)) + { + return ids.Length; + } + + if (_vocab.TryGetValue(sequence, out int mappedId)) + { + return 1; + } + + int[] encodedIds = BytePairEncoder.BytePairEncode(Encoding.UTF8.GetBytes(sequence), _encoder); + _cache.Add(sequence, encodedIds); + + return encodedIds.Length; } /// @@ -379,8 +404,6 @@ public override bool TokenizeToIds(string sequence, bool isSpecialToken, IList 0 ? Encoding.UTF8.GetString(utf8Bytes.ToArray()) : string.Empty; } - public override string? IdToString(int id, bool skipSpecialTokens = false) => IdToToken(id, skipSpecialTokens); - /// /// Gets the dictionary mapping tokens to Ids. /// @@ -403,12 +426,5 @@ public override bool TokenizeToIds(string sequence, bool isSpecialToken, IList public override Trainer? GetTrainer() => throw new NotImplementedException(); - - /// - /// Return true if the char is valid in the tokenizer; otherwise return false. - /// - /// - /// - public override bool IsValidChar(char ch) => true; } } \ No newline at end of file diff --git a/src/Microsoft.ML.Tokenizers/Model/Word.cs b/src/Microsoft.ML.Tokenizers/Model/Word.cs index 7124f8d8e0..15674f402f 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Word.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Word.cs @@ -22,7 +22,7 @@ public Word(int capacity) { throw new ArgumentOutOfRangeException(nameof(capacity)); } - _symbols = new Vec((int)capacity); + _symbols = new Vec(capacity); } public static Word WithCapacity(int capacity) => new Word(capacity); @@ -174,7 +174,7 @@ public void MergeAll(Dictionary, (int, int)> merges, float? dropout) int next = current.Next; if ((uint)next < (uint)_symbols.Count) { - Symbol nextSymbol = _symbols[(int)next]; + Symbol nextSymbol = _symbols[next]; Pair newPair = Pair.Create(current.C, nextSymbol.C); if (merges.TryGetValue(newPair, out value)) { @@ -194,6 +194,14 @@ public void MergeAll(Dictionary, (int, int)> merges, float? dropout) } } + public void PopulateIds(IList accumulatedIds) + { + for (int i = 0; i < SymbolsCount; i++) + { + accumulatedIds.Add(_symbols[i].C); + } + } + public Vec GetChars() { Vec chars = new Vec(); @@ -223,39 +231,19 @@ public override string ToString() return sb.ToString(); } - public Enumerator GetIterator(SortedDictionary vocabReverse) => new Enumerator(ref _symbols, vocabReverse); - - public struct Enumerator + public List ToTokens(SortedDictionary vocabReverse) { - private int _index; - private int _pos; - private Vec _symbols; - private readonly SortedDictionary _vocabReverse; + List tokens = new(SymbolsCount); + int index = 0; - public Enumerator(ref Vec symbols, SortedDictionary vocabReverse) + for (int i = 0; i < SymbolsCount; i++) { - _index = -1; - _pos = 0; - _symbols = symbols; - _vocabReverse = vocabReverse; + int endIndex = index + _symbols[i].Len; + tokens.Add(new Token(_symbols[i].C, vocabReverse[_symbols[i].C], (index, endIndex))); + index = endIndex; } - public readonly Enumerator GetEnumerator() => this; - - public readonly Token Current => new Token(_symbols[_index].C, _vocabReverse[_symbols[_index].C], (_pos, _pos + _symbols[_index].Len)); - - public bool MoveNext() - { - if (_symbols.Count == 0 || _index >= _symbols.Count - 1) - { - return false; - } - - _pos = _index == -1 ? 0 : _pos + _symbols[_index].Len; - - _index++; - return true; - } + return tokens; } } } diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index 94d07abb4d..aee4c84bcb 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -51,13 +51,6 @@ public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? nor /// public TokenizerDecoder? Decoder { get; set; } - /// - /// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping. - /// - /// The text to tokenize. - /// The tokenization result includes the tokens list, tokens Ids, tokens offset mapping. - public TokenizerResult Encode(string sequence) => Encode(sequence, skipSpecialTokens: false); - /// /// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping. /// @@ -140,33 +133,41 @@ public IReadOnlyList EncodeToIds(string sequence, bool skipSpecialTokens = throw new ArgumentNullException(nameof(sequence)); } - string normalized; - NormalizedString normalizedString = default; + string normalized = Normalizer is not null ? Normalizer.Normalize(sequence).Normalized : sequence; + List idsList = new(); - bool offsetsMappedToOriginal = true; - if (Normalizer is not null) + foreach (Split split in PreTokenizer.PreTokenize(normalized, skipSpecialTokens)) { - normalizedString = Normalizer.Normalize(sequence); - normalized = normalizedString.Normalized; - - offsetsMappedToOriginal = normalizedString.CanMapToOriginal; + Model.TokenizeToIds(split.TokenString, split.IsSpecialToken, idsList); } - else + + return idsList; + } + + /// + /// Get the number of tokens that the input sequence will be encoded to. + /// + /// The text to tokenize. + /// Indicate if want to skip the special tokens during the encoding. + /// The number of tokens Ids that the input sequence will be encoded to. + /// The input sequence is null. + /// Unable to tokenize the sequence. + public int CountTokens(string sequence, bool skipSpecialTokens = false) + { + if (sequence is null) { - normalized = sequence; + throw new ArgumentNullException(nameof(sequence)); } - List idsList = new(); + string normalized = Normalizer is not null ? Normalizer.Normalize(sequence).Normalized : sequence; + int idsCount = 0; foreach (Split split in PreTokenizer.PreTokenize(normalized, skipSpecialTokens)) { - if (!Model.TokenizeToIds(split.TokenString, split.IsSpecialToken, idsList)) - { - throw new ArgumentException($"Unable to tokenize the sequence: {split.TokenString}"); - } + idsCount += Model.CountTokens(split.TokenString, split.IsSpecialToken); } - return idsList; + return idsCount; } // skipSpecialTokens is used in post processing we don't support yet. We are keeping it to allow using it when we support post processing. @@ -199,12 +200,19 @@ public IReadOnlyList EncodeToIds(string sequence, bool skipSpecialTokens = List tokens = new List(); - foreach (int id in ids) + if (Model is EnglishRoberta robertaModel) { - if (Model.GetType() == typeof(EnglishRoberta)) - tokens.Add(Model.IdToString(id) ?? ""); - else - tokens.Add(Model.IdToToken(id) ?? ""); + foreach (int id in ids) + { + tokens.Add(robertaModel.IdToFilteredToken(id, skipSpecialTokens) ?? ""); + } + } + else + { + foreach (int id in ids) + { + tokens.Add(Model.IdToToken(id, skipSpecialTokens) ?? ""); + } } return Decoder?.Decode(tokens) ?? string.Join("", tokens); @@ -256,11 +264,6 @@ public void TrainFromFiles( // self.add_special_tokens(&special_tokens); } - public bool IsValidChar(char ch) - { - return Model.IsValidChar(ch); - } - private const string EndOfText = "<|endoftext|>"; private const string FimPrefix = "<|fim_prefix|>"; private const string FimMiddle = "<|fim_middle|>"; diff --git a/src/Microsoft.ML.TorchSharp/Roberta/QATrainer.cs b/src/Microsoft.ML.TorchSharp/Roberta/QATrainer.cs index 6e69c56e32..75e667fe39 100644 --- a/src/Microsoft.ML.TorchSharp/Roberta/QATrainer.cs +++ b/src/Microsoft.ML.TorchSharp/Roberta/QATrainer.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Internal.Utilities; @@ -62,7 +63,7 @@ public sealed class Options : NasBertOptions public string AnswerIndexStartColumnName = DefaultColumnNames.AnswerIndex; /// - /// Number of top predicted answers in question answering task. + /// Number of top predicted answers in question answering task. /// public int TopKAnswers = DefaultColumnNames.TopKAnswers; @@ -435,6 +436,9 @@ private torch.Tensor PrepareBatchTensor(ref List inputTensors, Device de private Dictionary AlignAnswerPosition(IReadOnlyList tokens, string text) { + EnglishRoberta robertaModel = Tokenizer.Model as EnglishRoberta; + Debug.Assert(robertaModel is not null); + var mapping = new Dictionary(); int surrogateDeduce = 0; for (var (i, j, tid) = (0, 0, 0); i < text.Length && tid < tokens.Count;) @@ -457,7 +461,7 @@ private Dictionary AlignAnswerPosition(IReadOnlyList tokens, s ++i; } // Chars not included in tokenizer will not appear in tokens - else if (!Tokenizer.IsValidChar(text[i])) + else if (!robertaModel.CharInSupportedRange(text[i])) { mapping[i - surrogateDeduce] = tid; ++i; diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs index d0f1116976..e1aaf74a6e 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs @@ -140,10 +140,13 @@ public void SimpleTestWithUnknownToken(Dictionary vocab, (string, s Tokenizer tokenizer = new Tokenizer(bpe); TokenizerResult encoding = tokenizer.Encode(sentence); + IReadOnlyList idsList = tokenizer.EncodeToIds(sentence); Assert.Equal(expectedTokens.Length, encoding.Tokens.Count); Assert.Equal(offsets.Length, encoding.Offsets.Count); Assert.Equal(ids.Length, encoding.Ids.Count); + Assert.Equal(ids.Length, idsList.Count); + Assert.Equal(ids.Length, tokenizer.CountTokens(sentence)); Assert.Equal(decodedTokens, tokenizer.Decode(encoding.Ids)); for (int i = 0; i < encoding.Tokens.Count; i++) @@ -151,6 +154,7 @@ public void SimpleTestWithUnknownToken(Dictionary vocab, (string, s Assert.Equal(expectedTokens[i], encoding.Tokens[i]); Assert.Equal(offsets[i], encoding.Offsets[i]); Assert.Equal(ids[i], encoding.Ids[i]); + Assert.Equal(ids[i], idsList[i]); Assert.Equal(encoding.Tokens[i], tokenizer.Model.IdToToken(encoding.Ids[i])); Assert.Equal(encoding.Ids[i], tokenizer.Model.TokenToId(encoding.Tokens[i])); Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i])); @@ -232,11 +236,13 @@ public void TestTrainingLoadingVocabFile() foreach (object?[] arguments in BpeTestData) { TokenizerResult enc = tokenizer.Encode((string)arguments[0]!); + IReadOnlyList ids = tokenizer.EncodeToIds((string)arguments[0]!); Assert.Equal((string)arguments[0]!, enc.OriginalString); Assert.Equal((string[])arguments[1]!, enc.Tokens); (int, int)[] offsets = ((int, int)[])arguments[2]!; - for (int i = 0; i < offsets.Length; i++) - Assert.Equal(offsets[i], enc.Offsets[i]); + Assert.Equal(offsets, enc.Offsets); + Assert.Equal(enc.Tokens.Count, ids.Count); + Assert.Equal(enc.Tokens.Count, tokenizer.CountTokens((string)arguments[0]!)); Assert.Equal(enc.Tokens.Count, enc.Ids.Count); @@ -244,6 +250,7 @@ public void TestTrainingLoadingVocabFile() for (int i = 0; i < enc.Ids.Count; i++) { Assert.Equal(vocab[enc.Tokens[i]], enc.Ids[i]); + Assert.Equal(enc.Ids[i], ids[i]); } } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs index 8e4832ba43..c01e0a809c 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs @@ -48,7 +48,7 @@ public static IEnumerable BertaData // Sentence, Expected Ids, Expected Tokens, Expected Offsets, Decoded Tokens, Token occurrence values yield return new object[] { - "In the night.", // Heighest occurence tokens + "In the night.", // Highest occurrence tokens new int[] { 818, 262, 1755, 13 }, new string[] { "In", "\u0120the", "\u0120night", "." }, new (int, int)[] { (0, 2), (2, 6), (6, 12), (12, 13) }, @@ -131,7 +131,11 @@ private void TestTokenizer(Tokenizer tokenizer) foreach (object[] p in BertaData) { TokenizerResult encoding = tokenizer.Encode((string)p[0]); + IReadOnlyList ids = tokenizer.EncodeToIds((string)p[0]); + int idsCount = tokenizer.CountTokens((string)p[0]); Assert.Equal(p[1], encoding.Ids); + Assert.Equal(p[1], ids); + Assert.Equal(((int[])p[1]).Length, idsCount); Assert.Equal(p[2], encoding.Tokens); Assert.Equal(p[3], encoding.Offsets); Assert.Equal(encoding.Ids.Count, encoding.Tokens.Count); diff --git a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs index 3e2b2913f6..dd554405d4 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs @@ -40,9 +40,12 @@ public void TestGPT4TokenizationEncoding() Assert.Equal(text, GPT4.Decode(encoded.ToArray())!); TokenizerResult result = GPT4.Encode(text); - Assert.Equal(new List() { 9906, 4435 }, result.Ids); + int idsCount = GPT4.CountTokens(text); + Assert.Equal(encoded, result.Ids); Assert.Equal(new string[] { "Hello", " World" }, result.Tokens); Assert.Equal(new List<(int, int)> { (0, 5), (5, 11) }, result.Offsets); + Assert.Equal(encoded.Count, idsCount); + Assert.Equal(encoded, result.Ids); } [Fact] @@ -54,9 +57,12 @@ public void TestEncode1() Assert.Equal(text, GPT4.Decode(encoded.ToArray())); TokenizerResult result = GPT4.Encode(text); - Assert.Equal(new List() { 100264, 9906, 4435, 100265 }, result.Ids); + int idsCount = GPT4.CountTokens(text); + Assert.Equal(encoded, result.Ids); Assert.Equal(new string[] { "<|im_start|>", "Hello", " World", "<|im_end|>" }, result.Tokens); Assert.Equal(new List<(int, int)> { (0, 12), (12, 17), (17, 23), (23, 33) }, result.Offsets); + Assert.Equal(encoded.Count, idsCount); + Assert.Equal(encoded, result.Ids); } [Fact] @@ -65,6 +71,8 @@ public void TestEncode2() string text = ReadAndSanitizeFile("./Data/lib.rs.txt"); IReadOnlyList encoded = GPT4.EncodeToIds(text, skipSpecialTokens: true); Assert.Equal(5584, encoded.Count); + int idsCount = GPT4.CountTokens(text, skipSpecialTokens: true); + Assert.Equal(encoded.Count, idsCount); using (Stream stream = File.OpenRead("./Data/tokens.json")) { @@ -86,7 +94,9 @@ public void TestEncode3() Assert.Equal(text, decoded); TokenizerResult result = GPT4.Encode(text); - Assert.Equal(new List() { 100264, 9906, 100265, 4435 }, result.Ids); + int idsCount = GPT4.CountTokens(text); + Assert.Equal(encoded, result.Ids); + Assert.Equal(encoded.Count, idsCount); Assert.Equal(new string[] { "<|im_start|>", "Hello", "<|im_end|>", " World" }, result.Tokens); Assert.Equal(new List<(int, int)> { (0, 12), (12, 17), (17, 27), (27, 33) }, result.Offsets); } @@ -99,9 +109,11 @@ public void TestEncode4() Assert.Empty(encoded); TokenizerResult result = GPT4.Encode(text); + int idsCount = GPT4.CountTokens(text); Assert.Empty(result.Ids); Assert.Empty(result.Tokens); Assert.Empty(result.Offsets); + Assert.Equal(result.Ids.Count, idsCount); } [Fact] @@ -109,11 +121,13 @@ public void TestEncode5() { string text = "<|im_start|>Hello ⭐ World<|im_end|>"; IReadOnlyList encoded = GPT4.EncodeToIds(text); + int idsCount = GPT4.CountTokens(text); Assert.Equal(new List() { 100264, 9906, 2928, 99834, 4435, 100265 }, encoded); Assert.Equal(text, GPT4.Decode(encoded.ToArray())); TokenizerResult result = GPT4.Encode(text); - Assert.Equal(new List() { 100264, 9906, 2928, 99834, 4435, 100265 }, result.Ids); + Assert.Equal(encoded, result.Ids); + Assert.Equal(encoded.Count, idsCount); Assert.Equal(new string[] { "<|im_start|>", "Hello", " ⭐", "", " World", "<|im_end|>" }, result.Tokens); Assert.Equal(new List<(int, int)> { (0, 12), (12, 17), (17, 19), (19, 19), (19, 25), (25, 35) }, result.Offsets); } @@ -123,7 +137,9 @@ public void TestEncodeGpt2() { string text = ReadAndSanitizeFile("./Data/lib.rs.txt"); IReadOnlyList encoded = GPT2.EncodeToIds(text); + int idsCount = GPT2.CountTokens(text); Assert.Equal(11378, encoded.Count); + Assert.Equal(encoded.Count, idsCount); using (Stream stream = File.OpenRead("./Data/tokens_gpt2.json")) { @@ -140,7 +156,9 @@ public void TestEncodeP50kBase() { string text = ReadAndSanitizeFile("./Data/lib.rs.txt"); IReadOnlyList encoded = P50kBase.EncodeToIds(text); + int idsCount = P50kBase.CountTokens(text); Assert.Equal(7230, encoded.Count); + Assert.Equal(encoded.Count, idsCount); using (Stream stream = File.OpenRead("./Data/tokens_p50k_base.json")) { @@ -157,7 +175,9 @@ public void TestEncodeP50kEdit() { string text = ReadAndSanitizeFile("./Data/lib.rs.txt"); IReadOnlyList encoded = P50kEdit.EncodeToIds(text); + int idsCount = P50kEdit.CountTokens(text); Assert.Equal(7230, encoded.Count); + Assert.Equal(encoded.Count, idsCount); using (Stream stream = File.OpenRead("./Data/tokens_p50k_edit.json")) { @@ -174,7 +194,9 @@ public void TestEncodeR50kBase() { string text = ReadAndSanitizeFile("./Data/lib.rs.txt"); IReadOnlyList encoded = R50kBase.EncodeToIds(text); + int idsCount = R50kBase.CountTokens(text); Assert.Equal(11378, encoded.Count); + Assert.Equal(encoded.Count, idsCount); using (Stream stream = File.OpenRead("./Data/tokens_r50k_base.json")) {