From ed88215736a510fca410327e968233f2daa9d009 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 16 Feb 2024 21:58:51 -0500 Subject: [PATCH] First round of perf improvements for tiktoken --- src/Microsoft.ML.Tokenizers/AssemblyInfo.cs | 7 + .../Microsoft.ML.Tokenizers.csproj | 1 + .../Model/EnglishRoberta.cs | 19 +- src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 263 +++++++++++++----- .../PreTokenizer/PreTokenizer.cs | 118 +++----- .../PreTokenizer/Roberta.cs | 18 +- .../PreTokenizer/TikTokenPreTokenizer.cs | 144 ++-------- .../PreTokenizer/Whitespace.cs | 18 +- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 91 +++--- .../TokenizerResult.cs | 12 +- .../Utils/ByteArrayComparer.cs | 26 +- .../Utils/BytePairEncoder.cs | 21 +- .../Utils/IListExtensions.cs | 12 +- 13 files changed, 376 insertions(+), 374 deletions(-) create mode 100644 src/Microsoft.ML.Tokenizers/AssemblyInfo.cs diff --git a/src/Microsoft.ML.Tokenizers/AssemblyInfo.cs b/src/Microsoft.ML.Tokenizers/AssemblyInfo.cs new file mode 100644 index 0000000000..4b89b383d5 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/AssemblyInfo.cs @@ -0,0 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#if NET5_0_OR_GREATER +[module: System.Runtime.CompilerServices.SkipLocalsInit] +#endif diff --git a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj index e50c62889b..d370145bad 100644 --- a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj +++ b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj @@ -5,6 +5,7 @@ netstandard2.0;net8.0 enable Microsoft.ML.Tokenizers contains the implmentation of the tokenization used in the NLP transforms. + true diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index e750a5df6c..ad98ed917c 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -8,10 +8,7 @@ using System.Diagnostics; using System.IO; using System.Linq; -using System.Runtime.CompilerServices; -using System.Text; using System.Text.Json; -using System.Text.Json.Serialization; namespace Microsoft.ML.Tokenizers { @@ -27,7 +24,7 @@ public sealed class EnglishRoberta : Model private readonly IReadOnlyDictionary _byteToUnicode; private readonly IReadOnlyDictionary _unicodeToByte; private readonly string[] _charToString; - private readonly Cache> _cache; + private readonly Cache> _cache; /// /// Construct tokenizer object to use with the English Robert model. @@ -72,7 +69,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc } _unicodeToByte = _byteToUnicode.Reverse(); - _cache = new Cache>(); + _cache = new Cache>(); } /// @@ -110,7 +107,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes } _unicodeToByte = _byteToUnicode.Reverse(); - _cache = new Cache>(); + _cache = new Cache>(); } // @@ -226,17 +223,17 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok { ArrayPool.Shared.Return(token); ArrayPool.Shared.Return(indexMapping); - return Bpe.EmptyTokensList; + return Array.Empty(); } - if (_cache.TryGet(sequence, out IReadOnlyList? hit)) + if (_cache.TryGet(sequence, out List? hit)) { ArrayPool.Shared.Return(token); ArrayPool.Shared.Return(indexMapping); return ModifyTokenListOffsets(hit, indexMapping); } - IReadOnlyList result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping); + List result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping); _cache.Set(sequence, result); ArrayPool.Shared.Return(token); ArrayPool.Shared.Return(indexMapping); @@ -261,7 +258,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok private int TokenizeToIds(string sequence, IList? accumulatedIds) { - if (_cache.TryGet(sequence, out IReadOnlyList? hit)) + if (_cache.TryGet(sequence, out List? hit)) { if (accumulatedIds is not null) { @@ -299,7 +296,7 @@ private int TokenizeToIds(string sequence, IList? accumulatedIds) return 0; } - IReadOnlyList result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping); + List result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping); _cache.Set(sequence, result); return result.Count; } diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index bd9a376e20..9935dd6428 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -1,13 +1,15 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Text; +using System.Threading.Tasks; namespace Microsoft.ML.Tokenizers { @@ -16,14 +18,12 @@ namespace Microsoft.ML.Tokenizers /// public sealed class Tiktoken : Model { - private Dictionary _encoder = null!; - private IReadOnlyDictionary _decoder = null!; + private readonly Dictionary, int> _encoder = null!; + private readonly IReadOnlyDictionary _decoder = null!; private readonly LruCache _cache; - private IReadOnlyDictionary? _specialTokensEncoder; - private Dictionary? _specialTokensDecoder; - - private Dictionary _vocab = null!; - private static readonly List _emptyTokenList = new(); + private readonly IReadOnlyDictionary? _specialTokensEncoder; + private readonly Dictionary? _specialTokensDecoder; + private readonly Dictionary _vocab = null!; /// /// Create a new Tiktoken tokenizer object. @@ -33,17 +33,9 @@ public sealed class Tiktoken : Model /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE rank file. - public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize) + public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : + this(string.IsNullOrEmpty(tikTokenBpeFile) ? throw new ArgumentNullException(nameof(tikTokenBpeFile)) : File.OpenRead(tikTokenBpeFile), specialTokensEncoder, cacheSize, disposeStream: true) { - if (string.IsNullOrEmpty(tikTokenBpeFile)) - { - throw new ArgumentNullException(nameof(tikTokenBpeFile)); - } - - using (Stream stream = File.OpenRead(tikTokenBpeFile)) - { - Initialize(stream, specialTokensEncoder); - } } /// @@ -54,17 +46,17 @@ public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specia /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE rank file. - public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize) + public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : + this(tikTokenBpeFileStream ?? throw new ArgumentNullException(nameof(tikTokenBpeFileStream)), specialTokensEncoder, cacheSize, disposeStream: false) { - Initialize(tikTokenBpeFileStream, specialTokensEncoder); } internal Tiktoken( - Dictionary encoder, - IReadOnlyDictionary decoder, - Dictionary vocab, - IReadOnlyDictionary? specialTokensEncoder = null, - int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize) + Dictionary, int> encoder, + IReadOnlyDictionary decoder, + Dictionary vocab, + IReadOnlyDictionary? specialTokensEncoder = null, + int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize) { Debug.Assert(encoder is not null); Debug.Assert(decoder is not null); @@ -81,36 +73,42 @@ internal Tiktoken( } } - private Tiktoken(int cacheSize) - { - _cache = new LruCache(cacheSize); - } - - private void Initialize(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder = null) + private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder, int cacheSize, bool disposeStream) : this(cacheSize) { - if (tikTokenBpeFileStream is null) + try { - throw new ArgumentNullException(nameof(tikTokenBpeFileStream)); - } - - (_encoder, _vocab, _decoder) = LoadTikTokenBpe(tikTokenBpeFileStream); + (_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(tikTokenBpeFileStream, useAsync: false).GetAwaiter().GetResult(); - _specialTokensEncoder = specialTokensEncoder; - if (_specialTokensEncoder is not null) + _specialTokensEncoder = specialTokensEncoder; + if (_specialTokensEncoder is not null) + { + _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); + } + } + finally { - _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); + if (disposeStream) + { + tikTokenBpeFileStream.Dispose(); + } } } + private Tiktoken(int cacheSize) + { + _cache = new LruCache(cacheSize); + } + /// /// Load BPE rank dictionary from a stream. /// /// Stream to the BPE rank file + /// Whether to perform I/O synchronously or asynchronously. /// Map of byte[] to integer token id /// - internal static (Dictionary, Dictionary, IReadOnlyDictionary) LoadTikTokenBpe(Stream tikTokenBpeFileStream) + internal static async ValueTask<(Dictionary, int>, Dictionary, IReadOnlyDictionary)> LoadTikTokenBpeAsync(Stream tikTokenBpeFileStream, bool useAsync) { - var encoder = new Dictionary(new ByteArrayComparer()); + var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance); var vocab = new Dictionary(); var decoder = new Dictionary(); @@ -118,11 +116,17 @@ internal static (Dictionary, Dictionary, IReadOnlyDict { using (StreamReader reader = new StreamReader(tikTokenBpeFileStream)) { - while (!reader.EndOfStream) + while (true) { - string? line = reader.ReadLine(); + string? line = useAsync ? + await reader.ReadLineAsync().ConfigureAwait(false) : + reader.ReadLine(); if (string.IsNullOrWhiteSpace(line)) { + if (line is null) + { + break; + } continue; } @@ -172,11 +176,11 @@ internal static (Dictionary, Dictionary, IReadOnlyDict /// The list of tokens generated from the sequence tokenization. public override IReadOnlyList Tokenize(string sequence, bool isSpecialToken) { - List tokens; + Token[] tokens; if (string.IsNullOrEmpty(sequence)) { - return _emptyTokenList; + return Array.Empty(); } if (isSpecialToken) @@ -196,12 +200,12 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok if (_cache.Lookup(sequence, out int[] ids)) { - tokens = new(ids.Length); - tokens.Add(new Token(ids[0], sequence, (0, sequence.Length))); + tokens = new Token[ids.Length]; + tokens[0] = new Token(ids[0], sequence, (0, sequence.Length)); for (int i = 1; i < ids.Length; i++) { // One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width. - tokens.Add(new Token(ids[i], "", (sequence.Length, sequence.Length))); + tokens[i] = new Token(ids[i], "", (sequence.Length, sequence.Length)); } return tokens; @@ -213,17 +217,22 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok return new List { new(mappedId, sequence, (0, sequence.Length)) }; } - int[] encodedIds = BytePairEncoder.BytePairEncode(Encoding.UTF8.GetBytes(sequence), _encoder); + byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length)); + int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray); + + int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); + Debug.Assert(encodedIds.Length > 0); _cache.Add(sequence, encodedIds); - tokens = new List(encodedIds.Length); - tokens.Add(new Token(encodedIds[0], sequence, (0, sequence.Length))); + tokens = new Token[encodedIds.Length]; + tokens[0] = new Token(encodedIds[0], sequence, (0, sequence.Length)); for (int i = 1; i < encodedIds.Length; i++) { // One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width. - tokens.Add(new Token(encodedIds[i], "", (sequence.Length, sequence.Length))); + tokens[i] = new Token(encodedIds[i], "", (sequence.Length, sequence.Length)); } + ArrayPool.Shared.Return(arrayPoolArray); return tokens; } @@ -262,10 +271,15 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IList.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length)); + int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray); + + int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); _cache.Add(sequence, encodedIds); accumulatedIds.AddRange(encodedIds); + + ArrayPool.Shared.Return(arrayPoolArray); return; } @@ -284,7 +298,7 @@ public override int CountTokens(string sequence, bool isSpecialToken) if (isSpecialToken && _specialTokensEncoder is not null) { - return _specialTokensEncoder.TryGetValue(sequence, out int id) ? 1 : 0; + return _specialTokensEncoder.TryGetValue(sequence, out _) ? 1 : 0; } if (_cache.Lookup(sequence, out int[] ids)) @@ -292,14 +306,18 @@ public override int CountTokens(string sequence, bool isSpecialToken) return ids.Length; } - if (_vocab.TryGetValue(sequence, out int mappedId)) + if (_vocab.TryGetValue(sequence, out _)) { return 1; } - int[] encodedIds = BytePairEncoder.BytePairEncode(Encoding.UTF8.GetBytes(sequence), _encoder); + byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length)); + int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray); + + int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); _cache.Add(sequence, encodedIds); + ArrayPool.Shared.Return(arrayPoolArray); return encodedIds.Length; } @@ -343,15 +361,25 @@ public override int CountTokens(string sequence, bool isSpecialToken) return id; } - int[] idsToCache = BytePairEncoder.BytePairEncode(Encoding.UTF8.GetBytes(token), _encoder); - _cache.Add(token, idsToCache); + byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length)); + try + { + int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray); + + int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); + _cache.Add(token, idsToCache); + + if (idsToCache.Length == 1) + { + return idsToCache[0]; + } - if (idsToCache.Length == 1) + return null; + } + finally { - return idsToCache[0]; + ArrayPool.Shared.Return(arrayPoolArray); } - - return null; } /// @@ -382,26 +410,66 @@ public override int CountTokens(string sequence, bool isSpecialToken) return null; } - List utf8Bytes = new(); - bool useSpecialTokens = !skipSpecialTokens && _specialTokensDecoder is not null; - - foreach (int id in ids) + byte[]? arrayPoolArray = null; + try { - if (_decoder.TryGetValue(id, out byte[]? tokenBytes)) + Span utf8Bytes = stackalloc byte[256]; + int utf8ByteCount = 0; + + bool useSpecialTokens = !skipSpecialTokens && _specialTokensDecoder is not null; + + foreach (int id in ids) { - utf8Bytes.AddRange(tokenBytes); + if (_decoder.TryGetValue(id, out byte[]? tokenBytes)) + { + if ((uint)utf8ByteCount + (uint)tokenBytes.Length > (uint)utf8Bytes.Length) + { + ArrayPoolGrow(ref utf8Bytes, ref arrayPoolArray, utf8ByteCount + tokenBytes.Length); + } + + tokenBytes.AsSpan().CopyTo(utf8Bytes.Slice(utf8ByteCount)); + utf8ByteCount += tokenBytes.Length; + } + else if (useSpecialTokens && _specialTokensDecoder!.TryGetValue(id, out string? token)) + { + while (true) + { + if (TryGetUtf8Bytes(token.AsSpan(), utf8Bytes.Slice(utf8ByteCount), out int bytesWritten)) + { + utf8ByteCount += bytesWritten; + break; + } + + ArrayPoolGrow(ref utf8Bytes, ref arrayPoolArray, utf8ByteCount + Encoding.UTF8.GetByteCount(token)); + } + } + else + { + return null; + } } - else if (useSpecialTokens && _specialTokensDecoder!.TryGetValue(id, out string? token)) + + return GetString(utf8Bytes.Slice(0, utf8ByteCount)); + } + finally + { + if (arrayPoolArray is not null) { - utf8Bytes.AddRange(Encoding.UTF8.GetBytes(token)); + ArrayPool.Shared.Return(arrayPoolArray); } - else + } + + static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray, int requiredCapacity) + { + byte[] tmp = ArrayPool.Shared.Rent(Math.Max(utf8Bytes.Length * 2, requiredCapacity)); + utf8Bytes.CopyTo(tmp.AsSpan()); + byte[]? toReturn = arrayPoolArray; + utf8Bytes = arrayPoolArray = tmp; + if (toReturn is not null) { - return null; + ArrayPool.Shared.Return(toReturn); } } - - return utf8Bytes.Count > 0 ? Encoding.UTF8.GetString(utf8Bytes.ToArray()) : string.Empty; } /// @@ -426,5 +494,50 @@ public override int CountTokens(string sequence, bool isSpecialToken) /// Gets a trainer object to use in training the model. /// public override Trainer? GetTrainer() => throw new NotImplementedException(); + + private static unsafe int GetUtf8Bytes(ReadOnlySpan source, Span destination) + { +#if NETCOREAPP + return Encoding.UTF8.GetBytes(source, destination); +#else + fixed (char* sourcePtr = source) + fixed (byte* destPtr = destination) + { + return Encoding.UTF8.GetBytes(sourcePtr, source.Length, destPtr, destination.Length); + } +#endif + } + + private static unsafe bool TryGetUtf8Bytes(ReadOnlySpan source, Span destination, out int bytesWritten) + { +#if NET8_0_OR_GREATER + return Encoding.UTF8.TryGetBytes(source, destination, out bytesWritten); +#else + fixed (char* sourcePtr = source) + fixed (byte* destPtr = destination) + { + if (Encoding.UTF8.GetByteCount(sourcePtr, source.Length) <= destination.Length) + { + bytesWritten = Encoding.UTF8.GetBytes(sourcePtr, source.Length, destPtr, destination.Length); + return true; + } + + bytesWritten = 0; + return false; + } +#endif + } + + private static unsafe string GetString(ReadOnlySpan utf8Bytes) + { +#if NETCOREAPP + return Encoding.UTF8.GetString(utf8Bytes); +#else + fixed (byte* sourcePtr = utf8Bytes) + { + return Encoding.UTF8.GetString(sourcePtr, utf8Bytes.Length); + } +#endif + } } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs index 94acfcb96f..aef8b13c42 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs @@ -3,9 +3,7 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections; using System.Collections.Generic; -using System.Diagnostics; using System.Text.RegularExpressions; namespace Microsoft.ML.Tokenizers @@ -15,14 +13,22 @@ namespace Microsoft.ML.Tokenizers /// in the original string. These offsets are in the `original` referential. /// It also contains any `Token` associated to the current split. /// - public readonly struct Split : IEquatable + public struct Split : IEquatable { + private readonly string? _originalString; + private string? _tokenString; + /// /// Gets the underlying split token. Each SubString is represented by a token /// and in the end we might be carrying a lot of SubString representing various parts of the /// original input string. /// - public string TokenString { get; } + public string TokenString => _tokenString ??= _originalString!.Substring(Offset.Index, Offset.End - Offset.Index); + + /// + /// Gets the underlying split token as a span. + /// + public ReadOnlySpan TokenSpan => _tokenString is string s ? s.AsSpan() : _originalString.AsSpan(Offset.Index, Offset.End - Offset.Index); /// /// Returns the offset mapping to the original string @@ -37,7 +43,15 @@ namespace Microsoft.ML.Tokenizers /// Indicates whether the token is a special token public Split(string token, (int Index, int End) offset, bool isSpecialToken = false) { - TokenString = token; + _tokenString = token; + Offset = offset; + IsSpecialToken = isSpecialToken; + } + + internal Split(string originalString, string? token, (int Index, int End) offset, bool isSpecialToken = false) + { + _originalString = originalString; + _tokenString = token; Offset = offset; IsSpecialToken = isSpecialToken; } @@ -52,21 +66,18 @@ public Split(string token, (int Index, int End) offset, bool isSpecialToken = fa /// /// The Split object to compare with the current object. public bool Equals(Split other) => - TokenString == other.TokenString && + (_originalString == other._originalString || TokenString == other.TokenString) && IsSpecialToken == other.IsSpecialToken && Offset.Index == other.Offset.Index && Offset.End == other.Offset.End; } - /// /// Base class for all pre-tokenizers classes. /// The PreTokenizer is in charge of doing the pre-segmentation step. /// public abstract class PreTokenizer { - internal static readonly IReadOnlyList EmptyList = new List(); - /// /// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. /// @@ -74,89 +85,36 @@ public abstract class PreTokenizer /// Indicates whether to skip the special tokens. /// The list of the splits containing the tokens and the token's offsets to the original string. public abstract IEnumerable PreTokenize(string sentence, bool skipSpecialTokens = false); - } - internal sealed class RegexSplitEnumerable : IEnumerable - { - private readonly static Dictionary _regexCache = new(StringComparer.Ordinal); - private readonly Regex _regex; - private readonly string _sentence; - - public RegexSplitEnumerable(string sentence, string pattern) + internal static IEnumerable SplitSentence(string sentence, Regex regex) { - Debug.Assert(sentence is not null); - Debug.Assert(pattern is not null); - - Regex? regex; - lock (_regexCache) + (int Offset, int Length) match; + int beginning = 0; + while (TryGetMatch(regex, sentence, beginning, sentence.Length - beginning, out match)) { - if (!_regexCache.TryGetValue(pattern!, out regex)) - { - regex = new Regex(pattern, RegexOptions.Compiled); - _regexCache[pattern!] = regex; - } + yield return new Split(sentence, null, (match.Offset, match.Offset + match.Length)); + beginning = match.Offset + match.Length; } - - _regex = regex; - _sentence = sentence!; } - public IEnumerator GetEnumerator() => new RegexSplitEnumerator(_regex, _sentence); - - IEnumerator IEnumerable.GetEnumerator() => new RegexSplitEnumerator(_regex, _sentence); - - private sealed class RegexSplitEnumerator : IEnumerator + internal static bool TryGetMatch(Regex regex, string sentence, int beginning, int length, out (int offset, int length) match) { - private Split _current = default; - private readonly Regex _regex; - private Match? _tokenMatch; - private readonly string _sentence; - - public RegexSplitEnumerator(Regex regex, string sentence) - { - Debug.Assert(sentence is not null); - Debug.Assert(regex is not null); - - _regex = regex!; - _sentence = sentence!; - } - - public Split Current => _current; - - object IEnumerator.Current => _current; - - public bool MoveNext() +#if NET7_0_OR_GREATER + foreach (ValueMatch m in regex.EnumerateMatches(sentence.AsSpan(beginning, length))) { - if (_tokenMatch is null) - { - _tokenMatch = _regex.Match(_sentence); - } - else if (!_tokenMatch.Success) - { - return false; - } - else - { - _tokenMatch = _tokenMatch.NextMatch(); - } - - if (!_tokenMatch.Success) - { - return false; - } - - _current = new Split(_tokenMatch.Value, (_tokenMatch.Index, _tokenMatch.Index + _tokenMatch.Length)); + match = (beginning + m.Index, m.Length); return true; } - - public void Reset() - { - _tokenMatch = null; - } - - public void Dispose() +#else + Match m = regex.Match(sentence, beginning, length); + if (m.Success) { + match = (m.Index, m.Length); + return true; } +#endif + match = default; + return false; } } } diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs index e07e755c29..8fd748d838 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs @@ -4,20 +4,28 @@ using System; using System.Collections.Generic; +using System.Text.RegularExpressions; namespace Microsoft.ML.Tokenizers { /// /// The pre-tokenizer for Roberta English tokenizer. /// - public sealed class RobertaPreTokenizer : PreTokenizer + public sealed partial class RobertaPreTokenizer : PreTokenizer { /// /// Gets a singleton instance of the Roberta pre-tokenizer.. /// - public static readonly RobertaPreTokenizer Instance = new RobertaPreTokenizer(); + public static RobertaPreTokenizer Instance { get; } = new RobertaPreTokenizer(); - private const string Pattern = @"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; + private const string PretokenizePattern = @"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; +#if NET7_0_OR_GREATER + [GeneratedRegex(PretokenizePattern)] + private static partial Regex PretokenizeRegex(); +#else + private static readonly Regex _regex = new Regex(PretokenizePattern, RegexOptions.Compiled); + private static Regex PretokenizeRegex() => _regex; +#endif /// /// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. @@ -29,10 +37,10 @@ public override IEnumerable PreTokenize(string sentence, bool skipSpecial { if (string.IsNullOrEmpty(sentence)) { - return EmptyList; + return Array.Empty(); } - return new RegexSplitEnumerable(sentence, Pattern); + return SplitSentence(sentence, PretokenizeRegex()); } } } diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/TikTokenPreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/TikTokenPreTokenizer.cs index b64096de71..7651de599d 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/TikTokenPreTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/TikTokenPreTokenizer.cs @@ -1,11 +1,9 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; -using System.Collections; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; using System.Text.RegularExpressions; @@ -34,7 +32,7 @@ public TikTokenPreTokenizer(Regex regex, IReadOnlyDictionary? speci _regex = regex; - if (specialTokensEncoder is not null && specialTokensEncoder.Count > 0) + if (specialTokensEncoder is { Count: > 0 }) { _specialTokensRegex = new Regex(string.Join("|", specialTokensEncoder.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled); } @@ -50,131 +48,41 @@ public override IEnumerable PreTokenize(string sentence, bool skipSpecial { if (string.IsNullOrEmpty(sentence)) { - return EmptyList; + return Array.Empty(); } - return new TokenizationEnumerable(sentence, _regex, skipSpecialTokens ? null : _specialTokensRegex); - } - - private sealed class TokenizationEnumerable : IEnumerable - { - private readonly string _sentence; - private readonly Regex _regex; - private readonly Regex? _specialTokensRegex; - - public TokenizationEnumerable(string sentence, Regex regex, Regex? specialTokensRegex) - { - if (sentence is null) - { - throw new ArgumentNullException(nameof(sentence)); - } - - if (regex is null) - { - throw new ArgumentNullException(nameof(regex)); - } + return SplitSentences(sentence, _regex, skipSpecialTokens ? null : _specialTokensRegex); - _sentence = sentence; - _regex = regex; - _specialTokensRegex = specialTokensRegex; - } - - public IEnumerator GetEnumerator() => new TokenizationEnumerator(_sentence, _regex, _specialTokensRegex); - IEnumerator IEnumerable.GetEnumerator() => new TokenizationEnumerator(_sentence, _regex, _specialTokensRegex); - - private sealed class TokenizationEnumerator : IEnumerator + static IEnumerable SplitSentences(string sentence, Regex regex, Regex? specialTokensRegex) { - private Split _current = default; - private int _startIndex; - private int _offset; - private MatchCollection? _matches; - private int _matchIndex; - private Match? _specialTokenMatch; - private readonly Regex _regex; - private readonly string _sentence; - private readonly Regex? _specialTokensRegex; - - public TokenizationEnumerator(string sentence, Regex regex, Regex? specialTokensRegex) - { - Debug.Assert(sentence is not null); - Debug.Assert(regex is not null); + (int Offset, int Length) match; + int beginning = 0; - _sentence = sentence!; - _regex = regex!; - _specialTokensRegex = specialTokensRegex; - _startIndex = 0; - _offset = 0; - } - - object IEnumerator.Current => _current; - - Split IEnumerator.Current => _current; - - public bool MoveNext() + if (specialTokensRegex is not null) { - if (_matches is not null && _matchIndex < _matches.Count) + while (true) { - Match match = _matches[_matchIndex]; - _current = new Split(match.Value, (match.Index + _offset, match.Index + _offset + match.Length), false); - _startIndex += match.Length; - _matchIndex++; - return true; + (int Offset, int Length) specialMatch; + if (!TryGetMatch(specialTokensRegex, sentence, beginning, sentence.Length - beginning, out specialMatch)) + { + break; + } + + while (TryGetMatch(regex, sentence, beginning, specialMatch.Offset - beginning, out match)) + { + yield return new Split(sentence, null, (match.Offset, match.Offset + match.Length)); + beginning = match.Offset + match.Length; + } + + yield return new Split(sentence, null, (specialMatch.Offset, specialMatch.Offset + specialMatch.Length), isSpecialToken: true); + beginning = specialMatch.Offset + specialMatch.Length; } - - if (_specialTokenMatch is not null && _specialTokenMatch.Success) - { - _current = new Split(_specialTokenMatch.Value, (_specialTokenMatch.Index, _specialTokenMatch.Index + _specialTokenMatch.Length), true); - _startIndex += _specialTokenMatch.Length; - _specialTokenMatch = null; - return true; - } - - if (_startIndex >= _sentence.Length) - { - return false; - } - - if (_specialTokensRegex is not null) - { - _specialTokenMatch = _specialTokensRegex.Match(_sentence, _startIndex); - _offset = _startIndex; - _matches = _regex.Matches(_sentence.Substring(_startIndex, _specialTokenMatch.Success ? _specialTokenMatch.Index - _startIndex : _sentence.Length - _startIndex)); - } - else - { - _matches = _regex.Matches(_sentence); - } - - if (_matches.Count > 0) - { - Match match = _matches[0]; - _current = new Split(match.Value, (match.Index + _startIndex, match.Index + _startIndex + match.Length), false); - _startIndex += match.Length; - _matchIndex = 1; - return true; - } - else if (_specialTokenMatch is not null && _specialTokenMatch.Success) - { - _current = new Split(_specialTokenMatch.Value, (_specialTokenMatch.Index, _specialTokenMatch.Index + _specialTokenMatch.Length), true); - _startIndex += _specialTokenMatch.Length; - _specialTokenMatch = null; - return true; - } - - return false; - } - - public void Reset() - { - _current = default; - _startIndex = 0; - _matches = null; - _matchIndex = -1; - _specialTokenMatch = null; } - public void Dispose() + while (TryGetMatch(regex, sentence, beginning, sentence.Length - beginning, out match)) { + yield return new Split(sentence, null, (match.Offset, match.Offset + match.Length)); + beginning = match.Length + match.Offset; } } } diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs index d2d0158885..2a53bec814 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Text.RegularExpressions; namespace Microsoft.ML.Tokenizers { @@ -11,14 +12,21 @@ namespace Microsoft.ML.Tokenizers /// The pre-tokenizer which split the text at the word boundary. /// The word is a set of alphabet, numeric, and underscore characters. /// - public sealed class WhiteSpace : PreTokenizer + public sealed partial class WhiteSpace : PreTokenizer { /// /// Gets a singleton instance of the WhiteSpace pre-tokenizer.. /// - public static readonly WhiteSpace Instance = new WhiteSpace(); + public static WhiteSpace Instance { get; } = new WhiteSpace(); - private const string Pattern = @"\w+|[^\w\s]+"; + private const string PretokenizePattern = @"\w+|[^\w\s]+"; +#if NET7_0_OR_GREATER + [GeneratedRegex(PretokenizePattern)] + private static partial Regex PretokenizeRegex(); +#else + private static readonly Regex _regex = new Regex(PretokenizePattern, RegexOptions.Compiled); + private static Regex PretokenizeRegex() => _regex; +#endif /// /// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string. @@ -30,10 +38,10 @@ public override IEnumerable PreTokenize(string sentence, bool skipSpecial { if (string.IsNullOrEmpty(sentence)) { - return EmptyList; + return Array.Empty(); } - return new RegexSplitEnumerable(sentence, Pattern); + return SplitSentence(sentence, PretokenizeRegex()); } } } diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index aee4c84bcb..d002f55833 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -3,10 +3,10 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; -using System.Linq; using System.Net.Http; using System.Text.RegularExpressions; using System.Threading.Tasks; @@ -16,7 +16,7 @@ namespace Microsoft.ML.Tokenizers /// /// A Tokenizer works as a pipeline. It processes some raw text as input and outputs a TokenizerResult object. /// - public class Tokenizer + public partial class Tokenizer { /// /// Create a new Tokenizer object. @@ -282,15 +282,14 @@ private enum ModelEncoding GPT2 } - private static readonly IReadOnlyDictionary _modelPrefixToEncoding = - new Dictionary() - { + private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding = + [ // chat - { "gpt-4-", ModelEncoding.Cl100kBase }, // e.g., gpt-4-0314, etc., plus gpt-4-32k - { "gpt-3.5-turbo-", ModelEncoding.Cl100kBase } // e.g, gpt-3.5-turbo-0301, -0401, etc. - }; + ( "gpt-4-", ModelEncoding.Cl100kBase ), // e.g., gpt-4-0314, etc., plus gpt-4-32k + ( "gpt-3.5-turbo-", ModelEncoding.Cl100kBase ) // e.g, gpt-3.5-turbo-0301, -0401, etc. + ]; - private static readonly IReadOnlyDictionary _modelToEncoding = + private static readonly Dictionary _modelToEncoding = new Dictionary(StringComparer.OrdinalIgnoreCase) { // chat @@ -353,15 +352,15 @@ public static async Task CreateByModelNameAsync( IReadOnlyDictionary? extraSpecialTokens = null, Normalizer? normalizer = null) { - var encoder = ModelEncoding.None; + ModelEncoding encoder; if (!_modelToEncoding.TryGetValue(modelName, out encoder)) { - foreach (KeyValuePair kvp in _modelPrefixToEncoding) + foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) { - if (modelName.StartsWith(kvp.Key, StringComparison.OrdinalIgnoreCase)) + if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) { - encoder = kvp.Value; + encoder = Encoding; break; } } @@ -372,16 +371,30 @@ public static async Task CreateByModelNameAsync( throw new NotImplementedException($"Doesn't support this model [{modelName}]"); } - return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer); + return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer).ConfigureAwait(false); } private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; private const string P50kBaseRegexPattern = @"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; - const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"; - const string P50RegexUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"; - const string R50RegexUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"; - const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken"; + private const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"; + private const string P50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"; + private const string R50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"; + private const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken"; + +#if NET7_0_OR_GREATER + [GeneratedRegex(Cl100kBaseRegexPattern)] + private static partial Regex Cl100kBaseRegex(); + + [GeneratedRegex(P50kBaseRegexPattern)] + private static partial Regex P50kBaseRegex(); +#else + private static Regex? _cl100kBaseRegex; + private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled); + + private static Regex? _p50kBaseRegex; + private static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled); +#endif /// /// Create tokenizer based on encoder name and extra special tokens @@ -401,24 +414,24 @@ private static async Task CreateByEncoderNameAsync( case ModelEncoding.Cl100kBase: var specialTokens = new Dictionary { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }; - return await CreateTikTokenTokenizerAsync(Cl100kBaseRegexPattern, Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer); + return await CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); case ModelEncoding.P50kBase: specialTokens = new Dictionary { { EndOfText, 50256 } }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegexPattern, P50RegexUrl, specialTokens, extraSpecialTokens, normalizer); + return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); case ModelEncoding.P50kEdit: specialTokens = new Dictionary { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegexPattern, P50RegexUrl, specialTokens, extraSpecialTokens, normalizer); + return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); case ModelEncoding.R50kBase: specialTokens = new Dictionary { { EndOfText, 50256 } }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegexPattern, R50RegexUrl, specialTokens, extraSpecialTokens, normalizer); + return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); case ModelEncoding.GPT2: specialTokens = new Dictionary { { EndOfText, 50256 }, }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegexPattern, GPT2Url, specialTokens, extraSpecialTokens, normalizer); + return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); default: Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); @@ -426,39 +439,43 @@ private static async Task CreateByEncoderNameAsync( } } - private static readonly Dictionary, Dictionary, IReadOnlyDictionary)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); + private static readonly ConcurrentDictionary, int>, Dictionary, IReadOnlyDictionary)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); /// /// Create tokenizer based on regex pattern, BPE rank file and special tokens /// - /// Regex pattern to break a long string + /// Regex to break a long string /// BPE rank file - /// Special tokens mapping + /// Special tokens mapping. This may be mutated by the method. /// Extra special tokens other than the built-in ones for the encoder /// To normalize the text before tokenization /// The tokenizer private static async Task CreateTikTokenTokenizerAsync( - string regexPatternStr, - string mergeableRanksFileUrl, - Dictionary specialTokens, - IReadOnlyDictionary? extraSpecialTokens, - Normalizer? normalizer) + Regex regex, + string mergeableRanksFileUrl, + Dictionary specialTokens, + IReadOnlyDictionary? extraSpecialTokens, + Normalizer? normalizer) { if (extraSpecialTokens is not null) { - specialTokens = specialTokens.Concat(extraSpecialTokens).ToDictionary(pair => pair.Key, pair => pair.Value); + foreach (var extraSpecialToken in extraSpecialTokens) + { + specialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value); + } } - if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary encoder, Dictionary vocab, IReadOnlyDictionary decoder) cache)) + if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, IReadOnlyDictionary decoder) cache)) { - using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl)) + using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl).ConfigureAwait(false)) { - cache = Tiktoken.LoadTikTokenBpe(stream); + cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true).ConfigureAwait(false); } - _tiktokenCache.Add(mergeableRanksFileUrl, cache); + + _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache); } - return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(new Regex(regexPatternStr, RegexOptions.Compiled), specialTokens), normalizer); + return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer); } } } diff --git a/src/Microsoft.ML.Tokenizers/TokenizerResult.cs b/src/Microsoft.ML.Tokenizers/TokenizerResult.cs index 192c215eec..6b8e434878 100644 --- a/src/Microsoft.ML.Tokenizers/TokenizerResult.cs +++ b/src/Microsoft.ML.Tokenizers/TokenizerResult.cs @@ -63,8 +63,6 @@ internal void AddTokens(IReadOnlyList addedTokens) } } - private static readonly IReadOnlyList _emptyIds = new List(); - /// /// Gets list of the tokens Ids. /// The Ids are the main input to a Language Model. They are the token indices, the numerical representations that a LM understands. @@ -80,7 +78,7 @@ public IReadOnlyList Ids if (_tokens is null) { - return _emptyIds; + return Array.Empty(); } _ids = new List(_tokens.Count); @@ -94,8 +92,6 @@ public IReadOnlyList Ids } } - private static readonly IReadOnlyList _emptyTokens = new List(); - /// /// Gets the generated tokens. They are the string representation of the Ids. /// @@ -110,7 +106,7 @@ public IReadOnlyList Tokens if (_tokens is null) { - return _emptyTokens; + return Array.Empty(); } _tokensWords = new List(_tokens.Count); @@ -124,8 +120,6 @@ public IReadOnlyList Tokens } } - private static readonly IReadOnlyList<(int, int)> _emptyOffsets = new List<(int, int)>(); - /// /// Gets The list of offsets. These offsets let’s you slice the input string, and thus retrieve /// the original part that led to producing the corresponding token. @@ -141,7 +135,7 @@ public IReadOnlyList Tokens if (_tokens is null) { - return _emptyOffsets; + return Array.Empty<(int, int)>(); } _offsets = new List<(int Index, int End)>(_tokens.Count); diff --git a/src/Microsoft.ML.Tokenizers/Utils/ByteArrayComparer.cs b/src/Microsoft.ML.Tokenizers/Utils/ByteArrayComparer.cs index 9ccca49b89..a3f418317d 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/ByteArrayComparer.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/ByteArrayComparer.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -8,27 +8,17 @@ namespace Microsoft.ML.Tokenizers { - internal class ByteArrayComparer : IEqualityComparer + internal sealed class ReadOnlyMemoryByteComparer : IEqualityComparer> { - public bool Equals(byte[]? x, byte[]? y) - { - if (x is null || y is null) - { - return x == y; - } + public static ReadOnlyMemoryByteComparer Instance { get; } = new(); - return x.SequenceEqual(y); - } + public bool Equals(ReadOnlyMemory x, ReadOnlyMemory y) => + x.Span.SequenceEqual(y.Span); - public int GetHashCode(byte[] bytes) + public int GetHashCode(ReadOnlyMemory x) { - if (bytes == null) - { - throw new ArgumentNullException(nameof(bytes)); - } - int hash = 17; - foreach (byte b in bytes) + foreach (byte b in x.Span) { hash = hash * 31 + b; } @@ -36,4 +26,4 @@ public int GetHashCode(byte[] bytes) return hash; } } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs b/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs index 07a4c7db3b..523db677ee 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs @@ -1,10 +1,9 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; -using System.Runtime.CompilerServices; namespace Microsoft.ML.Tokenizers { @@ -13,11 +12,11 @@ namespace Microsoft.ML.Tokenizers /// internal static class BytePairEncoder { - public static int[] BytePairEncode(byte[] mergingBytes, IReadOnlyDictionary ranks) + public static int[] BytePairEncode(ReadOnlyMemory mergingBytes, Dictionary, int> ranks) { if (mergingBytes.Length == 1) { - return new int[] { ranks[mergingBytes] }; + return [ranks[mergingBytes]]; } var byteIndicesAndRanks = new List<(int Index, int Rank)>(); @@ -29,7 +28,7 @@ int GetRank(int startIndex, int skip = 0) { if (startIndex + skip + 2 < byteIndicesAndRanks.Count) { - var slice = mergingBytes.Slice(byteIndicesAndRanks[startIndex].Index, byteIndicesAndRanks[startIndex + skip + 2].Index); + var slice = mergingBytes.SliceStartEnd(byteIndicesAndRanks[startIndex].Index, byteIndicesAndRanks[startIndex + skip + 2].Index); if (ranks.TryGetValue(slice, out var rank)) { return rank; @@ -74,17 +73,11 @@ int GetRank(int startIndex, int skip = 0) var outList = new int[byteIndicesAndRanks.Count - 1]; for (int i = 0; i < byteIndicesAndRanks.Count - 1; i++) { - outList[i] = ranks[mergingBytes.Slice(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)]; + outList[i] = ranks[mergingBytes.SliceStartEnd(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)]; } return outList; } - private static T[] Slice(this T[] array, int start, int end) - { - var length = end - start; - var result = new T[length]; - Array.Copy(array, start, result, 0, length); - return result; - } + private static ReadOnlyMemory SliceStartEnd(this ReadOnlyMemory memory, int start, int end) => memory.Slice(start, end - start); } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.Tokenizers/Utils/IListExtensions.cs b/src/Microsoft.ML.Tokenizers/Utils/IListExtensions.cs index 061f4cc876..feb913158f 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/IListExtensions.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/IListExtensions.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -14,6 +14,14 @@ public static void AddRange(this IList list, IEnumerable items) { concreteList.AddRange(items); } + else if (items is IList listToAdd) + { + int count = listToAdd.Count; + for (int i = 0; i < count; i++) + { + list.Add(listToAdd[i]); + } + } else { foreach (var item in items) @@ -23,4 +31,4 @@ public static void AddRange(this IList list, IEnumerable items) } } } -} \ No newline at end of file +}