Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st

(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadFile(vocabFile, mergesFile);
Vocab = vocab1 ?? new Dictionary<string, int>();
Cache = new Cache<string, Word>();
Cache = new Cache<Word>();

VocabReverse = new();

Expand Down Expand Up @@ -274,7 +274,7 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(strin
internal Dictionary<Pair<int>, (int, int)> Merges { get; set; }

/// Contains the cache for optimizing the encoding step.
internal Cache<string, Word>? Cache { get; set; }
internal Cache<Word>? Cache { get; set; }

internal static readonly int DefaultCacheCapacity = 10_000;

Expand Down Expand Up @@ -315,9 +315,6 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(strin
return merges;
}

/// Reset the cache.
internal void ClearCache() => Cache?.Clear();

private readonly Dictionary<char, string> _charToString = new Dictionary<char, string>();

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down Expand Up @@ -425,7 +422,7 @@ internal List<Token> TokenizeWithCache(string sequence)
Word word;
if (Cache is not null)
{
if (Cache.TryGet(sequence, out word))
if (Cache.TryGetValue(sequence, out word))
{
return WordToTokens(ref word);
}
Expand Down Expand Up @@ -457,7 +454,7 @@ internal int TokenizeToIdsWithCache(string sequence, IList<int>? accumulatedIds)

if (Cache is not null)
{
if (Cache.TryGet(sequence, out Word hit))
if (Cache.TryGetValue(sequence, out Word hit))
{
return WordToIds(ref hit, accumulatedIds);
}
Expand Down
82 changes: 19 additions & 63 deletions src/Microsoft.ML.Tokenizers/Model/Cache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,96 +4,52 @@

using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;

namespace Microsoft.ML.Tokenizers
{
internal sealed class Cache<TKey, TValue> where TKey : notnull where TValue : notnull
internal sealed class Cache<TValue>
{
internal Cache() : this(Bpe.DefaultCacheCapacity) { }

internal Cache(int capacity)
{
Capacity = capacity;
Map = new Dictionary<TKey, TValue>(Capacity);
}
private readonly int _capacity;
private readonly Dictionary<StringSpanOrdinalKey, TValue> _map;

private readonly ReaderWriterLockSlim _cacheLock = new ReaderWriterLockSlim();
private object SyncObj => _map;

internal Dictionary<TKey, TValue> Map { get; set; }

internal int Capacity { get; set; }

internal void Fresh() => Map = new Dictionary<TKey, TValue>(Capacity);

internal void Clear()
{
_cacheLock.EnterWriteLock();
try
{
Map.Clear();
}
finally { _cacheLock.ExitWriteLock(); }
}
internal Cache() : this(Bpe.DefaultCacheCapacity) { }

internal List<TValue> GetValues(IEnumerable<TKey> keys)
internal Cache(int capacity)
{
List<TValue> values = new();
_cacheLock.EnterReadLock();
try
{
foreach (TKey key in keys)
{
if (Map.TryGetValue(key, out TValue? value))
{
values.Add(value);
}
}
}
finally { _cacheLock.ExitReadLock(); }

return values;
_capacity = capacity;
_map = new Dictionary<StringSpanOrdinalKey, TValue>(capacity);
}

internal bool TryGet(TKey key, out TValue value)
internal bool TryGetValue(string key, out TValue value)
{
_cacheLock.EnterReadLock();
try
lock (SyncObj)
{
return Map.TryGetValue(key, out value!);
return _map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
}
finally { _cacheLock.ExitReadLock(); }
}

internal void SetValues(IEnumerable<(TKey, TValue)> entries)
internal unsafe bool TryGetValue(ReadOnlySpan<char> key, out TValue value)
{
_cacheLock.EnterWriteLock();
try
lock (SyncObj)
{
foreach ((TKey, TValue) entry in entries)
fixed (char* ptr = key)
{
if (Capacity <= Map.Count)
{
break;
}
Map[entry.Item1] = entry.Item2;
return _map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
}
}
finally { _cacheLock.ExitWriteLock(); }
}

internal void Set(TKey k, TValue v)
internal void Set(string k, TValue v)
{
_cacheLock.EnterWriteLock();
try
lock (SyncObj)
{
if (Capacity > Map.Count)
if (_map.Count < _capacity)
{
Map[k] = v;
_map[new StringSpanOrdinalKey(k)] = v;
}
}
finally { _cacheLock.ExitWriteLock(); }
}
}
}
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public sealed class EnglishRoberta : Model
private readonly IReadOnlyDictionary<char, char> _byteToUnicode;
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
private readonly string[] _charToString;
private readonly Cache<string, List<Token>> _cache;
private readonly Cache<List<Token>> _cache;

/// <summary>
/// Construct tokenizer object to use with the English Robert model.
Expand Down Expand Up @@ -69,7 +69,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc
}

_unicodeToByte = _byteToUnicode.Reverse();
_cache = new Cache<string, List<Token>>();
_cache = new Cache<List<Token>>();
}

/// <summary>
Expand Down Expand Up @@ -107,7 +107,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
}

_unicodeToByte = _byteToUnicode.Reverse();
_cache = new Cache<string, List<Token>>();
_cache = new Cache<List<Token>>();
}

