Skip to content

Commit

Permalink
[release/4.0] Moved SpecialTokens assignment after the modification t…
Browse files Browse the repository at this point in the history
…o avoid "Collection Modified" error (#7330)

* Moved special tokens assignment below so the collection won't be modified

* Added safe dictionary inversion

* Added storing the not-normalized special tokens

* Added support for net standard

* Added and updated tests

* Updated without additional memory allocation

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Fix copilot changes

---------

Co-authored-by: Shaltiel Shmidman <shaltieltzion@gmail.com>
Co-authored-by: Tarek Mahmoud Sayed <10833894+tarekgh@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Tarek Mahmoud Sayed <tarekms@microsoft.com>
  • Loading branch information
5 people authored Dec 9, 2024
1 parent d92c0b3 commit cfa306e
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 21 deletions.
45 changes: 29 additions & 16 deletions src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand Down Expand Up @@ -762,15 +762,16 @@ private static BertTokenizer Create(

options.Normalizer ??= options.ApplyBasicTokenization ? new BertNormalizer(options.LowerCaseBeforeTokenization, options.IndividuallyTokenizeCjk, options.RemoveNonSpacingMarks) : null;

IReadOnlyDictionary<string, int>? specialTokensDict = options.SpecialTokens;
if (options.SplitOnSpecialTokens)
{
bool lowerCase = options.ApplyBasicTokenization && options.LowerCaseBeforeTokenization;
if (options.SpecialTokens is not null)
{
if (lowerCase)
{
Dictionary<string, int> dic = options.SpecialTokens.ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
options.SpecialTokens = dic;
Dictionary<string, int> tempSpecialTokens = new Dictionary<string, int>();
specialTokensDict = tempSpecialTokens;

foreach (var kvp in options.SpecialTokens)
{
Expand All @@ -779,37 +780,49 @@ private static BertTokenizer Create(
throw new ArgumentException($"The special token '{kvp.Key}' is not in the vocabulary or assigned id value {id} different than the value {kvp.Value} in the special tokens.");
}

// Ensure that the special tokens are lowercased.
dic[kvp.Key.ToLowerInvariant()] = kvp.Value;
// Add the special token into our dictionary, normalizing it, and adding it into the
// main vocab, if needed.
AddSpecialToken(vocab, tempSpecialTokens, kvp.Key, true);
}
}
}
else
{
// Create a dictionary with the special tokens.
Dictionary<string, int> specialTokens = new Dictionary<string, int>();
options.SpecialTokens = specialTokens;

AddSpecialToken(vocab, specialTokens, options.UnknownToken, lowerCase);
AddSpecialToken(vocab, specialTokens, options.SeparatorToken, lowerCase);
AddSpecialToken(vocab, specialTokens, options.PaddingToken, lowerCase);
AddSpecialToken(vocab, specialTokens, options.ClassificationToken, lowerCase);
AddSpecialToken(vocab, specialTokens, options.MaskingToken, lowerCase);
// Create a dictionary with the special tokens - store the un-normalized forms in the options as
// that field is exposed to the public. In addition, store the normalized form for creating the
// pre-tokenizer.
Dictionary<string, int> tempSpecialTokens = new Dictionary<string, int>();
Dictionary<string, int> notNormalizedSpecialTokens = new Dictionary<string, int>();
AddSpecialToken(vocab, tempSpecialTokens, options.UnknownToken, lowerCase, notNormalizedSpecialTokens);
AddSpecialToken(vocab, tempSpecialTokens, options.SeparatorToken, lowerCase, notNormalizedSpecialTokens);
AddSpecialToken(vocab, tempSpecialTokens, options.PaddingToken, lowerCase, notNormalizedSpecialTokens);
AddSpecialToken(vocab, tempSpecialTokens, options.ClassificationToken, lowerCase, notNormalizedSpecialTokens);
AddSpecialToken(vocab, tempSpecialTokens, options.MaskingToken, lowerCase, notNormalizedSpecialTokens);

options.SpecialTokens = notNormalizedSpecialTokens;
specialTokensDict = tempSpecialTokens;
}
}

options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? options.SpecialTokens : null) : PreTokenizer.CreateWhiteSpace();
// We set the PreTokenizer here using the normalized special tokens dict (if relevant), and therefore we can
// keep the not-normalized special tokens dict in the options passed to the WordPieceTokenizer.
options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? specialTokensDict : null) : PreTokenizer.CreateWhiteSpace();

return new BertTokenizer(vocab, vocabReverse, options);
}

