Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions cpp/serve/grammar/grammar_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ int32_t EBNFParserImpl::ParseCharacterClass() {
continue;
}

auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_, kCustomEscapeMap);
auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_, kCustomEscapeMap);
if (codepoint == static_cast<TCodepoint>(CharHandlingError::kInvalidUtf8)) {
ThrowParseError("Invalid utf8 sequence");
ThrowParseError("Invalid UTF8 sequence");
}
if (codepoint == static_cast<TCodepoint>(CharHandlingError::kInvalidEscape)) {
ThrowParseError("Invalid escape sequence");
}
Consume(len);
Consume(new_cur - cur_);
if (past_is_hyphen) {
ICHECK(!elements.empty());
if (elements.back().lower > codepoint) {
Expand Down Expand Up @@ -194,14 +194,15 @@ int32_t EBNFParserImpl::ParseString() {
if (Peek() == '\r' || Peek() == '\n') {
ThrowParseError("There should be no newline character in a string literal");
}
auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_);

auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_);
if (codepoint == static_cast<TCodepoint>(CharHandlingError::kInvalidUtf8)) {
ThrowParseError("Invalid utf8 sequence");
}
if (codepoint == static_cast<TCodepoint>(CharHandlingError::kInvalidEscape)) {
ThrowParseError("Invalid escape sequence");
}
Consume(len);
Consume(new_cur - cur_);
character_classes.push_back(builder_.AddCharacterClass({{codepoint, codepoint}}));
}
if (character_classes.empty()) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/serve/grammar/grammar_serializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) {
result += "^";
}
for (auto i = 0; i < rule_expr.data_len; i += 2) {
result += CodepointToPrintable(rule_expr[i], kCustomEscapeMap);
result += PrintAsEscaped(rule_expr[i], kCustomEscapeMap);
if (rule_expr[i] == rule_expr[i + 1]) {
continue;
}
result += "-";
result += CodepointToPrintable(rule_expr[i + 1], kCustomEscapeMap);
result += PrintAsEscaped(rule_expr[i + 1], kCustomEscapeMap);
}
result += "]";
return result;
Expand Down
10 changes: 5 additions & 5 deletions cpp/serve/grammar/grammar_state_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState")
bool MatchCompleteString(GrammarStateMatcher matcher, String str) {
auto mutable_node =
const_cast<GrammarStateMatcherNodeImpl*>(matcher.as<GrammarStateMatcherNodeImpl>());
auto codepoints = Utf8StringToCodepoints(str.c_str());
auto codepoints = ParseUTF8(str.c_str());
int accepted_cnt = 0;
for (auto codepoint : codepoints) {
if (!mutable_node->AcceptCodepoint(codepoint, false)) {
Expand Down Expand Up @@ -553,9 +553,9 @@ void PrintAcceptedRejectedTokens(
// First cast to unsigned, then cast to int
std::cerr << static_cast<int>(static_cast<unsigned char>(token[0]));
} else {
auto codepoints = Utf8StringToCodepoints(token.c_str());
auto codepoints = ParseUTF8(token.c_str());
for (auto c : codepoints) {
std::cerr << CodepointToPrintable(c);
std::cerr << PrintAsEscaped(c);
}
}
std::cerr << "> ";
Expand All @@ -571,9 +571,9 @@ void PrintAcceptedRejectedTokens(
if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) {
std::cerr << (int)(unsigned char)token[0];
} else {
auto codepoints = Utf8StringToCodepoints(token.c_str());
auto codepoints = ParseUTF8(token.c_str());
for (auto c : codepoints) {
std::cerr << CodepointToPrintable(c);
std::cerr << PrintAsEscaped(c);
}
}
std::cerr << "> ";
Expand Down
8 changes: 4 additions & 4 deletions cpp/serve/grammar/grammar_state_matcher_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,15 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool
}
if (tmp_new_stack_tops_.empty()) {
if (verbose) {
std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint)
<< "\" Rejected" << std::endl;
std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Rejected"
<< std::endl;
}
return false;
}
stack_tops_history_.PushHistory(tmp_new_stack_tops_);
if (verbose) {
std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint)
<< "\" Accepted" << std::endl;
std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Accepted"
<< std::endl;
std::cout << "Stack after accepting: " << PrintStackState() << std::endl;
}
#if TVM_LOG_DEBUG
Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/grammar/grammar_state_matcher_preproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ inline std::shared_ptr<GrammarStateInitContext> GrammarStateMatcher::CreateInitC
ptr->special_token_ids.push_back(i);
} else {
// First replace the special underscore with space.
auto codepoints = Utf8StringToCodepoints(token.c_str());
auto codepoints = ParseUTF8(token.c_str());
DCHECK(!codepoints.empty() &&
codepoints[0] != static_cast<TCodepoint>(CharHandlingError::kInvalidUtf8))
<< "Invalid token: " << token;
Expand Down
42 changes: 21 additions & 21 deletions cpp/support/encoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
namespace mlc {
namespace llm {

std::string CodepointToUtf8(TCodepoint codepoint) {
std::string PrintAsUTF8(TCodepoint codepoint) {
ICHECK(codepoint <= 0x10FFFF) << "Invalid codepoint: " << codepoint;
std::string utf8;
if (codepoint <= 0x7F) {
Expand All @@ -36,8 +36,8 @@ std::string CodepointToUtf8(TCodepoint codepoint) {
return utf8;
}

std::string CodepointToPrintable(
TCodepoint codepoint, const std::unordered_map<TCodepoint, std::string>& custom_escape_map) {
std::string PrintAsEscaped(TCodepoint codepoint,
const std::unordered_map<TCodepoint, std::string>& custom_escape_map) {
static const std::unordered_map<TCodepoint, std::string> kCodepointToEscape = {
{'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"},
{'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"},
Expand All @@ -63,10 +63,10 @@ std::string CodepointToPrintable(
return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex;
}

std::pair<TCodepoint, int> Utf8ToCodepoint(const char* utf8) {
const std::array<int8_t, 5> kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07};
std::pair<TCodepoint, const char*> ParseNextUTF8(const char* utf8) {
static const std::array<int8_t, 5> kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07};
// clang-format off
const std::array<int, 256> kUtf8Bytes = {
static const std::array<int, 256> kUtf8Bytes = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
Expand All @@ -89,7 +89,7 @@ std::pair<TCodepoint, int> Utf8ToCodepoint(const char* utf8) {
auto bytes = kUtf8Bytes[static_cast<unsigned char>(utf8[0])];
if (bytes == -1) {
// invalid utf8
return {static_cast<TCodepoint>(CharHandlingError::kInvalidUtf8), 0};
return {static_cast<TCodepoint>(CharHandlingError::kInvalidUtf8), utf8};
}

TCodepoint res = static_cast<unsigned char>(utf8[0]) & kFirstByteMask[bytes];
Expand All @@ -100,23 +100,23 @@ std::pair<TCodepoint, int> Utf8ToCodepoint(const char* utf8) {
}
res = (res << 6) | (static_cast<unsigned char>(utf8[i]) & 0x3F);
}
return {res, bytes};
return {res, utf8 + bytes};
}

std::vector<TCodepoint> Utf8StringToCodepoints(const char* utf8) {
std::vector<TCodepoint> ParseUTF8(const char* utf8) {
std::vector<TCodepoint> codepoints;
while (*utf8 != 0) {
auto [codepoint, bytes] = Utf8ToCodepoint(utf8);
TCodepoint codepoint;
std::tie(codepoint, utf8) = ParseNextUTF8(utf8);
if (codepoint == static_cast<TCodepoint>(CharHandlingError::kInvalidUtf8)) {
return {codepoint};
}
codepoints.push_back(codepoint);
utf8 += bytes;
}
return codepoints;
}

int HexCharToInt(char c) {
inline int HexCharToInt(char c) {
if (c >= '0' && c <= '9') {
return c - '0';
} else if (c >= 'a' && c <= 'f') {
Expand All @@ -128,22 +128,22 @@ int HexCharToInt(char c) {
}
}

std::pair<TCodepoint, int> Utf8OrEscapeToCodepoint(
std::pair<TCodepoint, const char*> ParseNextUTF8OrEscaped(
const char* utf8, const std::unordered_map<std::string, TCodepoint>& custom_escape_map) {
static const std::unordered_map<std::string, TCodepoint> kEscapeToCodepoint = {
{"\\\'", '\''}, {"\\\"", '\"'}, {"\\\?", '\?'}, {"\\\\", '\\'}, {"\\a", '\a'},
{"\\b", '\b'}, {"\\f", '\f'}, {"\\n", '\n'}, {"\\r", '\r'}, {"\\t", '\t'},
{"\\v", '\v'}, {"\\0", '\0'}, {"\\e", '\x1B'}};
if (utf8[0] != '\\') {
return Utf8ToCodepoint(utf8);
return ParseNextUTF8(utf8);
}

auto escape_sequence = std::string(utf8, 2);
if (auto it = custom_escape_map.find(escape_sequence); it != custom_escape_map.end()) {
return {it->second, 2};
return {it->second, utf8 + 2};
}
if (auto it = kEscapeToCodepoint.find(escape_sequence); it != kEscapeToCodepoint.end()) {
return {it->second, 2};
return {it->second, utf8 + 2};
}

if (utf8[1] == 'x') {
Expand All @@ -159,9 +159,9 @@ std::pair<TCodepoint, int> Utf8OrEscapeToCodepoint(
++len;
}
if (len == 0) {
return {static_cast<TCodepoint>(CharHandlingError::kInvalidEscape), 0};
return {static_cast<TCodepoint>(CharHandlingError::kInvalidEscape), utf8};
}
return {codepoint, len + 2};
return {codepoint, utf8 + len + 2};
} else if (utf8[1] == 'u' || utf8[1] == 'U') {
// 4- or 8-digit hex
int len = utf8[1] == 'u' ? 4 : 8;
Expand All @@ -170,13 +170,13 @@ std::pair<TCodepoint, int> Utf8OrEscapeToCodepoint(
for (int i = 0; i < len; ++i) {
auto digit = HexCharToInt(utf8[i + 2]);
if (digit == -1) {
return {static_cast<TCodepoint>(CharHandlingError::kInvalidEscape), 0};
return {static_cast<TCodepoint>(CharHandlingError::kInvalidEscape), utf8};
}
codepoint = codepoint * 16 + digit;
}
return {codepoint, len + 2};
return {codepoint, utf8 + len + 2};
} else {
return {static_cast<TCodepoint>(CharHandlingError::kInvalidEscape), 0};
return {static_cast<TCodepoint>(CharHandlingError::kInvalidEscape), utf8};
}
}

Expand Down
14 changes: 7 additions & 7 deletions cpp/support/encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ using TCodepoint = int32_t;
* \param codepoint The codepoint.
* \return The UTF-8 string.
*/
std::string CodepointToUtf8(TCodepoint codepoint);
std::string PrintAsUTF8(TCodepoint codepoint);

/*!
* \brief Convert a codepoint to a printable string. If the codepoint is not printable, it will be
* escaped. By default the function support escape sequences in C ("\n", "\t", "\u0123"). User can
* specify more escape sequences using custom_escape_map.
* \param codepoint The codepoint.
* \param custom_escape_map A map from codepoint to escape sequence. If the codepoint is in the map,
* it will be escaped using the corresponding escape sequence. e.g. {'-', "\\-"}.
* it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}.
* \return The printable string.
*/
std::string CodepointToPrintable(
std::string PrintAsEscaped(
TCodepoint codepoint,
const std::unordered_map<TCodepoint, std::string>& custom_escape_map = {});

Expand All @@ -53,22 +53,22 @@ enum class CharHandlingError : TCodepoint {
* \return The codepoint and the number of bytes consumed. If the UTF-8 string is invalid, the
* function returns (CharHandlingError::kInvalidUtf8, 0).
*/
std::pair<TCodepoint, int> Utf8ToCodepoint(const char* utf8);
std::pair<TCodepoint, const char*> ParseNextUTF8(const char* utf8);

std::vector<TCodepoint> Utf8StringToCodepoints(const char* utf8);
std::vector<TCodepoint> ParseUTF8(const char* utf8);

/*!
* \brief Convert a UTF-8 string or an escape sequence to a codepoint. By default the function
* supports escape sequences in C ("\n", "\t", "\u0123"). User can specify more escape sequences
* using custom_escape_map.
* \param utf8 The UTF-8 string or the escape sequence.
* \param custom_escape_map A map from escape sequence to codepoint. If the escape sequence is in
* the map, it will be converted to the corresponding codepoint. e.g. {"\\-", '-'}.
* the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}.
* \return The codepoint and the number of bytes consumed. If the UTF-8 string or the escape
* sequence is invalid, the function returns
* (CharHandlingError::kInvalidUtf8 or CharHandlingError::kInvalidEscape, 0).
*/
std::pair<TCodepoint, int> Utf8OrEscapeToCodepoint(
std::pair<TCodepoint, const char*> ParseNextUTF8OrEscaped(
const char* utf8, const std::unordered_map<std::string, TCodepoint>& custom_escape_map = {});

} // namespace llm
Expand Down