Skip to content

Commit

Permalink
[Serving][Grammar] Refactor GrammarStateMatcher and support LLaMA-3
Browse files Browse the repository at this point in the history
This PR refactors GrammarStateMatcher and support the LLaMA-3 tokenizer.

Common tokenizers, including Phi-2, Gemma, LLaMA-2, etc. are also
supported.

The performance is optimized for LLaMA-3 tokenizer since its token table
has size 128k, much larger than LLaMA-2 tokenizer.

These changes are introduced to the grammar library:

These changes are introduced to the grammar library:
1. Introduce ByteString rule expression and simplify CharacterClass
   and CharacterClassStar
2. Refactor BNFGrammarVisitor and BNFGrammarMutator for visiting and
   mutating grammar rules
3. Now GrammarStateMatcherBase, the internally impl of the
   GrammarStateMatcher, accepts char by char, instead of codepoint by
   codepoint. So it supports any valid UTF-8 string, even if the token
   is not a complete codepoint.
4. Support lookahead assertion for rules to specify the rule must be
   followed by a sequence. This can eliminate some uncertain tokens
   in preprocessing.

Minor changes:
1. Introduce template hash function HashCombine
2. Update the UTF8 encoding handling functions

Performance:
1. For JSON, finding mask requires <30us on 5900X with single thread.
   The uncertain tokens is <30 in most cases.
2. For JSON schema, finding mask requires <30us on 5900X with single
   thread. The uncertain tokens is <30 in most cases.
  • Loading branch information