private static void AddSpecialToken(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<string, int> specialTokens, string token, bool lowerCase)
private static void AddSpecialToken(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<string, int> specialTokens, string token, bool lowerCase, Dictionary<string, int>? notNormalizedSpecialTokens = null)
{
if (token is null || !vocab.TryGetValue(new StringSpanOrdinalKey(token), out int id))
{
throw new ArgumentException($"The special token '{token}' is not in the vocabulary.");
}

if (notNormalizedSpecialTokens is not null)
{
notNormalizedSpecialTokens[token] = id;
}

string normalizedToken = token;
if (lowerCase)
{
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand Down Expand Up @@ -42,7 +42,7 @@ internal WordPieceTokenizer(
options ??= new();

SpecialTokens = options.SpecialTokens;
SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key) : null;
SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.GroupBy(kvp => kvp.Value).ToDictionary(g => g.Key, g => g.First().Key) : null;

if (options.UnknownToken is null)
{
Expand Down Expand Up @@ -800,4 +800,4 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
return OperationStatus.Done;
}
}
}
}
93 changes: 91 additions & 2 deletions test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand All @@ -14,6 +14,91 @@ namespace Microsoft.ML.Tokenizers.Tests
{
public class BertTokenizerTests
{
[Fact]
public void TestWithLowerCasingExplicitSpecialTokens()
{
// Add [SPECIAL] token at end (to keep indices as is)
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"];

string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);

Dictionary<string, int> specialTokens = new() {
{ "[PAD]", 0 },
{ "[UNK]", 1 },
{ "[CLS]", 2 },
{ "[SEP]", 3 },
{ "[MASK]", 4 },
{ "[SPECIAL]", 13 },
};
var bertOptions = new BertOptions()
{
SpecialTokens = specialTokens
};

try
{
using Stream vocabStream = File.OpenRead(vocabFile);
BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile, bertOptions), BertTokenizer.Create(vocabStream, bertOptions)];

foreach (var tokenizer in bertTokenizers)
{
Assert.NotNull(tokenizer.PreTokenizer);
Assert.Equal("[UNK]", tokenizer.UnknownToken);
Assert.Equal(1, tokenizer.UnknownTokenId);
Assert.NotNull(tokenizer.Normalizer);
Assert.NotNull(tokenizer.PreTokenizer);

Assert.True(tokenizer.SpecialTokens!.ContainsKey("[SPECIAL]"));

string text = "Hello, How are you [SPECIAL]?";
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
Assert.Equal("hello, how are you [special]?", normalizedText);

Assert.Equal(
[
new EncodedToken(8, "hello", new Range(0, 5)),
new EncodedToken(6, ",", new Range(5, 6)),
new EncodedToken(10, "how", new Range(7, 10)),
new EncodedToken(11, "are", new Range(11, 14)),
new EncodedToken(12, "you", new Range(15, 18)),
new EncodedToken(13, "[SPECIAL]", new Range(19, 28)),
new EncodedToken(7, "?", new Range(28, 29))
],
tokens);

var ids = tokenizer.EncodeToIds(text);
Assert.Equal([tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId], ids);

Assert.Equal("[CLS] hello, how are you [SPECIAL]? [SEP]", tokenizer.Decode(ids));
Assert.Equal("hello, how are you?", tokenizer.Decode(ids, skipSpecialTokens: true));

tokens = tokenizer.EncodeToTokens(tokenizer.Decode(ids), out normalizedText);
Assert.Equal("[cls] hello, how are you [special]? [sep]", normalizedText);
Assert.Equal(
[
new EncodedToken(2, "[CLS]", new Range(0, 5)),
new EncodedToken(8, "hello", new Range(6, 11)),
new EncodedToken(6, ",", new Range(11, 12)),
new EncodedToken(10, "how", new Range(13, 16)),
new EncodedToken(11, "are", new Range(17, 20)),
new EncodedToken(12, "you", new Range(21, 24)),
new EncodedToken(13, "[SPECIAL]", new Range(25, 34)),
new EncodedToken(7, "?", new Range(34, 35)),
new EncodedToken(3, "[SEP]", new Range(36, 41))
],
tokens);

ids = tokenizer.EncodeToIds(normalizedText!);
Assert.Equal([tokenizer.ClassificationTokenId, tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId, tokenizer.SeparatorTokenId], ids);
}
}
finally
{
File.Delete(vocabFile);
}
}

[Fact]
public void TestWithLowerCasing()
{
Expand All @@ -35,6 +120,10 @@ public void TestWithLowerCasing()
Assert.NotNull(tokenizer.Normalizer);
Assert.NotNull(tokenizer.PreTokenizer);

// Make sure the SpecialTokens dictionary contains the not-normalized tokens
Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.UnknownToken));
Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.ClassificationToken));

string text = "Hello, How are you?";
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
Assert.Equal("hello, how are you?", normalizedText);
Expand Down Expand Up @@ -511,4 +600,4 @@ public void TestCreateTokenTypeIdsFromSequences()
}
}
}
}
}

0 comments on commit cfa306e

Please sign in to comment.