Skip to content

Commit

Permalink
Embed the Tokenizer data files inside the assembly (#6403)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekgh authored Oct 24, 2022
1 parent 1903fa5 commit c69acbe
Show file tree
Hide file tree
Showing 9 changed files with 100,381 additions and 102 deletions.
114 changes: 85 additions & 29 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,71 @@ public sealed class EnglishRoberta : Model
/// <param name="highestOccurrenceMappingPath">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
public EnglishRoberta(string vocabularyPath, string mergePath, string highestOccurrenceMappingPath)
{
if (vocabularyPath is null)
{
throw new ArgumentNullException(nameof(vocabularyPath));
}

if (mergePath is null)
{
throw new ArgumentNullException(nameof(mergePath));
}

if (highestOccurrenceMappingPath is null)
{
throw new ArgumentNullException(nameof(highestOccurrenceMappingPath));
}

using Stream vocabularyStream = File.OpenRead(vocabularyPath);
using Stream mergeStream = File.OpenRead(mergePath);
using Stream highestOccurrenceMappingStream = File.OpenRead(highestOccurrenceMappingPath);

// vocabularyPath like encoder.json
// merge file like vocab.bpe
// highestOccurrenceMappingPath like dict.txt

_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingPath);
_vocab = GetVocabulary(vocabularyPath);
_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
_mergeRanks = GetMergeRanks(mergePath);
_mergeRanks = GetMergeRanks(mergeStream);
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
_charToString = new string[maxCharValue];
for (char c = (char)0; c < (char)maxCharValue; c++)
{
_charToString[c] = c.ToString();
}

_unicodeToByte = _byteToUnicode.Reverse();
_cache = new Cache<string, IReadOnlyList<Token>>();
}

/// <summary>
/// Construct tokenizer object to use with the English Robert model.
/// </summary>
/// <param name="vocabularyStream">The stream of a JSON file containing the dictionary of string keys and their ids.</param>
/// <param name="mergeStream">The stream of a file containing the tokens's pairs list.</param>
/// <param name="highestOccurrenceMappingStream">Remap the original GPT-2 model Ids to high occurrence ranks and values.</param>
public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highestOccurrenceMappingStream)
{
if (vocabularyStream is null)
{
throw new ArgumentNullException(nameof(vocabularyStream));
}

if (mergeStream is null)
{
throw new ArgumentNullException(nameof(mergeStream));
}

if (highestOccurrenceMappingStream is null)
{
throw new ArgumentNullException(nameof(highestOccurrenceMappingStream));
}

_vocabIdToHighestOccurrence = GetHighestOccurrenceMapping(highestOccurrenceMappingStream);
_vocab = GetVocabulary(vocabularyStream);
_vocabReverse = _vocab.ReverseSorted();
_mergeRanks = GetMergeRanks(mergeStream);
int maxCharValue = GetByteToUnicode(out _byteToUnicode);
_charToString = new string[maxCharValue];
for (char c = (char)0; c < (char)maxCharValue; c++)
Expand Down Expand Up @@ -298,28 +355,24 @@ private IReadOnlyList<Token> ModifyTokenListOffsets(IReadOnlyList<Token> tokens,
return tokens;
}

private static HighestOccurrenceMapping GetHighestOccurrenceMapping(string highestOccurrenceMappingPath) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingPath);
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);

private Dictionary<string, int> GetVocabulary(string vocabularyPath)
private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
{
Dictionary<string, int>? vocab;
try
{
using (Stream stream = File.OpenRead(vocabularyPath))
{
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(stream) as Dictionary<string, int>;

}
vocab = JsonSerializer.Deserialize<Dictionary<string, int>>(vocabularyStream) as Dictionary<string, int>;
}
catch (Exception e)
{
throw new ArgumentException($"Problems met when parsing JSON object in {vocabularyPath}.{Environment.NewLine}Error message: {e.Message}");
throw new ArgumentException($"Problems met when parsing JSON vocabulary object.{Environment.NewLine}Error message: {e.Message}");
}

if (vocab is null)
{
throw new ArgumentException($"Failed to read the vocabulary file '{vocabularyPath}'");
throw new ArgumentException($"Failed to read the vocabulary file.");
}

if (_vocabIdToHighestOccurrence.BosWord is not null)
Expand All @@ -345,28 +398,28 @@ private Dictionary<string, int> GetVocabulary(string vocabularyPath)
return vocab;
}

