From 6828d7b8ef2de15bf81c3de46f139c55e319390e Mon Sep 17 00:00:00 2001 From: Ubospica Date: Tue, 30 Apr 2024 03:18:09 +0000 Subject: [PATCH] [Support] Simplify function names in encoding.h This PR simplifies the tool function names in encoding.h. The new names are - PrintAsUTF8 - PrintAsEscaped - ParseNextUTF8 - ParseUTF8 - ParseNextUTF8OrEscaped Also make ParseNextUTF8 return the new char pointer instead of the number of chars processed to make the interface simpler. --- cpp/serve/grammar/grammar_parser.cc | 11 ++--- cpp/serve/grammar/grammar_serializer.cc | 4 +- cpp/serve/grammar/grammar_state_matcher.cc | 10 ++--- .../grammar/grammar_state_matcher_base.h | 8 ++-- .../grammar/grammar_state_matcher_preproc.h | 2 +- cpp/support/encoding.cc | 42 +++++++++---------- cpp/support/encoding.h | 14 +++---- 7 files changed, 46 insertions(+), 45 deletions(-) diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index 1ece99099e..55ab0a1dff 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -156,14 +156,14 @@ int32_t EBNFParserImpl::ParseCharacterClass() { continue; } - auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_, kCustomEscapeMap); + auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_, kCustomEscapeMap); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { - ThrowParseError("Invalid utf8 sequence"); + ThrowParseError("Invalid UTF8 sequence"); } if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { ThrowParseError("Invalid escape sequence"); } - Consume(len); + Consume(new_cur - cur_); if (past_is_hyphen) { ICHECK(!elements.empty()); if (elements.back().lower > codepoint) { @@ -194,14 +194,15 @@ int32_t EBNFParserImpl::ParseString() { if (Peek() == '\r' || Peek() == '\n') { ThrowParseError("There should be no newline character in a string literal"); } - auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_); + + auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { ThrowParseError("Invalid utf8 sequence"); } if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { ThrowParseError("Invalid escape sequence"); } - Consume(len); + Consume(new_cur - cur_); character_classes.push_back(builder_.AddCharacterClass({{codepoint, codepoint}})); } if (character_classes.empty()) { diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index fd41517863..c3c2c88baa 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -59,12 +59,12 @@ std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { result += "^"; } for (auto i = 0; i < rule_expr.data_len; i += 2) { - result += CodepointToPrintable(rule_expr[i], kCustomEscapeMap); + result += PrintAsEscaped(rule_expr[i], kCustomEscapeMap); if (rule_expr[i] == rule_expr[i + 1]) { continue; } result += "-"; - result += CodepointToPrintable(rule_expr[i + 1], kCustomEscapeMap); + result += PrintAsEscaped(rule_expr[i + 1], kCustomEscapeMap); } result += "]"; return result; diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 5c4ef98efe..451127e746 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -510,7 +510,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") bool MatchCompleteString(GrammarStateMatcher matcher, String str) { auto mutable_node = const_cast(matcher.as()); - auto codepoints = Utf8StringToCodepoints(str.c_str()); + auto codepoints = ParseUTF8(str.c_str()); int accepted_cnt = 0; for (auto codepoint : codepoints) { if (!mutable_node->AcceptCodepoint(codepoint, false)) { @@ -553,9 +553,9 @@ void PrintAcceptedRejectedTokens( // First cast to unsigned, then cast to int std::cerr << static_cast(static_cast(token[0])); } else { - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); for (auto c : codepoints) { - std::cerr << CodepointToPrintable(c); + std::cerr << PrintAsEscaped(c); } } std::cerr << "> "; @@ -571,9 +571,9 @@ void PrintAcceptedRejectedTokens( if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) { std::cerr << (int)(unsigned char)token[0]; } else { - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); for (auto c : codepoints) { - std::cerr << CodepointToPrintable(c); + std::cerr << PrintAsEscaped(c); } } std::cerr << "> "; diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 55c986bb10..5b774d33a4 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -156,15 +156,15 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool } if (tmp_new_stack_tops_.empty()) { if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) - << "\" Rejected" << std::endl; + std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Rejected" + << std::endl; } return false; } stack_tops_history_.PushHistory(tmp_new_stack_tops_); if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) - << "\" Accepted" << std::endl; + std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Accepted" + << std::endl; std::cout << "Stack after accepting: " << PrintStackState() << std::endl; } #if TVM_LOG_DEBUG diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index c853ac7e04..f63eee2c5c 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -268,7 +268,7 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC ptr->special_token_ids.push_back(i); } else { // First replace the special underscore with space. - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); DCHECK(!codepoints.empty() && codepoints[0] != static_cast(CharHandlingError::kInvalidUtf8)) << "Invalid token: " << token; diff --git a/cpp/support/encoding.cc b/cpp/support/encoding.cc index 0509c1eb2a..d9420bbbd5 100644 --- a/cpp/support/encoding.cc +++ b/cpp/support/encoding.cc @@ -11,7 +11,7 @@ namespace mlc { namespace llm { -std::string CodepointToUtf8(TCodepoint codepoint) { +std::string PrintAsUTF8(TCodepoint codepoint) { ICHECK(codepoint <= 0x10FFFF) << "Invalid codepoint: " << codepoint; std::string utf8; if (codepoint <= 0x7F) { @@ -36,8 +36,8 @@ std::string CodepointToUtf8(TCodepoint codepoint) { return utf8; } -std::string CodepointToPrintable( - TCodepoint codepoint, const std::unordered_map& custom_escape_map) { +std::string PrintAsEscaped(TCodepoint codepoint, + const std::unordered_map& custom_escape_map) { static const std::unordered_map kCodepointToEscape = { {'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"}, {'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"}, @@ -63,10 +63,10 @@ std::string CodepointToPrintable( return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex; } -std::pair Utf8ToCodepoint(const char* utf8) { - const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; +std::pair ParseNextUTF8(const char* utf8) { + static const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; // clang-format off - const std::array kUtf8Bytes = { + static const std::array kUtf8Bytes = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -89,7 +89,7 @@ std::pair Utf8ToCodepoint(const char* utf8) { auto bytes = kUtf8Bytes[static_cast(utf8[0])]; if (bytes == -1) { // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), 0}; + return {static_cast(CharHandlingError::kInvalidUtf8), utf8}; } TCodepoint res = static_cast(utf8[0]) & kFirstByteMask[bytes]; @@ -100,23 +100,23 @@ std::pair Utf8ToCodepoint(const char* utf8) { } res = (res << 6) | (static_cast(utf8[i]) & 0x3F); } - return {res, bytes}; + return {res, utf8 + bytes}; } -std::vector Utf8StringToCodepoints(const char* utf8) { +std::vector ParseUTF8(const char* utf8) { std::vector codepoints; while (*utf8 != 0) { - auto [codepoint, bytes] = Utf8ToCodepoint(utf8); + TCodepoint codepoint; + std::tie(codepoint, utf8) = ParseNextUTF8(utf8); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { return {codepoint}; } codepoints.push_back(codepoint); - utf8 += bytes; } return codepoints; } -int HexCharToInt(char c) { +inline int HexCharToInt(char c) { if (c >= '0' && c <= '9') { return c - '0'; } else if (c >= 'a' && c <= 'f') { @@ -128,22 +128,22 @@ int HexCharToInt(char c) { } } -std::pair Utf8OrEscapeToCodepoint( +std::pair ParseNextUTF8OrEscaped( const char* utf8, const std::unordered_map& custom_escape_map) { static const std::unordered_map kEscapeToCodepoint = { {"\\\'", '\''}, {"\\\"", '\"'}, {"\\\?", '\?'}, {"\\\\", '\\'}, {"\\a", '\a'}, {"\\b", '\b'}, {"\\f", '\f'}, {"\\n", '\n'}, {"\\r", '\r'}, {"\\t", '\t'}, {"\\v", '\v'}, {"\\0", '\0'}, {"\\e", '\x1B'}}; if (utf8[0] != '\\') { - return Utf8ToCodepoint(utf8); + return ParseNextUTF8(utf8); } auto escape_sequence = std::string(utf8, 2); if (auto it = custom_escape_map.find(escape_sequence); it != custom_escape_map.end()) { - return {it->second, 2}; + return {it->second, utf8 + 2}; } if (auto it = kEscapeToCodepoint.find(escape_sequence); it != kEscapeToCodepoint.end()) { - return {it->second, 2}; + return {it->second, utf8 + 2}; } if (utf8[1] == 'x') { @@ -159,9 +159,9 @@ std::pair Utf8OrEscapeToCodepoint( ++len; } if (len == 0) { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } - return {codepoint, len + 2}; + return {codepoint, utf8 + len + 2}; } else if (utf8[1] == 'u' || utf8[1] == 'U') { // 4- or 8-digit hex int len = utf8[1] == 'u' ? 4 : 8; @@ -170,13 +170,13 @@ std::pair Utf8OrEscapeToCodepoint( for (int i = 0; i < len; ++i) { auto digit = HexCharToInt(utf8[i + 2]); if (digit == -1) { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } codepoint = codepoint * 16 + digit; } - return {codepoint, len + 2}; + return {codepoint, utf8 + len + 2}; } else { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } } diff --git a/cpp/support/encoding.h b/cpp/support/encoding.h index f28aae6d74..790040e97e 100644 --- a/cpp/support/encoding.h +++ b/cpp/support/encoding.h @@ -21,7 +21,7 @@ using TCodepoint = int32_t; * \param codepoint The codepoint. * \return The UTF-8 string. */ -std::string CodepointToUtf8(TCodepoint codepoint); +std::string PrintAsUTF8(TCodepoint codepoint); /*! * \brief Convert a codepoint to a printable string. If the codepoint is not printable, it will be @@ -29,10 +29,10 @@ std::string CodepointToUtf8(TCodepoint codepoint); * specify more escape sequences using custom_escape_map. * \param codepoint The codepoint. * \param custom_escape_map A map from codepoint to escape sequence. If the codepoint is in the map, - * it will be escaped using the corresponding escape sequence. e.g. {'-', "\\-"}. + * it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. * \return The printable string. */ -std::string CodepointToPrintable( +std::string PrintAsEscaped( TCodepoint codepoint, const std::unordered_map& custom_escape_map = {}); @@ -53,9 +53,9 @@ enum class CharHandlingError : TCodepoint { * \return The codepoint and the number of bytes consumed. If the UTF-8 string is invalid, the * function returns (CharHandlingError::kInvalidUtf8, 0). */ -std::pair Utf8ToCodepoint(const char* utf8); +std::pair ParseNextUTF8(const char* utf8); -std::vector Utf8StringToCodepoints(const char* utf8); +std::vector ParseUTF8(const char* utf8); /*! * \brief Convert a UTF-8 string or an escape sequence to a codepoint. By default the function @@ -63,12 +63,12 @@ std::vector Utf8StringToCodepoints(const char* utf8); * using custom_escape_map. * \param utf8 The UTF-8 string or the escape sequence. * \param custom_escape_map A map from escape sequence to codepoint. If the escape sequence is in - * the map, it will be converted to the corresponding codepoint. e.g. {"\\-", '-'}. + * the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. * \return The codepoint and the number of bytes consumed. If the UTF-8 string or the escape * sequence is invalid, the function returns * (CharHandlingError::kInvalidUtf8 or CharHandlingError::kInvalidEscape, 0). */ -std::pair Utf8OrEscapeToCodepoint( +std::pair ParseNextUTF8OrEscaped( const char* utf8, const std::unordered_map& custom_escape_map = {}); } // namespace llm