Skip to content

Commit

Permalink
[Tokenizer] Auto-detect TokenizerInfo from tokenizer.json (#2416)
Browse files Browse the repository at this point in the history
This PR adds a new `TokenizerInfo` class that contains useful information
about the tokenizer during generation. It is auto-detected from
tokenizer.json if it exists. Otherwise it raises a warning and uses
the default value (byte fallback tokenizer, not prepend/strip space).
  • Loading branch information
Ubospica authored May 26, 2024
1 parent c62e143 commit 13c0661
Show file tree
Hide file tree
Showing 12 changed files with 376 additions and 156 deletions.
2 changes: 1 addition & 1 deletion cpp/serve/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.ImageDataGetImage").set_body_typed([](ImageData d
/*! \brief Convert a single token with probability to JSON string. */
inline void TokenToLogProbJSON(const Tokenizer& tokenizer, const TokenProbPair& token_prob,
std::ostringstream* os) {
const std::string& token = tokenizer->TokenTable()[token_prob.first];
const std::string& token = tokenizer->PostProcessedTokenTable()[token_prob.first];

(*os) << "\"token\": \"";
for (char ch : token) {
Expand Down
35 changes: 24 additions & 11 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,8 @@ class EngineImpl : public Engine {
ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()});
}
// - Initialize tokenizer and grammar
n->tokenizer_ = Tokenizer::FromPath(engine_config->model);
std::string token_table_postproc_method;
if (model_configs[0].count("token_table_postproc_method") == 0) {
// Backward compatibility: use "byte_fallback" by default
token_table_postproc_method = "byte_fallback";
} else {
token_table_postproc_method =
model_configs[0].at("token_table_postproc_method").get<std::string>();
}
n->token_table_ =
Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method);
n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0]));
n->token_table_ = n->tokenizer_->PostProcessedTokenTable();
n->grammar_init_context_cache_ = GrammarInitContextCache(n->token_table_);
// - Create the logit processor and sampler, and
// the DraftTokenWorkspaceManager for speculative decoding.
Expand Down Expand Up @@ -549,6 +540,28 @@ class EngineImpl : public Engine {
}
}

static std::optional<TokenizerInfo> GetTokenizerInfo(const picojson::object& model_config) {
if (model_config.count("tokenizer_info") == 0) {
LOG(WARNING) << "Tokenizer info not found in mlc-chat-config.json. "
<< "Trying to automatically detect the tokenizer info";
return std::nullopt;
}
const picojson::object& tokenizer_info_obj =
model_config.at("tokenizer_info").get<picojson::object>();
auto info = make_object<TokenizerInfoNode>();
if (tokenizer_info_obj.count("token_postproc_method")) {
info->token_postproc_method =
tokenizer_info_obj.at("token_postproc_method").get<std::string>();
}
if (tokenizer_info_obj.count("prepend_space_in_encode")) {
info->prepend_space_in_encode = tokenizer_info_obj.at("prepend_space_in_encode").get<bool>();
}
if (tokenizer_info_obj.count("strip_space_in_decode")) {
info->strip_space_in_decode = tokenizer_info_obj.at("strip_space_in_decode").get<bool>();
}
return TokenizerInfo(info);
}

// Engine state, managing requests and request states.
EngineState estate_;
// Configurations and singletons
Expand Down
8 changes: 3 additions & 5 deletions cpp/serve/grammar/grammar_state_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,12 @@ GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr<GrammarStateInitContext
#ifndef COMPILE_MLC_WASM_RUNTIME
// This creates tokenizer dependency issue in WASM building for web, hence skipped
TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenizer")
.set_body_typed([](BNFGrammar grammar, Optional<Tokenizer> tokenizer, int max_rollback_steps,
String token_table_postproc_method) {
.set_body_typed([](BNFGrammar grammar, Optional<Tokenizer> tokenizer, int max_rollback_steps) {
auto preproc_start = std::chrono::high_resolution_clock::now();
std::shared_ptr<mlc::llm::serve::GrammarStateInitContext> init_ctx;
if (tokenizer) {
auto token_table = Tokenizer::PostProcessTokenTable(tokenizer.value()->TokenTable(),
token_table_postproc_method);
init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table);
init_ctx = GrammarStateMatcher::CreateInitContext(
grammar, tokenizer.value()->PostProcessedTokenTable());
} else {
init_ctx = GrammarStateMatcher::CreateInitContext(grammar, {});
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/serve/grammar/grammar_state_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ using namespace tvm::runtime;
* \example
* \code
* Tokenizer tokenizer = ...;
* auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar, tokenizer->TokenTable());
* auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar,
* tokenizer->PostProcessedTokenTable());
* GrammarStateMatcher matcher(init_ctx, 10);
* matcher->AcceptToken(67);
*
Expand Down
1 change: 0 additions & 1 deletion cpp/serve/grammar/grammar_state_matcher_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include <vector>

#include "../../tokenizers.h"
#include "grammar.h"
#include "grammar_state_matcher_state.h"

Expand Down
2 changes: 1 addition & 1 deletion cpp/streamer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ StopStrHandler::StopStrHandler(Array<String> stop_strs,

TVM_REGISTER_GLOBAL("mlc.StopStrHandler")
.set_body_typed([](Array<String> stop_strs, const Tokenizer& tokenizer) {
return StopStrHandler(std::move(stop_strs), tokenizer->TokenTable());
return StopStrHandler(std::move(stop_strs), tokenizer->PostProcessedTokenTable());
});

TVM_REGISTER_GLOBAL("mlc.StopStrHandlerPut")
Expand Down
Loading

0 comments on commit 13c0661

Please sign in to comment.