Ubospica committed May 14, 2024
1 parent 679d3a8 commit a418665
Show file tree
Hide file tree
Showing 27 changed files with 1,684 additions and 1,024 deletions.
8 changes: 4 additions & 4 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class EngineImpl : public Engine {
}
n->token_table_ =
Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method);
n->grammar_init_context_storage_ = GrammarInitContextStorage(n->token_table_);
n->grammar_init_context_cache_ = GrammarInitContextCache(n->token_table_);
// - Create the logit processor and sampler, and
// the DraftTokenWorkspaceManager for speculative decoding.
int max_num_tokens = engine_config->max_num_sequence;
Expand Down Expand Up @@ -499,9 +499,9 @@ class EngineImpl : public Engine {
if (response_format.type != "json_object") {
return std::nullopt;
} else if (!response_format.schema) {
return grammar_init_context_storage_->GetInitContextForJSON();
return grammar_init_context_cache_->GetInitContextForJSON();
} else {
return grammar_init_context_storage_->GetInitContextForJSONSchema(
return grammar_init_context_cache_->GetInitContextForJSONSchema(
response_format.schema.value());
}
}
Expand All @@ -513,7 +513,7 @@ class EngineImpl : public Engine {
Tokenizer tokenizer_;
std::vector<std::string> token_table_;
// Helper to get the grammar init context for requests.
GrammarInitContextStorage grammar_init_context_storage_;
GrammarInitContextCache grammar_init_context_cache_;
// Models
Array<Model> models_;
// Device that the models run on.
Expand Down
135 changes: 78 additions & 57 deletions cpp/serve/grammar/grammar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

#include "grammar.h"

#include "grammar_functor.h"
#include "grammar_parser.h"
#include "grammar_serializer.h"
#include "grammar_simplifier.h"
#include "json_schema_converter.h"

namespace mlc {
Expand All @@ -21,18 +21,28 @@ std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) {
return os;
}

BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, const std::string& main_rule,
bool normalize, bool simplify) {
BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string,
const std::string& main_rule) {
auto grammar = EBNFParser::Parse(ebnf_string, main_rule);
if (normalize) {
grammar = NestedRuleUnwrapper(grammar).Apply();
}
// Normalize the grammar by default
grammar = BNFGrammarNormalizer().Apply(grammar);
return grammar;
}

TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString")
.set_body_typed([](String ebnf_string, String main_rule, bool normalize, bool simplify) {
return BNFGrammar::FromEBNFString(ebnf_string, main_rule, normalize, simplify);
.set_body_typed([](String ebnf_string, String main_rule) {
return BNFGrammar::FromEBNFString(ebnf_string, main_rule);
});

// Parse the EBNF string but not normalize it
BNFGrammar DebugFromEBNFStringNoNormalize(const std::string& ebnf_string,
const std::string& main_rule) {
return EBNFParser::Parse(ebnf_string, main_rule);
}

TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarDebugFromEBNFStringNoNormalize")
.set_body_typed([](String ebnf_string, String main_rule) {
return DebugFromEBNFStringNoNormalize(ebnf_string, main_rule);
});

BNFGrammar BNFGrammar::FromJSON(const std::string& json_string) {
Expand Down Expand Up @@ -69,79 +79,90 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema").set_body([](TVMArgs args,
*rv = BNFGrammar::FromSchema(args[0], indent, separators, args[3]);
});

// Optimized json grammar for the speed of the grammar state matcher
const std::string kJSONGrammarString = R"(
main ::= (
"{" ws members_or_embrace |
"[" ws elements_or_embrace
"{" [ \n\t]* members_and_embrace |
"[" [ \n\t]* elements_or_embrace
)
value ::= (
"{" ws members_or_embrace |
"[" ws elements_or_embrace |
"\"" characters "\"" |
[0-9] fraction exponent |
[1-9] digits fraction exponent |
value_non_str ::= (
"{" [ \n\t]* members_and_embrace |
"[" [ \n\t]* elements_or_embrace |
"0" fraction exponent |
[1-9] [0-9]* fraction exponent |
"-" [0-9] fraction exponent |
"-" [1-9] digits fraction exponent |
"-" [1-9] [0-9]* fraction exponent |
"true" |
"false" |
"null"
)
members_or_embrace ::= (
"\"" characters "\"" ws ":" ws value members_rest ws "}" |
"}"
)
members ::= "\"" characters "\"" ws ":" ws value members_rest
members_rest ::= (
"" |
"," ws "\"" characters "\"" ws ":" ws value members_rest |
" " ws "," ws "\"" characters "\"" ws ":" ws value members_rest |
"\n" ws "," ws "\"" characters "\"" ws ":" ws value members_rest |
"\t" ws "," ws "\"" characters "\"" ws ":" ws value members_rest
)
) (= [ \n\t,}\]])
members_and_embrace ::= ("\"" characters_and_colon [ \n\t]* members_suffix | "}") (= [ \n\t,}\]])
members_suffix ::= (
value_non_str [ \n\t]* member_suffix_suffix |
"\"" characters_and_embrace |
"\"" characters_and_comma [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix
) (= [ \n\t,}\]])
member_suffix_suffix ::= (
"}" |
"," [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix
) (= [ \n\t,}\]])
elements_or_embrace ::= (
"{" ws members_or_embrace elements_rest ws "]" |
"[" ws elements_or_embrace elements_rest ws "]" |
"\"" characters "\"" elements_rest ws "]" |
[0-9] fraction exponent elements_rest ws "]" |
[1-9] digits fraction exponent elements_rest ws "]" |
"-" [0-9] fraction exponent elements_rest ws "]" |
"-" [1-9] digits fraction exponent elements_rest ws "]" |
"true" elements_rest ws "]" |
"false" elements_rest ws "]" |
"null" elements_rest ws "]" |
"{" [ \n\t]* members_and_embrace elements_rest [ \n\t]* "]" |
"[" [ \n\t]* elements_or_embrace elements_rest [ \n\t]* "]" |
"\"" characters_item elements_rest [ \n\t]* "]" |
"0" fraction exponent elements_rest [ \n\t]* "]" |
[1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" |
"-" "0" fraction exponent elements_rest [ \n\t]* "]" |
"-" [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" |
"true" elements_rest [ \n\t]* "]" |
"false" elements_rest [ \n\t]* "]" |
"null" elements_rest [ \n\t]* "]" |
"]"
)
elements ::= (
"{" ws members_or_embrace elements_rest |
"[" ws elements_or_embrace elements_rest |
"\"" characters "\"" elements_rest |
[0-9] fraction exponent elements_rest |
[1-9] digits fraction exponent elements_rest |
"{" [ \n\t]* members_and_embrace elements_rest |
"[" [ \n\t]* elements_or_embrace elements_rest |
"\"" characters_item elements_rest |
"0" fraction exponent elements_rest |
[1-9] [0-9]* fraction exponent elements_rest |
"-" [0-9] fraction exponent elements_rest |
"-" [1-9] digits fraction exponent elements_rest |
"-" [1-9] [0-9]* fraction exponent elements_rest |
"true" elements_rest |
"false" elements_rest |
"null" elements_rest
)
elements_rest ::= (
"" |
"," ws elements |
" " ws "," ws elements |
"\n" ws "," ws elements |
"\t" ws "," ws elements
[ \n\t]* "," [ \n\t]* elements
)
characters ::= "" | [^"\\\r\n] characters | "\\" escape characters
characters_and_colon ::= (
"\"" [ \n\t]* ":" |
[^"\\\x00-\x1F] characters_and_colon |
"\\" escape characters_and_colon
) (=[ \n\t]* [\"{[0-9tfn-])
characters_and_comma ::= (
"\"" [ \n\t]* "," |
[^"\\\x00-\x1F] characters_and_comma |
"\\" escape characters_and_comma
) (=[ \n\t]* "\"")
characters_and_embrace ::= (
"\"" [ \n\t]* "}" |
[^"\\\x00-\x1F] characters_and_embrace |
"\\" escape characters_and_embrace
) (=[ \n\t]* [},])
characters_item ::= (
"\"" |
[^"\\\x00-\x1F] characters_item |
"\\" escape characters_item
) (= [ \n\t]* [,\]])
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
digits ::= [0-9] | [0-9] digits
fraction ::= "" | "." digits
exponent ::= "" | "e" sign digits | "E" sign digits
fraction ::= "" | "." [0-9] [0-9]*
exponent ::= "" | "e" sign [0-9] [0-9]* | "E" sign [0-9] [0-9]*
sign ::= "" | "+" | "-"
ws ::= [ \n\t]*
)";

BNFGrammar BNFGrammar::GetGrammarOfJSON() {
static const BNFGrammar grammar =
BNFGrammar::FromEBNFString(kJSONGrammarString, "main", true, false);
static const BNFGrammar grammar = BNFGrammar::FromEBNFString(kJSONGrammarString, "main");
return grammar;
}

Expand Down
45 changes: 18 additions & 27 deletions cpp/serve/grammar/grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,15 @@ using namespace tvm::runtime;
* #### Types of RuleExprs
* Every RuleExpr is represented by a type as well as a variable-length array containing its data.
* RuleExpr has several types:
* - Byte string: a string of bytes (0~255). Supports UTF-8 strings.
* - Character class: a range of characters (each character is a unicode codepoint), e.g. [a-z],
* [ac-z].
* A single character is represented by a character class with the same lower and upper bound.
* A string is represented by a sequence of character classes.
* - Negated character class: all characters that are not in the range, e.g. [^a-z], [^ac-z]
* [ac-z]. Can be negated: [^a-z], [^ac-z]. Now only ascii chars is allowed in [], but this
* expression can accept/reject unicode chars.
* - Character class star: a star quantifier of a character class. e.g. [a-z]*, [^a-z]*.
* - EmptyStr: an empty string, i.e. ""
* - Rule reference: a reference to another rule
* - Sequence: a sequence of rule_exprs, e.g. ("a" "b"). These rule_exprs are concatenated together.
* - Choices: a choice of rule_exprs, e.g. ("a" "b") | "c". Each rule_expr can be matched.
* - Character class star: special support for a repetition of a character class. e.g. [a-z]*
*
* #### Storage of RuleExprs
* Each type of RuleExpr has a different data format. For the format of each type of RuleExpr, see
Expand All @@ -76,6 +75,9 @@ class BNFGrammarNode : public Object {
std::string name;
/*! \brief The RuleExpr id of the body of the rule. */
int32_t body_expr_id;
/*! \brief The id of the associated lookahead assertion expr. For now it must be a id of a
* sequence RuleExpr. -1 if not exists. */
int32_t lookahead_assertion_id = -1;
};

/*! \brief Get the number of rules. */
Expand All @@ -86,6 +88,8 @@ class BNFGrammarNode : public Object {
<< "rule_id " << rule_id << " is out of bound";
return rules_[rule_id];
}
/*! \brief Get the main rule id of the grammar. */
int32_t GetMainRuleId() const { return main_rule_id_; }
/*! \brief Get the main rule of the grammar. */
const Rule& GetMainRule() const {
DCHECK(main_rule_id_ >= 0 && main_rule_id_ < static_cast<int32_t>(rules_.size()))
Expand All @@ -95,10 +99,11 @@ class BNFGrammarNode : public Object {

/*! \brief The type of the rule expr. */
enum class RuleExprType : int32_t {
// data format: [lower0, upper0, lower1, upper1, ...]
// data format: [byte0, byte1, ...]
kByteString,
// data format: [is_negative, lower0, upper0, lower1, upper1, ...]
kCharacterClass,
// data format: [lower0, upper0, lower1, upper1, ...]
kNegCharacterClass,
kCharacterClassStar,
// data format: []
kEmptyStr,
// data format: [rule_id]
Expand All @@ -107,8 +112,6 @@ class BNFGrammarNode : public Object {
kSequence,
// data format: [rule_expr_id0, rule_expr_id1, ...]
kChoices,
// data format: [rule_expr_id]
kCharacterClassStar,
};

/*! \brief The object representing a rule expr. */
Expand Down Expand Up @@ -154,8 +157,8 @@ class BNFGrammarNode : public Object {
std::vector<Rule> rules_;
/*! \brief The data of all rule_exprs. */
std::vector<int32_t> rule_expr_data_;
/*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id corresponds the
* index of this vector. */
/*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id is the index
* to the elements in this vector. */
std::vector<int32_t> rule_expr_indptr_;
/*! \brief The id of the main rule. */
int32_t main_rule_id_ = -1;
Expand All @@ -168,25 +171,13 @@ class BNFGrammarNode : public Object {
class BNFGrammar : public ObjectRef {
public:
/*!
* \brief Construct a BNF grammar with a EBNF-formatted string. Will parse the string and
* transform it into BNF AST.
* \brief Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized
* (simplified) by default.
* \param ebnf_string The EBNF-formatted string.
* \param main_rule The name of the main rule.
* \param normalize Whether to normalize the grammar. Default: true. Only set to false for the
* purpose of testing.
*
* \note In The normalized form of a BNF grammar, every rule is in the form:
* `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`.
*
* I.e. a list of choices, each choice is a sequence of elements. Elements can be a character
* class or a rule reference. And if the rule can be empty, the first choice will be an empty
* string.
* \param simplify Whether to simplify the grammar to make matching more efficient. Default: true.
* Not implemented yet.
*/
static BNFGrammar FromEBNFString(const std::string& ebnf_string,
const std::string& main_rule = "main", bool normalize = true,
bool simplify = true);
const std::string& main_rule = "main");

/*!
* \brief Construct a BNF grammar from the dumped JSON string.
Expand Down
Loading

0 comments on commit a418665

Please sign in to comment.