Skip to content

Cleanup SentencePiece tokenizer #7427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 0 additions & 61 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,67 +59,6 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool
specialTokens);
}

internal SentencePieceBaseModel(SentencePieceOptions options)
{
if (options is null)
{
throw new ArgumentNullException(nameof(options));
}

if (options.Vocabulary is null)
{
throw new ArgumentNullException(nameof(options.Vocabulary));
}

if (options.BeginningOfSentenceToken is null)
{
throw new ArgumentNullException(nameof(options.BeginningOfSentenceToken));
}

if (options.EndOfSentenceToken is null)
{
throw new ArgumentNullException(nameof(options.EndOfSentenceToken));
}

if (options.UnknownToken is null)
{
throw new ArgumentNullException(nameof(options.UnknownToken));
}

AddBeginningOfSentence = options.AddBeginningOfSentence;
AddEndOfSentence = options.AddEndOfSentence;
BeginningOfSentenceToken = options.BeginningOfSentenceToken;
EndOfSentenceToken = options.EndOfSentenceToken;
UnknownToken = options.UnknownToken;
AddDummyPrefix = options.AddDummyPrefix;
EscapeWhiteSpaces = options.EscapeWhiteSpaces;
TreatWhitespaceAsSuffix = options.TreatWhitespaceAsSuffix;
ByteFallback = options.ByteFallback;
SpecialTokens = options.SpecialTokens;

if (SpecialTokens is not null && SpecialTokens.Count > 0)
{
InternalSpecialTokens = new Dictionary<StringSpanOrdinalKey, int>();
SpecialTokensReverse = new Dictionary<int, string>();

foreach (var item in SpecialTokens)
{
InternalSpecialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value);
SpecialTokensReverse.Add(item.Value, item.Key);
}

// We create this Regex object without a timeout, as we expect the match operation to complete in O(N) time complexity. Note that `specialTokens` are treated as constants after the tokenizer is created.
SpecialTokensRegex = new Regex(string.Join("|", SpecialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
}

Normalizer = new SentencePieceNormalizer(
options.PrecompiledNormalizationData,
options.RemoveExtraWhiteSpaces,
options.AddDummyPrefix, options.EscapeWhiteSpaces,
options.TreatWhitespaceAsSuffix,
SpecialTokens);
}

internal Regex? SpecialTokensRegex { get; }

internal Dictionary<StringSpanOrdinalKey, int>? InternalSpecialTokens { get; }
Expand Down
46 changes: 0 additions & 46 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,52 +41,6 @@ internal SentencePieceBpeModel(ModelProto modelProto, bool addBos, bool addEos,
OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
}

internal SentencePieceBpeModel(SentencePieceOptions options) : base(options)
{
if (options.PrecompiledNormalizationData is not null)
{
throw new NotSupportedException("Normalization data is not supported for SentencePieceBpeModel.");
}

Debug.Assert(options.Vocabulary is not null);

int id = 0;
foreach (var item in options.Vocabulary!)
{
_vocab.Add(new StringSpanOrdinalKey(item.Token), (id, item.Score, (byte)ModelProto.Types.SentencePiece.Types.Type.Normal));
_vocabReverse.Add(id++, item.Token);
}

if (options.ByteFallback)
{
if (!_vocab.TryGetValue("<0x00>", out (int Id, float Score, byte Type) value))
{
throw new ArgumentException("'ByteFallback' is enabled but the vocabulary must include a special token for each byte value (0-255) in the format <0xNN>, where NN represents the byte's hexadecimal value.");
}

ByteCodeToIdOffset = value.Id;
OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
}

if (!_vocab.TryGetValue(options.UnknownToken, out (int Id, float Score, byte Type) unknownToken))
{
throw new ArgumentException($"The vocabulary must include the unknown token '{options.UnknownToken}'.");
}
UnknownId = unknownToken.Id;

if (!_vocab.TryGetValue(options.BeginningOfSentenceToken, out (int Id, float Score, byte Type) beginOfSentenceToken))
{
throw new ArgumentException($"The vocabulary must include the beginning of sentence token '{options.BeginningOfSentenceToken}'.");
}
BeginningOfSentenceId = beginOfSentenceToken.Id;

if (!_vocab.TryGetValue(options.EndOfSentenceToken, out (int Id, float Score, byte Type) endOfSentenceToken))
{
throw new ArgumentException($"The vocabulary must include the end of sentence token '{options.EndOfSentenceToken}'.");
}
EndOfSentenceId = endOfSentenceToken.Id;
}

public override IReadOnlyDictionary<string, int> Vocabulary
{
get
Expand Down
118 changes: 0 additions & 118 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs

This file was deleted.

24 changes: 0 additions & 24 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,6 @@ internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos,
};
}

