From 2ef22c30f66d358b8d648e02b65d51908b8bed96 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 16 Jul 2025 18:02:17 -0700 Subject: [PATCH] Support NFC Normalizer A simplified version of NFC. Based on what we have for NFD, this PR implements the composition logic. Also refer back to https://unicode.org/reports/tr15/ on how this works. --- include/pytorch/tokenizers/normalizer.h | 15 ++ src/normalizer.cpp | 29 +++ test/test_hf_tokenizer.py | 28 ++- .../llama.cpp-unicode/include/unicode.h | 3 + third-party/llama.cpp-unicode/src/unicode.cpp | 216 +++++++++++++++++- 5 files changed, 277 insertions(+), 14 deletions(-) diff --git a/include/pytorch/tokenizers/normalizer.h b/include/pytorch/tokenizers/normalizer.h index 5d0dda5..92715ed 100644 --- a/include/pytorch/tokenizers/normalizer.h +++ b/include/pytorch/tokenizers/normalizer.h @@ -171,4 +171,19 @@ class SequenceNormalizer : public Normalizer { }; // end class SequenceNormalizer +// -- NFC ---------------------------------------------------------------------- +// Used for Unicode NFC (Normalization Form Canonical Composition) normalization +// CITE: +// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/unicode.rs + +class NFCNormalizer : public Normalizer { + public: + /** Default constructor */ + explicit NFCNormalizer() = default; + + /** Normalize with NFC Unicode normalization */ + std::string normalize(const std::string& input) const override; + +}; // end class NFCNormalizer + } // namespace tokenizers diff --git a/src/normalizer.cpp b/src/normalizer.cpp index 3c9e7f9..04af6c5 100644 --- a/src/normalizer.cpp +++ b/src/normalizer.cpp @@ -10,6 +10,9 @@ // Local #include +// Third Party +#include + // Standard #include #include @@ -54,6 +57,9 @@ Normalizer::Ptr NormalizerConfig::create() const { [](const NormalizerConfig& cfg) { return cfg.create(); }); return Normalizer::Ptr(new SequenceNormalizer(norms)); } + if (type == "NFC") { + return Normalizer::Ptr(new NFCNormalizer()); + } throw std::runtime_error("Unsupported Normalizer type: " + type); } @@ -76,6 +82,11 @@ NormalizerConfig& NormalizerConfig::parse_json(const json& json_config) { for (const auto& entry : json_config.at("normalizers")) { normalizers->push_back(NormalizerConfig().parse_json(entry)); } + } else if (type == "NFC") { + // NFC normalizer has no additional configuration parameters + TK_LOG( + Info, + "Using NFC normalizer. Please notice that our implementation may not handle all edge cases."); } else { throw std::runtime_error("Unsupported Normalizer type: " + type); } @@ -119,4 +130,22 @@ std::string SequenceNormalizer::normalize(const std::string& input) const { return result; } +// NFCNormalizer /////////////////////////////////////////////////////////////// + +std::string NFCNormalizer::normalize(const std::string& input) const { + // Convert UTF-8 string to codepoints + auto codepoints = unicode_cpts_from_utf8(input); + + // Apply NFC normalization + auto normalized_cpts = unicode_cpts_normalize_nfc(codepoints); + + // Convert back to UTF-8 string + std::string result; + for (uint32_t cpt : normalized_cpts) { + result += unicode_cpt_to_utf8(cpt); + } + + return result; +} + } // namespace tokenizers diff --git a/test/test_hf_tokenizer.py b/test/test_hf_tokenizer.py index 378a357..304da49 100644 --- a/test/test_hf_tokenizer.py +++ b/test/test_hf_tokenizer.py @@ -10,27 +10,35 @@ """ import unittest +import pytest from pytorch_tokenizers import CppHFTokenizer from transformers import AutoTokenizer from tempfile import TemporaryDirectory PROMPT = "What is the capital of France?" -class TestHfTokenizer(unittest.TestCase): - def setUp(self) -> None: - self.temp_dir = TemporaryDirectory() - super().setUp() - def test_smolLM3(self) -> None: - tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B") - tokenizer_path = tokenizer.save_pretrained(self.temp_dir.name)[-1] +@pytest.mark.parametrize("model_id", [ + "HuggingFaceTB/SmolLM3-3B", + "Qwen/Qwen2.5-0.5B" +]) +def test_models(model_id: str) -> None: + with TemporaryDirectory() as temp_dir: + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer_path = tokenizer.save_pretrained(temp_dir)[-1] cpp_tokenizer = CppHFTokenizer() cpp_tokenizer.load(tokenizer_path) tokens = tokenizer.encode(PROMPT) cpp_tokens = cpp_tokenizer.encode(PROMPT) - self.assertEqual(tokens, cpp_tokens) + assert tokens == cpp_tokens + + +class TestHfTokenizer(unittest.TestCase): + def setUp(self) -> None: + self.temp_dir = TemporaryDirectory() + super().setUp() def test_llama3_2_1b(self) -> None: tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct") @@ -42,7 +50,3 @@ def test_llama3_2_1b(self) -> None: tokens = tokenizer.encode(PROMPT) cpp_tokens = cpp_tokenizer.encode(PROMPT, bos=1) self.assertEqual(tokens, cpp_tokens) - - - async def test_async_DO_NOT_COMMIT(self) -> None: - pass diff --git a/third-party/llama.cpp-unicode/include/unicode.h b/third-party/llama.cpp-unicode/include/unicode.h index a458285..be55557 100644 --- a/third-party/llama.cpp-unicode/include/unicode.h +++ b/third-party/llama.cpp-unicode/include/unicode.h @@ -84,6 +84,9 @@ std::vector unicode_cpts_from_utf8(const std::string &utf8); std::vector unicode_cpts_normalize_nfd(const std::vector &cpts); +std::vector +unicode_cpts_normalize_nfc(const std::vector &cpts); + codepoint_flags unicode_cpt_flags(const uint32_t cp); codepoint_flags unicode_cpt_flags(const std::string &utf8); diff --git a/third-party/llama.cpp-unicode/src/unicode.cpp b/third-party/llama.cpp-unicode/src/unicode.cpp index 096fdce..75f44ec 100644 --- a/third-party/llama.cpp-unicode/src/unicode.cpp +++ b/third-party/llama.cpp-unicode/src/unicode.cpp @@ -37,16 +37,27 @@ SOFTWARE. #include #include #include +#include +#include +#include #include #include #include #include #include #include -#include -#include #include +// Hash function for std::pair used in composition table +namespace std { + template<> + struct hash> { + std::size_t operator()(const std::pair& p) const { + return std::hash{}(((uint64_t)p.first << 32) | p.second); + } + }; +} + size_t unicode_len_utf8(char src) { const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}; uint8_t highbits = static_cast(src) >> 4; @@ -928,3 +939,204 @@ std::vector unicode_regex_split( return unicode_byte_encoding_process(bpe_words); } + +// Get canonical combining class for a codepoint using existing flags data +static uint8_t get_combining_class(uint32_t cpt) { + codepoint_flags flags = unicode_cpt_flags(cpt); + + // Use the existing flag system to determine combining class + if (flags.is_accent_mark) { + // Most combining marks have class 230, but some have different classes + // This is a simplified mapping based on common Unicode patterns + if (cpt >= 0x0591 && cpt <= 0x05BD) return 220; // Hebrew accents + if (cpt >= 0x05BF && cpt <= 0x05C7) return 230; // Hebrew points + if (cpt >= 0x0610 && cpt <= 0x061A) return 230; // Arabic marks + if (cpt >= 0x064B && cpt <= 0x065F) return 30; // Arabic vowels + if (cpt >= 0x0670 && cpt <= 0x0670) return 35; // Arabic superscript alef + if (cpt >= 0x06D6 && cpt <= 0x06E4) return 230; // Arabic small high marks + if (cpt >= 0x06E7 && cpt <= 0x06E8) return 230; // Arabic small high marks + if (cpt >= 0x06EA && cpt <= 0x06ED) return 220; // Arabic small low marks + + // Default combining class for most combining marks + return 230; + } + + return 0; // Non-combining character (starter) +} + +// Apply canonical ordering using bubble sort (simple but correct) +static void canonical_order(std::vector& cpts) { + for (size_t i = 1; i < cpts.size(); ++i) { + for (size_t j = i; j > 0; --j) { + uint8_t cc1 = get_combining_class(cpts[j-1]); + uint8_t cc2 = get_combining_class(cpts[j]); + + // Only reorder if both have non-zero combining class and are out of order + if (cc1 > cc2 && cc2 != 0) { + std::swap(cpts[j-1], cpts[j]); + } else { + break; + } + } + } +} + +// Build composition table by reverse-engineering the NFD data +static std::unordered_map, uint32_t> build_composition_table() { + std::unordered_map, uint32_t> composition_map; + + // Iterate through all NFD mappings to build reverse composition table + for (const auto& range : unicode_ranges_nfd) { + for (uint32_t cpt = range.first; cpt <= range.last; ++cpt) { + uint32_t base = range.nfd; + + // For NFC, we need to figure out what combining character was removed + // This is a simplified approach that works for the most common cases + + // Common diacritic mappings based on the composed character + uint32_t combining = 0; + + // Determine combining character based on the composed character + // This is derived from common Unicode patterns + switch (cpt) { + // Grave accent (0x0300) + case 0x00C0: case 0x00E0: // À à + case 0x00C8: case 0x00E8: // È è + case 0x00CC: case 0x00EC: // Ì ì + case 0x00D2: case 0x00F2: // Ò ò + case 0x00D9: case 0x00F9: // Ù ù + case 0x01CD: case 0x01CE: // Ǎ ǎ + case 0x01CF: case 0x01D0: // Ǐ ǐ + case 0x01D1: case 0x01D2: // Ǒ ǒ + case 0x01D3: case 0x01D4: // Ǔ ǔ + combining = 0x0300; break; + + // Acute accent (0x0301) + case 0x00C1: case 0x00E1: // Á á + case 0x00C9: case 0x00E9: // É é + case 0x00CD: case 0x00ED: // Í í + case 0x00D3: case 0x00F3: // Ó ó + case 0x00DA: case 0x00FA: // Ú ú + case 0x00DD: case 0x00FD: // Ý ý + combining = 0x0301; break; + + // Circumflex (0x0302) + case 0x00C2: case 0x00E2: // Â â + case 0x00CA: case 0x00EA: // Ê ê + case 0x00CE: case 0x00EE: // Î î + case 0x00D4: case 0x00F4: // Ô ô + case 0x00DB: case 0x00FB: // Û û + combining = 0x0302; break; + + // Tilde (0x0303) + case 0x00C3: case 0x00E3: // Ã ã + case 0x00D1: case 0x00F1: // Ñ ñ + case 0x00D5: case 0x00F5: // Õ õ + combining = 0x0303; break; + + // Diaeresis (0x0308) + case 0x00C4: case 0x00E4: // Ä ä + case 0x00CB: case 0x00EB: // Ë ë + case 0x00CF: case 0x00EF: // Ï ï + case 0x00D6: case 0x00F6: // Ö ö + case 0x00DC: case 0x00FC: // Ü ü + case 0x00FF: // ÿ + combining = 0x0308; break; + + // Ring above (0x030A) + case 0x00C5: case 0x00E5: // Å å + combining = 0x030A; break; + + // Cedilla (0x0327) + case 0x00C7: case 0x00E7: // Ç ç + combining = 0x0327; break; + + default: + // For other characters, try to infer from Unicode blocks + if (cpt >= 0x0100 && cpt <= 0x017F) { + // Extended Latin A - try common patterns + if ((cpt & 1) == 0) { // Even codepoints (uppercase) + if (cpt >= 0x0100 && cpt <= 0x0105) combining = 0x0304; // macron + else if (cpt >= 0x0102 && cpt <= 0x0107) combining = 0x0306; // breve + else if (cpt >= 0x0104 && cpt <= 0x0119) combining = 0x0328; // ogonek + else if (cpt >= 0x0106 && cpt <= 0x010D) combining = 0x0301; // acute + else if (cpt >= 0x0108 && cpt <= 0x010F) combining = 0x0302; // circumflex + else if (cpt >= 0x010A && cpt <= 0x0111) combining = 0x0307; // dot above + else if (cpt >= 0x010C && cpt <= 0x0165) combining = 0x030C; // caron + } + } + break; + } + + // Only add to composition table if we identified a combining character + if (combining != 0) { + composition_map[{base, combining}] = cpt; + } + } + } + + return composition_map; +} + +// Get the composition table (built once, cached) +static const std::unordered_map, uint32_t>& get_composition_table() { + static const auto composition_table = build_composition_table(); + return composition_table; +} + +std::vector unicode_cpts_normalize_nfc( + const std::vector& cpts) { + + // Step 1: Apply NFD (canonical decomposition) using existing implementation + std::vector nfd_result = unicode_cpts_normalize_nfd(cpts); + + // Step 2: Apply canonical ordering + canonical_order(nfd_result); + + // Step 3: Apply canonical composition + const auto& composition_table = get_composition_table(); + std::vector result; + result.reserve(nfd_result.size()); + + size_t i = 0; + while (i < nfd_result.size()) { + uint32_t starter = nfd_result[i]; + result.push_back(starter); + + // Only try to compose if this is a starter (combining class 0) + if (get_combining_class(starter) == 0) { + size_t last_starter_pos = result.size() - 1; + + // Look for composable combining marks after this starter + size_t j = i + 1; + while (j < nfd_result.size()) { + uint32_t combining = nfd_result[j]; + uint8_t cc = get_combining_class(combining); + + // If we hit another starter, stop + if (cc == 0) break; + + // Try to compose with the last starter + auto key = std::make_pair(result[last_starter_pos], combining); + auto it = composition_table.find(key); + + if (it != composition_table.end()) { + // Compose: replace starter with composed character + result[last_starter_pos] = it->second; + // Skip this combining character + ++j; + continue; + } + + // No composition possible, add the combining character + result.push_back(combining); + ++j; + } + i = j; + } else { + ++i; + } + } + + return result; +}