-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Models] Support for tokenizers in C++ runtime (#69)
Adds supports for LLaMA, GPT-2, and OPT tokenizers using the Hugging Face configuration
- Loading branch information
1 parent
68a3fe4
commit 9033c10
Showing
26 changed files
with
1,868 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#pragma once | ||
#include <sstream> | ||
|
||
// dbg is a stream which can be used to debug various aspects of the tokenizer; its contents are returned | ||
// as part of the tokenizer output. | ||
extern std::stringstream dbg; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#pragma once | ||
#include <cstddef> | ||
#include <memory> | ||
#include <numeric> | ||
#include <string> | ||
#include <vector> | ||
#include <hidet/runtime/llm/tokenizer/pattern.h> | ||
#include <hidet/runtime/llm/tokenizer/utf8.h> | ||
|
||
// Decoder manipulates a sequence of raw tokens to produce a human-readable form. | ||
class Decoder { | ||
public: | ||
virtual std::vector<std::string> decode_chain(std::vector<std::string> tokens) = 0; | ||
virtual ~Decoder() = default; | ||
|
||
std::string decode(std::vector<std::string> tokens) { | ||
std::string ret; | ||
for (std::string const &s : decode_chain(std::move(tokens))) ret += s; | ||
return ret; | ||
} | ||
}; | ||
|
||
// SequenceDecoder runs a sequence of decoders sequentially, passing the output of one decoder into another. | ||
class SequenceDecoder: public Decoder { | ||
public: | ||
explicit SequenceDecoder(std::vector<std::unique_ptr<Decoder>> decoders); | ||
std::vector<std::string> decode_chain(std::vector<std::string> tokens) final; | ||
|
||
private: | ||
std::vector<std::unique_ptr<Decoder>> decoders; | ||
}; | ||
|
||
// ReplaceDecoder replaces all instances of a pattern with another on a token-by-token basis. | ||
class ReplaceDecoder: public Decoder { | ||
public: | ||
explicit ReplaceDecoder(std::string const &pattern, std::string content); | ||
std::vector<std::string> decode_chain(std::vector<std::string> tokens) final; | ||
|
||
private: | ||
RegexPattern pattern; | ||
std::string content; | ||
}; | ||
|
||
// ByteLevelDecoder is meant to be used with ByteLevel pre-tokenization (for example, in GPT-2). It reverses the | ||
// mapping of the ByteLevel pre-tokenization step, converting the byte-level replacements back into human-readable | ||
// characters. | ||
class ByteLevelDecoder: public Decoder { | ||
public: | ||
ByteLevelDecoder(); | ||
std::vector<std::string> decode_chain(std::vector<std::string> tokens) final; | ||
|
||
private: | ||
std::map<std::string, uint8_t> chars_to_bytes; | ||
}; | ||
|
||
class FuseDecoder: public Decoder { | ||
public: | ||
std::vector<std::string> decode_chain(std::vector<std::string> tokens) final; | ||
}; | ||
|
||
class StripDecoder: public Decoder { | ||
std::string content; | ||
int n_begin; | ||
int n_end; | ||
|
||
public: | ||
explicit StripDecoder(std::string content, int n_begin, int n_end); | ||
std::vector<std::string> decode_chain(std::vector<std::string> tokens) final; | ||
}; | ||
|
||
class ByteFallbackDecoder: public Decoder { | ||
public: | ||
std::vector<std::string> decode_chain(std::vector<std::string> tokens) final; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#pragma once | ||
#include <iomanip> | ||
#include <limits> | ||
#include <map> | ||
#include <queue> | ||
#include <sstream> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
#include <hidet/runtime/llm/tokenizer/utf8.h> | ||
|
||
// Model takes in a chunk of text from the pre-tokenization step and splits it into a set of tokens, | ||
// each of which are described by their unsigned 32-bit IDs. | ||
class Model { | ||
public: | ||
virtual std::vector<uint32_t> tokenize(std::string const &sequence) = 0; | ||
virtual std::string id_to_token(uint32_t id) = 0; | ||
virtual ~Model() = default; | ||
}; | ||
|
||
// BPEWord corresponds roughly to the Hugging Face "Word" implementation. It describes how a chunk of | ||
// text from the pre-tokenization step is split. Initially a chunk of text is split character-wise, and | ||
// the BPE algorithm works to merge these characters into successively larger tokens. | ||
// | ||
// The underlying representation is a doubly-linked list of "Piece" nodes, which point to the previous | ||
// and next token in the sequence. It leverages the fact that the linked list never grows in size to | ||
// avoid dynamic memory allocation entirely -- instead, it uses a fixed-size vector to store the nodes, | ||
// which are then addressed by their index in this vector. | ||
class BPEWord { | ||
public: | ||
explicit BPEWord(std::vector<uint32_t> const &ids); | ||
|
||
// merge merges the node at i with the node which follows after it, creating a new node with the given ID. | ||
int merge(int i, uint32_t new_id); | ||
|
||
// ids provides IDs for all tokens currently in the BPEWord. | ||
std::vector<uint32_t> ids() const; | ||
|
||
uint32_t at(int i) const { return data[i].id; } | ||
int prev(int i) const { return data[i].prev; } | ||
int next(int i) const { return data[i].next; } | ||
bool valid(int i) const { return data[i].id != std::numeric_limits<uint32_t>::max(); } | ||
int begin() const { return 0; } | ||
int end() const { return data.size() - 1; } | ||
|
||
private: | ||
// Piece corresponds to a single token in the BPEWord. It contains the ID of the token, as well as | ||
// the "addresses" of the previous and next tokens in the sequence. | ||
struct Piece { | ||
// The default ID is the maximum value of a 32-bit unsigned integer, which is also used to invalidate | ||
// a Piece node and indicate that is no longer in use, as the result of some merges. | ||
uint32_t id{std::numeric_limits<uint32_t>::max()}; | ||
int prev{}; | ||
int next{}; | ||
|
||
Piece() = default; | ||
Piece(uint32_t id, int prev, int next) : id{id}, prev{prev}, next{next} {}; | ||
}; | ||
|
||
// The underlying data structure for the BPEWord. | ||
std::vector<Piece> data; | ||
}; | ||
|
||
// BPEModel | ||
class BPEModel: public Model { | ||
public: | ||
BPEModel(std::map<std::string, uint32_t> vocab, std::vector<std::pair<std::string, std::string>> const &merges, | ||
bool byte_fallback); | ||
std::vector<uint32_t> tokenize(std::string const &sequence) final; | ||
std::string id_to_token(uint32_t id) final; | ||
|
||
private: | ||
// The vocabulary that assigns a number to each token. | ||
std::map<std::string, uint32_t> vocab; | ||
// Reversed vocabulary, to rebuild sentences. | ||
std::map<uint32_t, std::string> vocab_r; | ||
// Contains the mapping between pairs of IDs and their (score, new_id) after | ||
// merging. | ||
std::map<std::pair<uint32_t, uint32_t>, std::pair<int, uint32_t>> merges; | ||
// Caches the results of calls to tokenize. | ||
std::map<std::string, std::vector<uint32_t>> cache; | ||
// Whether byte fallbacks (e.g. <0xFA>) should be used for characters not in the vocabulary. | ||
bool byte_fallback; | ||
|
||
// Helper for tokenize | ||
void merge_word(BPEWord &word); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#pragma once | ||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
#include <hidet/runtime/llm/tokenizer/pattern.h> | ||
|
||
// Normalizer takes in an input text and applies a transformation to normalize it. | ||
// This is the first step in the tokenization pipeline, coming before the pre-tokenizer. | ||
class Normalizer { | ||
public: | ||
virtual void normalize(std::string &s) = 0; | ||
virtual ~Normalizer() = default; | ||
}; | ||
|
||
// SequenceNormalizer runs a pre-defined set of normalizers in sequence. | ||
class SequenceNormalizer: public Normalizer { | ||
std::vector<std::unique_ptr<Normalizer>> normalizers; | ||
|
||
public: | ||
explicit SequenceNormalizer(std::vector<std::unique_ptr<Normalizer>> normalizers); | ||
void normalize(std::string &s) final; | ||
}; | ||
|
||
// PrependNormalizer prepends a prefix to a string. | ||
class PrependNormalizer: public Normalizer { | ||
std::string prefix; | ||
|
||
public: | ||
explicit PrependNormalizer(std::string prefix); | ||
void normalize(std::string &s) final; | ||
}; | ||
|
||
// ReplaceNormalizer replaces all instances of a pattern with the given content. | ||
class ReplaceNormalizer: public Normalizer { | ||
RegexPattern pattern; | ||
std::string content; | ||
|
||
public: | ||
explicit ReplaceNormalizer(const std::string &pattern, std::string content); | ||
void normalize(std::string &s) final; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#pragma once | ||
#include <regex> | ||
#include <string> | ||
#include <vector> | ||
|
||
// SplitDelimiterBehavior defines how the delimiter is treated when splitting a string. | ||
enum class SplitDelimiterBehavior { | ||
Removed, // A.B -> [A B] | ||
Isolated, // A.B -> [A . B] | ||
MergedWithPrevious, // A.B -> [A. B] | ||
MergedWithNext, // A.B -> [A .B] | ||
}; | ||
|
||
// PatternMatch represent a segment of a string that is either a match or not. | ||
struct PatternMatch { | ||
// The starting index of the segment, inclusive. | ||
int start; | ||
// The ending index of the segment, exclusive. | ||
int end; | ||
// Whether the segment is a match or not. | ||
bool is_match; | ||
|
||
PatternMatch(int start, int end, bool is_match) : start{start}, end{end}, is_match{is_match} {} | ||
}; | ||
|
||
class Pattern { | ||
public: | ||
// Suppose the string has length n, report a partition of [0, n) into intervals that are all either match/no match. | ||
// For example with inside = "abaca" and the pattern "ba", the result is [0 1 F], [1 3 T], [3 5 F]. | ||
virtual std::vector<PatternMatch> find_matches(std::string const &inside) const = 0; | ||
|
||
std::vector<std::string> split(const std::string &s, SplitDelimiterBehavior behaviour) const; | ||
|
||
virtual ~Pattern() = default; | ||
}; | ||
|
||
class RegexPattern: public Pattern { | ||
std::regex pattern; | ||
|
||
public: | ||
explicit RegexPattern(std::regex pattern); | ||
std::vector<PatternMatch> find_matches(std::string const &inside) const final; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#pragma once | ||
#include <cstdint> | ||
#include <map> | ||
#include <stdexcept> | ||
#include <string> | ||
#include <vector> | ||
|
||
// PostProcessor takes the output of the BPE algorithm (a sequence of token IDs) and applies a transformation to it | ||
// to prepare it for consumption by a language model. Usually, this involves prepending/appending special | ||
// starting/ending tokens. | ||
class PostProcessor { | ||
public: | ||
virtual std::vector<uint32_t> process(std::vector<uint32_t> encoding) = 0; | ||
virtual ~PostProcessor() = default; | ||
}; | ||
|
||
// ByteLevelPostProcessor is a no-op post-processor that returns the input encoding as-is. The provided functionality | ||
// by Hugging Face maps IDs back to their source offsets, which we don't need to do here. | ||
class ByteLevelPostProcessor: public PostProcessor { | ||
public: | ||
std::vector<uint32_t> process(std::vector<uint32_t> encoding) final { return encoding; }; | ||
}; | ||
|
||
// TemplateProcessingPostProcessor is a post-processor that takes a vector of strings (called the "template"), which | ||
// define the output of the model. The template is a list of strings, where each string is either a special token or | ||
// "A". If the string is "A", the output of the model is inserted into the output. If the string is a special token, | ||
// the ID of the special token is inserted into the output. | ||
class TemplateProcessingPostProcessor: public PostProcessor { | ||
std::vector<std::string> tmpl; | ||
std::map<std::string, uint32_t> special_tokens; | ||
|
||
public: | ||
TemplateProcessingPostProcessor(std::vector<std::string> tmpl, std::map<std::string, uint32_t> special_tokens); | ||
std::vector<uint32_t> process(std::vector<uint32_t> encoding) final; | ||
}; |
Oops, something went wrong.