private Dictionary<(string, string), int> GetMergeRanks(string mergePath)
private Dictionary<(string, string), int> GetMergeRanks(Stream mergeStream)
{
string[] splitContents;
List<string> splitContents = new();

try
{
splitContents = File.ReadAllLines(mergePath);
using StreamReader reader = new StreamReader(mergeStream);
while (reader.Peek() >= 0)
{
splitContents.Add(reader.ReadLine());
}
}
catch (Exception e)
{
throw new IOException($"Cannot read the file '{mergePath}'.{Environment.NewLine}Error message: {e.Message}", e);
throw new IOException($"Cannot read the file Merge file.{Environment.NewLine}Error message: {e.Message}", e);
}

var mergeRanks = new Dictionary<(string, string), int>();

for (int i = 0; i < splitContents.Length; i++)
// We ignore the first and last line in the file
for (int i = 1; i < splitContents.Count - 1; i++)
{
if (i == 0 || i == splitContents.Length - 1)
{
continue;
}

var split = splitContents[i].Split(' ');
if (split.Length != 2 || string.IsNullOrEmpty(split[0]) || string.IsNullOrEmpty(split[1]))
{
Expand Down Expand Up @@ -664,22 +717,25 @@ public int this[int idx]
/// 284 432911125
/// ...
/// </summary>
public static HighestOccurrenceMapping Load(string fileName)
public static HighestOccurrenceMapping Load(Stream stream)
{
var mapping = new HighestOccurrenceMapping();
mapping.AddFromFile(fileName);
mapping.AddFromStream(stream);
return mapping;
}

/// <summary>
/// Loads a pre-existing vocabulary from a text file and adds its symbols to this instance.
/// Loads a pre-existing vocabulary from a text stream and adds its symbols to this instance.
/// </summary>
public void AddFromFile(string fileName)
public void AddFromStream(Stream stream)
{
var lines = File.ReadAllLines(fileName, Encoding.UTF8);
Debug.Assert(stream is not null);
using StreamReader reader = new StreamReader(stream);

foreach (var line in lines)
while (reader.Peek() >= 0)
{
string line = reader.ReadLine();

var splitLine = line.Trim().Split(' ');
if (splitLine.Length != 2)
{
Expand Down
24 changes: 11 additions & 13 deletions src/Microsoft.ML.TorchSharp/Extensions/TokenizerExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.ML.Runtime;
using Microsoft.ML.Tokenizers;
Expand All @@ -13,25 +14,22 @@ namespace Microsoft.ML.TorchSharp.Extensions
{
internal static class TokenizerExtensions
{
private const string EncoderJsonName = "encoder.json";
private const string MergeName = "vocab.bpe";
private const string DictName = "dict.txt";

private static readonly Uri _encoderJsonUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json");
private static readonly Uri _mergeUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe");
private static readonly Uri _dictUrl = new Uri("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt");

private static Tokenizer _instance;

internal static Tokenizer GetInstance(IChannel ch)
{
if (_instance is null)
{
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, EncoderJsonName, _encoderJsonUrl, ch);
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, MergeName, _mergeUrl, ch);
FileUtils.LoadFromFileOrDownloadFromWeb(string.Empty, DictName, _dictUrl, ch);

EnglishRoberta model = new EnglishRoberta(EncoderJsonName, MergeName, DictName);
// encoder.json, vocab.bpe, and dict.txt are picked up from the following source:
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
// "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"
Assembly assembly = typeof(TokenizerExtensions).Assembly;

EnglishRoberta model = new EnglishRoberta(
assembly.GetManifestResourceStream("encoder.json"),
assembly.GetManifestResourceStream("vocab.bpe"),
assembly.GetManifestResourceStream("dict.txt"));
model.AddMaskSymbol();
_instance = new Tokenizer(model, new RobertaPreTokenizer());
}
Expand Down
18 changes: 15 additions & 3 deletions src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

<ItemGroup>
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all"/>
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all"/>
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all"/>
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all" />
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all" />
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX'))" PrivateAssets="all" />
</ItemGroup>

<ItemGroup>
Expand All @@ -24,4 +24,16 @@
</ProjectReference>
</ItemGroup>

<ItemGroup>
<EmbeddedResource Include="Resources\dict.txt">
<LogicalName>dict.txt</LogicalName>
</EmbeddedResource>
<EmbeddedResource Include="Resources\encoder.json">
<LogicalName>encoder.json</LogicalName>
</EmbeddedResource>
<EmbeddedResource Include="Resources\vocab.bpe">
<LogicalName>vocab.bpe</LogicalName>
</EmbeddedResource>
</ItemGroup>

</Project>
Loading

0 comments on commit c69acbe

Please sign in to comment.