//
Expand Down Expand Up @@ -226,7 +226,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
return Array.Empty<Token>();
}

if (_cache.TryGet(sequence, out List<Token>? hit))
if (_cache.TryGetValue(sequence, out List<Token>? hit))
{
ArrayPool<char>.Shared.Return(token);
ArrayPool<int>.Shared.Return(indexMapping);
Expand Down Expand Up @@ -258,7 +258,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok

private int TokenizeToIds(string sequence, IList<int>? accumulatedIds)
{
if (_cache.TryGet(sequence, out List<Token>? hit))
if (_cache.TryGetValue(sequence, out List<Token>? hit))
{
if (accumulatedIds is not null)
{
Expand Down
42 changes: 41 additions & 1 deletion src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.ML.Tokenizers
{
Expand Down Expand Up @@ -45,6 +44,19 @@ public virtual void TokenizeToIds(string sequence, bool isSpecialToken, IList<in
}
}

/// <summary>
/// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="sequence">The sequence to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
/// <remarks>
/// This method does the default implementation that uses the Tokenize method to get the token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual void TokenizeToIds(ReadOnlySpan<char> sequence, bool isSpecialToken, IList<int> accumulatedIds) =>
TokenizeToIds(sequence.ToString(), isSpecialToken, accumulatedIds);

/// <summary>
/// Get the number of tokens that the input sequence will be encoded to.
/// </summary>
Expand All @@ -62,13 +74,33 @@ public virtual int CountTokens(string sequence, bool isSpecialToken)
return ids.Count;
}

/// <summary>
/// Get the number of tokens that the input sequence will be encoded to.
/// </summary>
/// <param name="sequence">The text to tokenize.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input sequence will be encoded to.</returns>
/// <remarks>
/// This method does the default implementation that uses the TokenizeToIds method to get the number of token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
/// </remarks>
public virtual int CountTokens(ReadOnlySpan<char> sequence, bool isSpecialToken) =>
CountTokens(sequence.ToString(), isSpecialToken);

/// <summary>
/// Map the token to tokenized Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <returns>The mapped Id of the token.</returns>
public abstract int? TokenToId(string token);

/// <summary>
/// Map the token to tokenized Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <returns>The mapped Id of the token.</returns>
public virtual int? TokenToId(ReadOnlySpan<char> token) => TokenToId(token.ToString());

/// <summary>
/// Map the token to tokenized id with the option to skip the special tokens.
/// </summary>
Expand All @@ -77,6 +109,14 @@ public virtual int CountTokens(string sequence, bool isSpecialToken)
/// <returns>The mapped Id of the token.</returns>
public virtual int? TokenToId(string token, bool skipSpecialTokens) => TokenToId(token);

/// <summary>
/// Map the token to tokenized id with the option to skip the special tokens.
/// </summary>
/// <param name="token">The token to map to Id</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public virtual int? TokenToId(ReadOnlySpan<char> token, bool skipSpecialTokens) => TokenToId(token, skipSpecialTokens);

/// <summary>
/// Map the tokenized Id to the token.
/// </summary>
Expand Down
Loading