internal SentencePieceTokenizer(SentencePieceOptions options)
{
_model = options.ModelType switch
{
SentencePieceModelType.Bpe => new SentencePieceBpeModel(options),
SentencePieceModelType.Unigram => new SentencePieceUnigramModel(options),
_ => throw new ArgumentException($"The model type '{options.ModelType}' is not supported.", nameof(options.ModelType))
};
}

/// <summary>
/// The special tokens.
/// </summary>
Expand Down Expand Up @@ -467,19 +457,5 @@ public static SentencePieceTokenizer Create(

return new SentencePieceTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens);
}

/// <summary>
/// Creates an instance of SentencePieceTokenizer.
/// </summary>
/// <param name="options">The options to use for the sentence piece tokenizer.</param>
public static SentencePieceTokenizer Create(SentencePieceOptions options)
{
if (options is null)
{
throw new ArgumentNullException(nameof(options));
}

return new SentencePieceTokenizer(options);
}
}
}
66 changes: 8 additions & 58 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ internal sealed class SentencePieceUnigramModel : SentencePieceBaseModel
public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary<string, int>? specialTokens = null) : base(modelProto, addBos, addEos, specialTokens)
{
_vocab = new SortedDictionary<string, int>(OrdinalUtf8StringComparer.Instance);

if (modelProto.TrainerSpec.BosId >= modelProto.Pieces.Count ||
modelProto.TrainerSpec.EosId >= modelProto.Pieces.Count ||
modelProto.TrainerSpec.UnkId >= modelProto.Pieces.Count)
{
throw new ArgumentException("The BOS, EOS, or UNK token is not present in the vocabulary.");
}

_vocabReverse = new (string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)[modelProto.Pieces.Count];

_minScore = float.MaxValue;
Expand Down Expand Up @@ -85,64 +93,6 @@ public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos
}
}

public SentencePieceUnigramModel(SentencePieceOptions options) : base(options)
{
_vocab = new SortedDictionary<string, int>(OrdinalUtf8StringComparer.Instance);
// _vocabReverse = new (string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)[];

// 250_000 using big number to avoid reallocation during the initialization.
List<(string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)> vocabReverse = new(250_000);

_minScore = float.MaxValue;
_maxScore = float.MinValue;

int id = 0;
foreach ((string Token, float Score) item in options.Vocabulary!)
{
_vocab.Add(item.Token, id++);
vocabReverse.Add((item.Token, item.Score, ModelProto.Types.SentencePiece.Types.Type.Normal));
_minScore = Math.Min(_minScore, item.Score);
_maxScore = Math.Max(_maxScore, item.Score);
}

_vocabReverse = vocabReverse.ToArray();

if (options.ByteFallback)
{
if (!_vocab.TryGetValue("<0x00>", out id))
{
throw new ArgumentException("'ByteFallback' is enabled but the vocabulary must include a special token for each byte value (0-255) in the format <0xNN>, where NN represents the byte's hexadecimal value.");
}

ByteCodeToIdOffset = id;
OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
MaxIdByteFallbackId = ByteCodeToIdOffset + 0xFF; // from <0x00> to <0xFF>.
}

_trie = new DoubleArrayTrie(_vocab);

_vocabReverse[BeginningOfSentenceId] = (BeginningOfSentenceToken, 0f, 0);
_vocabReverse[EndOfSentenceId] = (EndOfSentenceToken, 0f, 0);

if (!_vocab.TryGetValue(options.UnknownToken, out int unknownToken))
{
throw new ArgumentException($"The vocabulary must include the unknown token '{options.UnknownToken}'.");
}
UnknownId = unknownToken;

if (!_vocab.TryGetValue(options.BeginningOfSentenceToken, out int beginOfSentenceToken))
{
throw new ArgumentException($"The vocabulary must include the beginning of sentence token '{options.BeginningOfSentenceToken}'.");
}
BeginningOfSentenceId = beginOfSentenceToken;

if (!_vocab.TryGetValue(options.EndOfSentenceToken, out int endOfSentenceToken))
{
throw new ArgumentException($"The vocabulary must include the end of sentence token '{options.EndOfSentenceToken}'.");
}
EndOfSentenceId = endOfSentenceToken;
}

public override IReadOnlyDictionary<string, int> Vocabulary => new ReadOnlyDictionary<string, int>(_vocab);

public int MaxIdByteFallbackId { get; }
Expand Down
Loading
Loading