From 2eadb5b7786ea9c11ed5e3a377c8f7f92a2a6792 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 19 Feb 2024 20:37:55 -0500 Subject: [PATCH 1/2] Tweak CreateByModelNameAsync - Add a CancellationToken to CreateByModelNameAsync, allowing the download and parsing to be canceled. - Use ReadLineAsync(cancellationToken), which not only allows it to be canceled, but avoids ~100K task allocations - Fix Helpers.FromBase64String to support lines longer than 300 chars --- src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 13 ++-- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 62 ++++++++++++------- .../Utils/Helpers.netcoreapp.cs | 27 ++++++-- .../Utils/Helpers.netstandard.cs | 19 +++++- 4 files changed, 85 insertions(+), 36 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 9935dd6428..74d95df23a 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -9,6 +9,7 @@ using System.IO; using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.ML.Tokenizers @@ -104,9 +105,11 @@ private Tiktoken(int cacheSize) /// /// Stream to the BPE rank file /// Whether to perform I/O synchronously or asynchronously. + /// used to request cancellation of the operation. /// Map of byte[] to integer token id /// - internal static async ValueTask<(Dictionary, int>, Dictionary, IReadOnlyDictionary)> LoadTikTokenBpeAsync(Stream tikTokenBpeFileStream, bool useAsync) + internal static async ValueTask<(Dictionary, int>, Dictionary, IReadOnlyDictionary)> LoadTikTokenBpeAsync( + Stream tikTokenBpeFileStream, bool useAsync, CancellationToken cancellationToken = default) { var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance); var vocab = new Dictionary(); @@ -119,7 +122,7 @@ private Tiktoken(int cacheSize) while (true) { string? line = useAsync ? - await reader.ReadLineAsync().ConfigureAwait(false) : + await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine(); if (string.IsNullOrWhiteSpace(line)) { @@ -136,10 +139,10 @@ await reader.ReadLineAsync().ConfigureAwait(false) : throw new FormatException($"Invalid format in the BPE encoder file stream"); } - byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex); - if (Helpers.TryParseInt32(line, spaceIndex + 1, out int rank)) { + byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex); + encoder[tokenBytes] = rank; decoder[rank] = tokenBytes; @@ -214,7 +217,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok // cache miss if (_vocab.TryGetValue(sequence, out int mappedId)) { - return new List { new(mappedId, sequence, (0, sequence.Length)) }; + return new Token[1] { new(mappedId, sequence, (0, sequence.Length)) }; } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length)); diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index d002f55833..d29766d65c 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -9,6 +9,7 @@ using System.IO; using System.Net.Http; using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.ML.Tokenizers @@ -346,32 +347,41 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo /// Model name /// Extra special tokens other than the built-in ones for the model /// To normalize the text before tokenization + /// used to request cancellation of the operation. /// The tokenizer - public static async Task CreateByModelNameAsync( + public static Task CreateByModelNameAsync( string modelName, IReadOnlyDictionary? extraSpecialTokens = null, - Normalizer? normalizer = null) + Normalizer? normalizer = null, + CancellationToken cancellationToken = default) { - ModelEncoding encoder; - - if (!_modelToEncoding.TryGetValue(modelName, out encoder)) + try { - foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) + ModelEncoding encoder; + + if (!_modelToEncoding.TryGetValue(modelName, out encoder)) { - if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) + foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) { - encoder = Encoding; - break; + if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) + { + encoder = Encoding; + break; + } } } - } - if (encoder == ModelEncoding.None) + if (encoder == ModelEncoding.None) + { + throw new NotImplementedException($"Doesn't support this model [{modelName}]"); + } + + return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken); + } + catch (Exception ex) { - throw new NotImplementedException($"Doesn't support this model [{modelName}]"); + return Task.FromException(ex); } - - return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer).ConfigureAwait(false); } private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; @@ -402,36 +412,38 @@ public static async Task CreateByModelNameAsync( /// Encoder label /// Extra special tokens other than the built-in ones for the encoder /// To normalize the text before tokenization + /// used to request cancellation of the operation. /// The tokenizer /// Throws if the encoder is not supported - private static async Task CreateByEncoderNameAsync( + private static Task CreateByEncoderNameAsync( ModelEncoding modelEncoding, IReadOnlyDictionary? extraSpecialTokens, - Normalizer? normalizer) + Normalizer? normalizer, + CancellationToken cancellationToken) { switch (modelEncoding) { case ModelEncoding.Cl100kBase: var specialTokens = new Dictionary { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }; - return await CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); case ModelEncoding.P50kBase: specialTokens = new Dictionary { { EndOfText, 50256 } }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); case ModelEncoding.P50kEdit: specialTokens = new Dictionary { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); case ModelEncoding.R50kBase: specialTokens = new Dictionary { { EndOfText, 50256 } }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); case ModelEncoding.GPT2: specialTokens = new Dictionary { { EndOfText, 50256 }, }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken); default: Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); @@ -449,13 +461,15 @@ private static async Task CreateByEncoderNameAsync( /// Special tokens mapping. This may be mutated by the method. /// Extra special tokens other than the built-in ones for the encoder /// To normalize the text before tokenization + /// used to request cancellation of the operation. /// The tokenizer private static async Task CreateTikTokenTokenizerAsync( Regex regex, string mergeableRanksFileUrl, Dictionary specialTokens, IReadOnlyDictionary? extraSpecialTokens, - Normalizer? normalizer) + Normalizer? normalizer, + CancellationToken cancellationToken) { if (extraSpecialTokens is not null) { @@ -467,9 +481,9 @@ private static async Task CreateTikTokenTokenizerAsync( if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, IReadOnlyDictionary decoder) cache)) { - using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl).ConfigureAwait(false)) + using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false)) { - cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true).ConfigureAwait(false); + cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false); } _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache); diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs index 99d764a9cf..b64531431f 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs @@ -1,26 +1,41 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using System.Buffers.Text; +using System.Diagnostics; using System.Globalization; +using System.IO; +using System.Threading.Tasks; +using System.Threading; +using System.Net.Http; namespace Microsoft.ML.Tokenizers { internal static class Helpers { + public static ValueTask ReadLineAsync(StreamReader reader, CancellationToken cancellationToken) => + reader.ReadLineAsync(cancellationToken); + + public static Task GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken) => + client.GetStreamAsync(url, cancellationToken); + public static byte[] FromBase64String(string base64String, int offset, int length) { - Span bytes = stackalloc byte[300]; - if (!Convert.TryFromBase64Chars(base64String.AsSpan().Slice(offset, length), bytes, out int bytesWritten)) + if (!Base64.IsValid(base64String.AsSpan(offset, length), out int decodedLength)) { - throw new System.FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'"); + throw new FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'"); } - return bytes.Slice(0, bytesWritten).ToArray(); + + byte[] bytes = new byte[decodedLength]; + bool success = Convert.TryFromBase64Chars(base64String.AsSpan(offset, length), bytes, out int bytesWritten); + Debug.Assert(success); + Debug.Assert(bytes.Length == bytesWritten); + return bytes; } internal static bool TryParseInt32(string s, int offset, out int result) => int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result); } } - diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs index 4f354cda5a..2979c99b6e 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs @@ -1,13 +1,30 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using System.IO; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; namespace Microsoft.ML.Tokenizers { internal static class Helpers { + public static ValueTask ReadLineAsync(StreamReader reader, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + return new ValueTask(reader.ReadLineAsync()); + } + + public static async Task GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken) + { + HttpResponseMessage response = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + return await response.Content.ReadAsStreamAsync().ConfigureAwait(false); + } + public static byte[] FromBase64String(string base64String, int offset, int length) => Convert.FromBase64String(base64String.Substring(offset, length)); // Not support signed number From e78ab0f3a424ed0c463290bdb1ce3414ddb2b070 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 19 Feb 2024 23:19:56 -0500 Subject: [PATCH 2/2] Prototype of using spans in Model --- src/Microsoft.ML.Tokenizers/Model/BPE.cs | 11 +- src/Microsoft.ML.Tokenizers/Model/Cache.cs | 82 ++------ .../Model/EnglishRoberta.cs | 10 +- src/Microsoft.ML.Tokenizers/Model/Model.cs | 42 +++- src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 184 +++++++++++------- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 8 +- src/Microsoft.ML.Tokenizers/Utils/LruCache.cs | 129 ++++++------ .../Utils/StringSpanOrdinalKey.cs | 55 ++++++ 8 files changed, 301 insertions(+), 220 deletions(-) create mode 100644 src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index 008dacb573..036d86daa6 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -95,7 +95,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st (Dictionary? vocab1, Vec<(string, string)> merges) = ReadFile(vocabFile, mergesFile); Vocab = vocab1 ?? new Dictionary(); - Cache = new Cache(); + Cache = new Cache(); VocabReverse = new(); @@ -274,7 +274,7 @@ internal static (Dictionary?, Vec<(string, string)>) ReadFile(strin internal Dictionary, (int, int)> Merges { get; set; } /// Contains the cache for optimizing the encoding step. - internal Cache? Cache { get; set; } + internal Cache? Cache { get; set; } internal static readonly int DefaultCacheCapacity = 10_000; @@ -315,9 +315,6 @@ internal static (Dictionary?, Vec<(string, string)>) ReadFile(strin return merges; } - /// Reset the cache. - internal void ClearCache() => Cache?.Clear(); - private readonly Dictionary _charToString = new Dictionary(); [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -425,7 +422,7 @@ internal List TokenizeWithCache(string sequence) Word word; if (Cache is not null) { - if (Cache.TryGet(sequence, out word)) + if (Cache.TryGetValue(sequence, out word)) { return WordToTokens(ref word); } @@ -457,7 +454,7 @@ internal int TokenizeToIdsWithCache(string sequence, IList? accumulatedIds) if (Cache is not null) { - if (Cache.TryGet(sequence, out Word hit)) + if (Cache.TryGetValue(sequence, out Word hit)) { return WordToIds(ref hit, accumulatedIds); } diff --git a/src/Microsoft.ML.Tokenizers/Model/Cache.cs b/src/Microsoft.ML.Tokenizers/Model/Cache.cs index 1fcfa849ec..1bf1ce9b9c 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Cache.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Cache.cs @@ -4,96 +4,52 @@ using System; using System.Collections.Generic; -using System.Text; -using System.Threading; namespace Microsoft.ML.Tokenizers { - internal sealed class Cache where TKey : notnull where TValue : notnull + internal sealed class Cache { - internal Cache() : this(Bpe.DefaultCacheCapacity) { } - - internal Cache(int capacity) - { - Capacity = capacity; - Map = new Dictionary(Capacity); - } + private readonly int _capacity; + private readonly Dictionary _map; - private readonly ReaderWriterLockSlim _cacheLock = new ReaderWriterLockSlim(); + private object SyncObj => _map; - internal Dictionary Map { get; set; } - - internal int Capacity { get; set; } - - internal void Fresh() => Map = new Dictionary(Capacity); - - internal void Clear() - { - _cacheLock.EnterWriteLock(); - try - { - Map.Clear(); - } - finally { _cacheLock.ExitWriteLock(); } - } + internal Cache() : this(Bpe.DefaultCacheCapacity) { } - internal List GetValues(IEnumerable keys) + internal Cache(int capacity) { - List values = new(); - _cacheLock.EnterReadLock(); - try - { - foreach (TKey key in keys) - { - if (Map.TryGetValue(key, out TValue? value)) - { - values.Add(value); - } - } - } - finally { _cacheLock.ExitReadLock(); } - - return values; + _capacity = capacity; + _map = new Dictionary(capacity); } - internal bool TryGet(TKey key, out TValue value) + internal bool TryGetValue(string key, out TValue value) { - _cacheLock.EnterReadLock(); - try + lock (SyncObj) { - return Map.TryGetValue(key, out value!); + return _map.TryGetValue(new StringSpanOrdinalKey(key), out value!); } - finally { _cacheLock.ExitReadLock(); } } - internal void SetValues(IEnumerable<(TKey, TValue)> entries) + internal unsafe bool TryGetValue(ReadOnlySpan key, out TValue value) { - _cacheLock.EnterWriteLock(); - try + lock (SyncObj) { - foreach ((TKey, TValue) entry in entries) + fixed (char* ptr = key) { - if (Capacity <= Map.Count) - { - break; - } - Map[entry.Item1] = entry.Item2; + return _map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!); } } - finally { _cacheLock.ExitWriteLock(); } } - internal void Set(TKey k, TValue v) + internal void Set(string k, TValue v) { - _cacheLock.EnterWriteLock(); - try + lock (SyncObj) { - if (Capacity > Map.Count) + if (_map.Count < _capacity) { - Map[k] = v; + _map[new StringSpanOrdinalKey(k)] = v; } } - finally { _cacheLock.ExitWriteLock(); } } } } diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index ad98ed917c..bcef6cdd8d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -24,7 +24,7 @@ public sealed class EnglishRoberta : Model private readonly IReadOnlyDictionary _byteToUnicode; private readonly IReadOnlyDictionary _unicodeToByte; private readonly string[] _charToString; - private readonly Cache> _cache; + private readonly Cache> _cache; /// /// Construct tokenizer object to use with the English Robert model. @@ -69,7 +69,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc } _unicodeToByte = _byteToUnicode.Reverse(); - _cache = new Cache>(); + _cache = new Cache>(); } /// @@ -107,7 +107,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes } _unicodeToByte = _byteToUnicode.Reverse(); - _cache = new Cache>(); + _cache = new Cache>(); } // @@ -226,7 +226,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok return Array.Empty(); } - if (_cache.TryGet(sequence, out List? hit)) + if (_cache.TryGetValue(sequence, out List? hit)) { ArrayPool.Shared.Return(token); ArrayPool.Shared.Return(indexMapping); @@ -258,7 +258,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok private int TokenizeToIds(string sequence, IList? accumulatedIds) { - if (_cache.TryGet(sequence, out List? hit)) + if (_cache.TryGetValue(sequence, out List? hit)) { if (accumulatedIds is not null) { diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs index c8bd01cc06..8fc176172d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Model.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; -using System.Text; namespace Microsoft.ML.Tokenizers { @@ -45,6 +44,19 @@ public virtual void TokenizeToIds(string sequence, bool isSpecialToken, IList + /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list. + /// + /// The sequence to split. + /// Indicate if the token is a special token. + /// The list of accumulated tokenized Ids. + /// + /// This method does the default implementation that uses the Tokenize method to get the token's Ids. + /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. + /// + public virtual void TokenizeToIds(ReadOnlySpan sequence, bool isSpecialToken, IList accumulatedIds) => + TokenizeToIds(sequence.ToString(), isSpecialToken, accumulatedIds); + /// /// Get the number of tokens that the input sequence will be encoded to. /// @@ -62,6 +74,19 @@ public virtual int CountTokens(string sequence, bool isSpecialToken) return ids.Count; } + /// + /// Get the number of tokens that the input sequence will be encoded to. + /// + /// The text to tokenize. + /// Indicate if the token is special token. + /// The number of tokens that the input sequence will be encoded to. + /// + /// This method does the default implementation that uses the TokenizeToIds method to get the number of token's Ids. + /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. + /// + public virtual int CountTokens(ReadOnlySpan sequence, bool isSpecialToken) => + CountTokens(sequence.ToString(), isSpecialToken); + /// /// Map the token to tokenized Id. /// @@ -69,6 +94,13 @@ public virtual int CountTokens(string sequence, bool isSpecialToken) /// The mapped Id of the token. public abstract int? TokenToId(string token); + /// + /// Map the token to tokenized Id. + /// + /// The token to map to the Id. + /// The mapped Id of the token. + public virtual int? TokenToId(ReadOnlySpan token) => TokenToId(token.ToString()); + /// /// Map the token to tokenized id with the option to skip the special tokens. /// @@ -77,6 +109,14 @@ public virtual int CountTokens(string sequence, bool isSpecialToken) /// The mapped Id of the token. public virtual int? TokenToId(string token, bool skipSpecialTokens) => TokenToId(token); + /// + /// Map the token to tokenized id with the option to skip the special tokens. + /// + /// The token to map to Id + /// Indicate if want to skip the special tokens during the encoding. + /// The mapped Id of the token. + public virtual int? TokenToId(ReadOnlySpan token, bool skipSpecialTokens) => TokenToId(token, skipSpecialTokens); + /// /// Map the tokenized Id to the token. /// diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 74d95df23a..8424c60e51 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -19,12 +19,13 @@ namespace Microsoft.ML.Tokenizers /// public sealed class Tiktoken : Model { - private readonly Dictionary, int> _encoder = null!; - private readonly IReadOnlyDictionary _decoder = null!; - private readonly LruCache _cache; - private readonly IReadOnlyDictionary? _specialTokensEncoder; + private readonly Dictionary, int> _encoder; + private readonly IReadOnlyDictionary _decoder; + private readonly LruCache _cache; + private readonly Dictionary? _specialTokensEncoder; private readonly Dictionary? _specialTokensDecoder; - private readonly Dictionary _vocab = null!; + private readonly IReadOnlyDictionary _vocabOriginal; + private readonly Dictionary _vocab; /// /// Create a new Tiktoken tokenizer object. @@ -34,7 +35,7 @@ public sealed class Tiktoken : Model /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE rank file. - public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : + public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : this(string.IsNullOrEmpty(tikTokenBpeFile) ? throw new ArgumentNullException(nameof(tikTokenBpeFile)) : File.OpenRead(tikTokenBpeFile), specialTokensEncoder, cacheSize, disposeStream: true) { } @@ -47,44 +48,40 @@ public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specia /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE rank file. - public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : + public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : this(tikTokenBpeFileStream ?? throw new ArgumentNullException(nameof(tikTokenBpeFileStream)), specialTokensEncoder, cacheSize, disposeStream: false) { } internal Tiktoken( Dictionary, int> encoder, - IReadOnlyDictionary decoder, - Dictionary vocab, + Dictionary decoder, + Dictionary vocab, IReadOnlyDictionary? specialTokensEncoder = null, - int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize) + int cacheSize = LruCache.DefaultCacheSize) { Debug.Assert(encoder is not null); Debug.Assert(decoder is not null); Debug.Assert(vocab is not null); _encoder = encoder!; - _vocab = vocab!; _decoder = decoder!; + _vocab = vocab!; + _cache = new LruCache(cacheSize); - _specialTokensEncoder = specialTokensEncoder; - if (_specialTokensEncoder is not null) - { - _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); - } + _vocabOriginal = _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); + (_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokensEncoder); } - private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder, int cacheSize, bool disposeStream) : this(cacheSize) + private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder, int cacheSize, bool disposeStream) { try { - (_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(tikTokenBpeFileStream, useAsync: false).GetAwaiter().GetResult(); + _cache = new LruCache(cacheSize); - _specialTokensEncoder = specialTokensEncoder; - if (_specialTokensEncoder is not null) - { - _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); - } + (_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(tikTokenBpeFileStream, useAsync: false).GetAwaiter().GetResult(); + _vocabOriginal = _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); + (_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokensEncoder); } finally { @@ -95,9 +92,15 @@ private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? } } - private Tiktoken(int cacheSize) + private static (Dictionary?, Dictionary?) CreateEncoderDecoder(IReadOnlyDictionary? specialTokens) { - _cache = new LruCache(cacheSize); + if (specialTokens is not null) + { + var encoder = specialTokens.ToDictionary(e => new StringSpanOrdinalKey(e.Key), e => e.Value); + return (encoder, encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key.Data!)); + } + + return (null, null); } /// @@ -108,11 +111,11 @@ private Tiktoken(int cacheSize) /// used to request cancellation of the operation. /// Map of byte[] to integer token id /// - internal static async ValueTask<(Dictionary, int>, Dictionary, IReadOnlyDictionary)> LoadTikTokenBpeAsync( + internal static async ValueTask<(Dictionary, int>, Dictionary, Dictionary)> LoadTikTokenBpeAsync( Stream tikTokenBpeFileStream, bool useAsync, CancellationToken cancellationToken = default) { var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance); - var vocab = new Dictionary(); + var vocab = new Dictionary(); var decoder = new Dictionary(); try @@ -148,7 +151,7 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : string decodedToken = Encoding.UTF8.GetString(tokenBytes); - vocab[decodedToken] = rank; + vocab[new StringSpanOrdinalKey(decodedToken)] = rank; } else { @@ -165,12 +168,6 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : return (encoder, vocab, decoder); } - /// - /// Gets the dictionary mapping special tokens to Ids. - /// - /// The dictionary mapping special tokens to Ids. - public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoder; - /// /// Tokenize a split sequence string to a list of tokens. /// @@ -193,7 +190,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok throw new InvalidOperationException($"The tokenizer doesn't have special tokens"); } - if (_specialTokensEncoder.TryGetValue(sequence, out int id)) + if (_specialTokensEncoder.TryGetValue(new StringSpanOrdinalKey(sequence), out int id)) { return new List { new(id, sequence, (0, sequence.Length)) }; } @@ -201,7 +198,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok throw new InvalidOperationException($"The special token {sequence} doesn't exist in the tokenizer"); } - if (_cache.Lookup(sequence, out int[] ids)) + if (_cache.TryGetValue(sequence, out int[]? ids)) { tokens = new Token[ids.Length]; tokens[0] = new Token(ids[0], sequence, (0, sequence.Length)); @@ -215,7 +212,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok } // cache miss - if (_vocab.TryGetValue(sequence, out int mappedId)) + if (_vocab.TryGetValue(new StringSpanOrdinalKey(sequence), out int mappedId)) { return new Token[1] { new(mappedId, sequence, (0, sequence.Length)) }; } @@ -245,40 +242,58 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok /// The sequence to tokenize. /// Indicate if the token is a special token. /// The list of accumulated Ids. - public override void TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) + public override void TokenizeToIds(string sequence, bool isSpecialToken, IList accumulatedIds) => + TokenizeToIds(sequence.AsSpan(), sequence, isSpecialToken, accumulatedIds); + + /// + /// Tokenize a split sequence string to a list of Ids. + /// + /// The sequence to tokenize. + /// Indicate if the token is a special token. + /// The list of accumulated Ids. + public override void TokenizeToIds(ReadOnlySpan sequence, bool isSpecialToken, IList accumulatedIds) => + TokenizeToIds(sequence, null, isSpecialToken, accumulatedIds); + + private unsafe void TokenizeToIds(ReadOnlySpan sequence, string? sequenceString, bool isSpecialToken, IList accumulatedIds) { - if (string.IsNullOrEmpty(sequence)) + if (sequence.IsEmpty) { return; } if (isSpecialToken) { - if (_specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(sequence, out int id)) + fixed (char* ptr = sequence) { - accumulatedIds.Add(id); + if (_specialTokensEncoder?.TryGetValue(new StringSpanOrdinalKey(ptr, sequence.Length), out int id) is true) + { + accumulatedIds.Add(id); + } } return; } - if (_cache.Lookup(sequence, out int[] tokenIds)) + if (_cache.TryGetValue(sequence, out int[]? tokenIds)) { accumulatedIds.AddRange(tokenIds); return; } - if (_vocab.TryGetValue(sequence, out int mappedId)) + fixed (char* ptr = sequence) { - accumulatedIds.Add(mappedId); - return; + if (_vocab.TryGetValue(new StringSpanOrdinalKey(ptr, sequence.Length), out int mappedId)) + { + accumulatedIds.Add(mappedId); + return; + } } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length)); - int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray); + int encodedLength = GetUtf8Bytes(sequence, arrayPoolArray); int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache.Add(sequence, encodedIds); + _cache.Add(sequenceString ?? sequence.ToString(), encodedIds); accumulatedIds.AddRange(encodedIds); @@ -292,44 +307,67 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IListThe text to tokenize. /// Indicate if the token is special token. /// The number of tokens that the input sequence will be encoded to. - public override int CountTokens(string sequence, bool isSpecialToken) + public override int CountTokens(string sequence, bool isSpecialToken) => + CountTokens(sequence.AsSpan(), sequence, isSpecialToken); + + /// + /// Get the number of tokens that the input sequence will be encoded to. + /// + /// The text to tokenize. + /// Indicate if the token is special token. + /// The number of tokens that the input sequence will be encoded to. + public override int CountTokens(ReadOnlySpan sequence, bool isSpecialToken) => + CountTokens(sequence, null, isSpecialToken); + + private unsafe int CountTokens(ReadOnlySpan sequence, string? sequenceString, bool isSpecialToken) { - if (string.IsNullOrEmpty(sequence)) + if (sequence.IsEmpty) { return 0; } if (isSpecialToken && _specialTokensEncoder is not null) { - return _specialTokensEncoder.TryGetValue(sequence, out _) ? 1 : 0; + fixed (char* ptr = sequence) + { + return _specialTokensEncoder.TryGetValue(new StringSpanOrdinalKey(ptr, sequence.Length), out _) ? 1 : 0; + } } - if (_cache.Lookup(sequence, out int[] ids)) + if (_cache.TryGetValue(sequence, out int[] ids)) { return ids.Length; } - if (_vocab.TryGetValue(sequence, out _)) + fixed (char* ptr = sequence) { - return 1; + if (_vocab.TryGetValue(new StringSpanOrdinalKey(ptr, sequence.Length), out _)) + { + return 1; + } } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length)); - int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray); + int encodedLength = GetUtf8Bytes(sequence, arrayPoolArray); int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache.Add(sequence, encodedIds); + _cache.Add(sequenceString ?? sequence.ToString(), encodedIds); ArrayPool.Shared.Return(arrayPoolArray); return encodedIds.Length; } + public override int? TokenToId(ReadOnlySpan token) => TokenToId(token, skipSpecialTokens: false); + public override int? TokenToId(string token) => TokenToId(token, skipSpecialTokens: false); + /// /// Map the token to tokenized Id. /// /// The token to map to the Id. + /// Indicate if want to skip the special tokens during the encoding. /// The mapped Id of the token. - public override int? TokenToId(string token) => TokenToId(token, skipSpecialTokens: false); + public override int? TokenToId(string token, bool skipSpecialTokens) => + TokenToId(token.AsSpan(), token, skipSpecialTokens); /// /// Map the token to tokenized Id. @@ -337,19 +375,28 @@ public override int CountTokens(string sequence, bool isSpecialToken) /// The token to map to the Id. /// Indicate if want to skip the special tokens during the encoding. /// The mapped Id of the token. - public override int? TokenToId(string token, bool skipSpecialTokens) + public override unsafe int? TokenToId(ReadOnlySpan token, bool skipSpecialTokens) => + TokenToId(token, null, skipSpecialTokens); + + private unsafe int? TokenToId(ReadOnlySpan token, string? tokenString, bool skipSpecialTokens) { - if (string.IsNullOrEmpty(token)) + if (token.IsEmpty) { return 0; } - if (!skipSpecialTokens && _specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(token, out int specialTokenId)) + if (!skipSpecialTokens && _specialTokensEncoder is not null) { - return specialTokenId; + fixed (char* ptr = token) + { + if (_specialTokensEncoder.TryGetValue(new StringSpanOrdinalKey(ptr, token.Length), out int specialTokenId)) + { + return specialTokenId; + } + } } - if (_cache.Lookup(token, out int[] ids)) + if (_cache.TryGetValue(token, out int[]? ids)) { if (ids.Length == 1) { @@ -359,18 +406,21 @@ public override int CountTokens(string sequence, bool isSpecialToken) return null; } - if (_vocab.TryGetValue(token, out int id)) + fixed (char* ptr = token) { - return id; + if (_vocab.TryGetValue(new StringSpanOrdinalKey(ptr, token.Length), out int id)) + { + return id; + } } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length)); try { - int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray); + int encodedLength = GetUtf8Bytes(token, arrayPoolArray); int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache.Add(token, idsToCache); + _cache.Add(tokenString ?? token.ToString(), idsToCache); if (idsToCache.Length == 1) { @@ -478,12 +528,12 @@ static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray, /// /// Gets the dictionary mapping tokens to Ids. /// - public override IReadOnlyDictionary GetVocab() => _vocab; + public override IReadOnlyDictionary GetVocab() => _vocabOriginal; /// /// Gets the dictionary size that map tokens to Ids. /// - public override int GetVocabSize() => _vocab.Count; + public override int GetVocabSize() => _vocabOriginal.Count; /// /// Save the model data into the vocabulary and merges files. diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index d29766d65c..1e3a086f12 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -139,7 +139,7 @@ public IReadOnlyList EncodeToIds(string sequence, bool skipSpecialTokens = foreach (Split split in PreTokenizer.PreTokenize(normalized, skipSpecialTokens)) { - Model.TokenizeToIds(split.TokenString, split.IsSpecialToken, idsList); + Model.TokenizeToIds(split.TokenSpan, split.IsSpecialToken, idsList); } return idsList; @@ -165,7 +165,7 @@ public int CountTokens(string sequence, bool skipSpecialTokens = false) int idsCount = 0; foreach (Split split in PreTokenizer.PreTokenize(normalized, skipSpecialTokens)) { - idsCount += Model.CountTokens(split.TokenString, split.IsSpecialToken); + idsCount += Model.CountTokens(split.TokenSpan, split.IsSpecialToken); } return idsCount; @@ -451,7 +451,7 @@ private static Task CreateByEncoderNameAsync( } } - private static readonly ConcurrentDictionary, int>, Dictionary, IReadOnlyDictionary)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); + private static readonly ConcurrentDictionary, int>, Dictionary, Dictionary)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); /// /// Create tokenizer based on regex pattern, BPE rank file and special tokens @@ -479,7 +479,7 @@ private static async Task CreateTikTokenTokenizerAsync( } } - if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, IReadOnlyDictionary decoder) cache)) + if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary decoder) cache)) { using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false)) { diff --git a/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs b/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs index 9ad88e2f35..1bd2275c13 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs @@ -1,48 +1,33 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; namespace Microsoft.ML.Tokenizers { - internal class LruCache where TKey : notnull where TValue : notnull + internal sealed class LruCache { /// /// The default LRU cache size. /// - public const int DefaultCacheSize = 8192; // 4096; + public const int DefaultCacheSize = 8192; - private readonly object _lockObject = new object(); - - private class CacheItem - { - public readonly TKey Key; - public TValue Value; - - public CacheItem(TKey key, TValue value) - { - Key = key; - Value = value; - } - } - - private readonly Dictionary> _cache; - private readonly LinkedList _lruList; + private readonly Dictionary>> _cache = new(); + private readonly LinkedList> _lruList = new(); private readonly int _cacheSize; + private object SyncObj => _cache; + /// - /// Constructs an object. + /// Constructs an object. /// /// - /// The maximum number of to mappings - /// that can be cached. This defaults to , which is set to - /// 4096. + /// The maximum number of mappings that can be cached. This defaults to , which is set to 8192. /// public LruCache(int cacheSize = DefaultCacheSize) { - _cache = new Dictionary>(); - _lruList = new LinkedList(); _cacheSize = cacheSize; } @@ -54,11 +39,11 @@ public LruCache(int cacheSize = DefaultCacheSize) /// /// true if the cache contains a mapping for key, false otherwise. /// - public bool Lookup(TKey key, out TValue value) + public bool TryGetValue(string key, out TValue value) { - lock (_lockObject) + lock (SyncObj) { - if (_cache.TryGetValue(key, out LinkedListNode? cached)) + if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode>? cached)) { _lruList.Remove(cached); _lruList.AddFirst(cached); @@ -71,16 +56,31 @@ public bool Lookup(TKey key, out TValue value) } } - protected virtual void OnEviction(TValue evictedValue) { } - - private void EvictIfNeeded() + /// + /// Retrieves the value associated with the specified key /> object. + /// + /// The object to be used as a key. + /// An out parameter that is set to the value of the key if key contains a mapping in the cache. + /// + /// true if the cache contains a mapping for key, false otherwise. + /// + public unsafe bool TryGetValue(ReadOnlySpan key, out TValue value) { - while (_cache.Count >= _cacheSize) + lock (SyncObj) { - LinkedListNode? nodeToEvict = _lruList.Last; - _lruList.RemoveLast(); - _cache.Remove(nodeToEvict!.Value.Key); - OnEviction(nodeToEvict.Value.Value); + fixed (char* ptr = key) + { + if (_cache.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out LinkedListNode>? cached)) + { + _lruList.Remove(cached); + _lruList.AddFirst(cached); + value = cached.Value.Value; + return true; + } + } + + value = default!; + return false; } } @@ -89,46 +89,29 @@ private void EvictIfNeeded() /// /// The key whose mapped is to be created or replaced. /// The new value to be mapped to the . - public void Add(TKey key, TValue value) => Replace(key, value, out _); - - public bool Replace(TKey key, TValue value, out TValue oldValue) + public void Add(string key, TValue value) { - lock (_lockObject) + lock (SyncObj) { - return ReplaceInternal(key, value, out oldValue); - } - } - - private bool ReplaceInternal(TKey key, TValue value, out TValue oldValue) - { - if (_cache.TryGetValue(key, out LinkedListNode? cached)) - { - oldValue = cached.Value.Value; - cached.Value.Value = value; - _lruList.Remove(cached); - _lruList.AddFirst(cached); - return true; - } - EvictIfNeeded(); - var node = new LinkedListNode(new CacheItem(key, value)); - _cache[key] = node; - _lruList.AddFirst(node); - oldValue = default!; - return false; - } + if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode>? cached)) + { + cached.Value = new KeyValuePair(key, value); + _lruList.Remove(cached); + _lruList.AddFirst(cached); + return; + } - /// - /// The number of entries currently present in the cache. - /// - public int Count => _cache.Count; + while (_cache.Count >= _cacheSize) + { + LinkedListNode>? nodeToEvict = _lruList.Last; + _lruList.RemoveLast(); + _cache.Remove(new StringSpanOrdinalKey(nodeToEvict!.Value.Key)); + } - /// - /// Clears the contents of this cache. - /// - public void Clear() - { - _cache.Clear(); - _lruList.Clear(); + var node = new LinkedListNode>(new KeyValuePair(key, value)); + _cache[new StringSpanOrdinalKey(key)] = node; + _lruList.AddFirst(node); + } } } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs new file mode 100644 index 0000000000..4897a2fa82 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; + +namespace Microsoft.ML.Tokenizers +{ + /// Used as a key in a dictionary to enable querying with either a string or a span. + /// + /// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should + /// always be used with a string. + /// + internal unsafe readonly struct StringSpanOrdinalKey : IEquatable + { + public readonly char* Ptr; + public readonly int Length; + public readonly string? Data; + + public StringSpanOrdinalKey(char* ptr, int length) + { + Ptr = ptr; + Length = length; + } + + public StringSpanOrdinalKey(string data) => + Data = data; + + private ReadOnlySpan Span => Ptr is not null ? + new ReadOnlySpan(Ptr, Length) : + Data.AsSpan(); + + public override bool Equals(object? obj) => + obj is StringSpanOrdinalKey wrapper && Equals(wrapper); + + public bool Equals(StringSpanOrdinalKey other) => + Span.SequenceEqual(other.Span); + + public override int GetHashCode() + { +#if NET5_0_OR_GREATER + return string.GetHashCode(Span); +#else + int hash = 17; + foreach (char c in Span) + { + hash = hash * 31 + c; + } + + return hash; +#endif + } + } +}