Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/pytorch/tokenizers/normalizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,19 @@ class SequenceNormalizer : public Normalizer {

}; // end class SequenceNormalizer

// -- NFC ----------------------------------------------------------------------
// Used for Unicode NFC (Normalization Form Canonical Composition) normalization
// CITE:
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/normalizers/unicode.rs

class NFCNormalizer : public Normalizer {
public:
/** Default constructor */
explicit NFCNormalizer() = default;

/** Normalize with NFC Unicode normalization */
std::string normalize(const std::string& input) const override;

}; // end class NFCNormalizer

} // namespace tokenizers
29 changes: 29 additions & 0 deletions src/normalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
// Local
#include <pytorch/tokenizers/normalizer.h>

// Third Party
#include <unicode.h>

// Standard
#include <algorithm>
#include <iterator>
Expand Down Expand Up @@ -54,6 +57,9 @@ Normalizer::Ptr NormalizerConfig::create() const {
[](const NormalizerConfig& cfg) { return cfg.create(); });
return Normalizer::Ptr(new SequenceNormalizer(norms));
}
if (type == "NFC") {
return Normalizer::Ptr(new NFCNormalizer());
}
throw std::runtime_error("Unsupported Normalizer type: " + type);
}

Expand All @@ -76,6 +82,11 @@ NormalizerConfig& NormalizerConfig::parse_json(const json& json_config) {
for (const auto& entry : json_config.at("normalizers")) {
normalizers->push_back(NormalizerConfig().parse_json(entry));
}
} else if (type == "NFC") {
// NFC normalizer has no additional configuration parameters
TK_LOG(
Info,
"Using NFC normalizer. Please notice that our implementation may not handle all edge cases.");
} else {
throw std::runtime_error("Unsupported Normalizer type: " + type);
}
Expand Down Expand Up @@ -119,4 +130,22 @@ std::string SequenceNormalizer::normalize(const std::string& input) const {
return result;
}

// NFCNormalizer ///////////////////////////////////////////////////////////////

std::string NFCNormalizer::normalize(const std::string& input) const {
// Convert UTF-8 string to codepoints
auto codepoints = unicode_cpts_from_utf8(input);

// Apply NFC normalization
auto normalized_cpts = unicode_cpts_normalize_nfc(codepoints);

// Convert back to UTF-8 string
std::string result;
for (uint32_t cpt : normalized_cpts) {
result += unicode_cpt_to_utf8(cpt);
}

return result;
}

} // namespace tokenizers
28 changes: 16 additions & 12 deletions test/test_hf_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,35 @@
"""

import unittest
import pytest
from pytorch_tokenizers import CppHFTokenizer
from transformers import AutoTokenizer
from tempfile import TemporaryDirectory

PROMPT = "What is the capital of France?"

class TestHfTokenizer(unittest.TestCase):
def setUp(self) -> None:
self.temp_dir = TemporaryDirectory()
super().setUp()

def test_smolLM3(self) -> None:
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
tokenizer_path = tokenizer.save_pretrained(self.temp_dir.name)[-1]
@pytest.mark.parametrize("model_id", [
"HuggingFaceTB/SmolLM3-3B",
"Qwen/Qwen2.5-0.5B"
])
def test_models(model_id: str) -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer_path = tokenizer.save_pretrained(temp_dir)[-1]

cpp_tokenizer = CppHFTokenizer()
cpp_tokenizer.load(tokenizer_path)

tokens = tokenizer.encode(PROMPT)
cpp_tokens = cpp_tokenizer.encode(PROMPT)
self.assertEqual(tokens, cpp_tokens)
assert tokens == cpp_tokens


class TestHfTokenizer(unittest.TestCase):
def setUp(self) -> None:
self.temp_dir = TemporaryDirectory()
super().setUp()

def test_llama3_2_1b(self) -> None:
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
Expand All @@ -42,7 +50,3 @@ def test_llama3_2_1b(self) -> None:
tokens = tokenizer.encode(PROMPT)
cpp_tokens = cpp_tokenizer.encode(PROMPT, bos=1)
self.assertEqual(tokens, cpp_tokens)


async def test_async_DO_NOT_COMMIT(self) -> None:
pass
3 changes: 3 additions & 0 deletions third-party/llama.cpp-unicode/include/unicode.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string &utf8);
std::vector<uint32_t>
unicode_cpts_normalize_nfd(const std::vector<uint32_t> &cpts);

std::vector<uint32_t>
unicode_cpts_normalize_nfc(const std::vector<uint32_t> &cpts);

codepoint_flags unicode_cpt_flags(const uint32_t cp);
codepoint_flags unicode_cpt_flags(const std::string &utf8);

Expand Down
216 changes: 214 additions & 2 deletions third-party/llama.cpp-unicode/src/unicode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,27 @@ SOFTWARE.
#include <codecvt>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <limits>
#include <locale>
#include <map>
#include <regex>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

// Hash function for std::pair<uint32_t, uint32_t> used in composition table
namespace std {
template<>
struct hash<std::pair<uint32_t, uint32_t>> {
std::size_t operator()(const std::pair<uint32_t, uint32_t>& p) const {
return std::hash<uint64_t>{}(((uint64_t)p.first << 32) | p.second);
}
};
}

size_t unicode_len_utf8(char src) {
const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
Expand Down Expand Up @@ -928,3 +939,204 @@ std::vector<std::string> unicode_regex_split(

return unicode_byte_encoding_process(bpe_words);
}

// Get canonical combining class for a codepoint using existing flags data
static uint8_t get_combining_class(uint32_t cpt) {
codepoint_flags flags = unicode_cpt_flags(cpt);

// Use the existing flag system to determine combining class
if (flags.is_accent_mark) {
// Most combining marks have class 230, but some have different classes
// This is a simplified mapping based on common Unicode patterns
if (cpt >= 0x0591 && cpt <= 0x05BD) return 220; // Hebrew accents
if (cpt >= 0x05BF && cpt <= 0x05C7) return 230; // Hebrew points
if (cpt >= 0x0610 && cpt <= 0x061A) return 230; // Arabic marks
if (cpt >= 0x064B && cpt <= 0x065F) return 30; // Arabic vowels
if (cpt >= 0x0670 && cpt <= 0x0670) return 35; // Arabic superscript alef
if (cpt >= 0x06D6 && cpt <= 0x06E4) return 230; // Arabic small high marks
if (cpt >= 0x06E7 && cpt <= 0x06E8) return 230; // Arabic small high marks
if (cpt >= 0x06EA && cpt <= 0x06ED) return 220; // Arabic small low marks

// Default combining class for most combining marks
return 230;
}

return 0; // Non-combining character (starter)
}

// Apply canonical ordering using bubble sort (simple but correct)
static void canonical_order(std::vector<uint32_t>& cpts) {
for (size_t i = 1; i < cpts.size(); ++i) {
for (size_t j = i; j > 0; --j) {
uint8_t cc1 = get_combining_class(cpts[j-1]);
uint8_t cc2 = get_combining_class(cpts[j]);

// Only reorder if both have non-zero combining class and are out of order
if (cc1 > cc2 && cc2 != 0) {
std::swap(cpts[j-1], cpts[j]);
} else {
break;
}
}
}
}

// Build composition table by reverse-engineering the NFD data
static std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> build_composition_table() {
std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t> composition_map;

// Iterate through all NFD mappings to build reverse composition table
for (const auto& range : unicode_ranges_nfd) {
for (uint32_t cpt = range.first; cpt <= range.last; ++cpt) {
uint32_t base = range.nfd;

// For NFC, we need to figure out what combining character was removed
// This is a simplified approach that works for the most common cases

// Common diacritic mappings based on the composed character
uint32_t combining = 0;

// Determine combining character based on the composed character
// This is derived from common Unicode patterns
switch (cpt) {
// Grave accent (0x0300)
case 0x00C0: case 0x00E0: // À à
case 0x00C8: case 0x00E8: // È è
case 0x00CC: case 0x00EC: // Ì ì
case 0x00D2: case 0x00F2: // Ò ò
case 0x00D9: case 0x00F9: // Ù ù
case 0x01CD: case 0x01CE: // Ǎ ǎ
case 0x01CF: case 0x01D0: // Ǐ ǐ
case 0x01D1: case 0x01D2: // Ǒ ǒ
case 0x01D3: case 0x01D4: // Ǔ ǔ
combining = 0x0300; break;

// Acute accent (0x0301)
case 0x00C1: case 0x00E1: // Á á
case 0x00C9: case 0x00E9: // É é
case 0x00CD: case 0x00ED: // Í í
case 0x00D3: case 0x00F3: // Ó ó
case 0x00DA: case 0x00FA: // Ú ú
case 0x00DD: case 0x00FD: // Ý ý
combining = 0x0301; break;

// Circumflex (0x0302)
case 0x00C2: case 0x00E2: // Â â
case 0x00CA: case 0x00EA: // Ê ê
case 0x00CE: case 0x00EE: // Î î
case 0x00D4: case 0x00F4: // Ô ô
case 0x00DB: case 0x00FB: // Û û
combining = 0x0302; break;

// Tilde (0x0303)
case 0x00C3: case 0x00E3: // Ã ã
case 0x00D1: case 0x00F1: // Ñ ñ
case 0x00D5: case 0x00F5: // Õ õ
combining = 0x0303; break;

// Diaeresis (0x0308)
case 0x00C4: case 0x00E4: // Ä ä
case 0x00CB: case 0x00EB: // Ë ë
case 0x00CF: case 0x00EF: // Ï ï
case 0x00D6: case 0x00F6: // Ö ö
case 0x00DC: case 0x00FC: // Ü ü
case 0x00FF: // ÿ
combining = 0x0308; break;

// Ring above (0x030A)
case 0x00C5: case 0x00E5: // Å å
combining = 0x030A; break;

// Cedilla (0x0327)
case 0x00C7: case 0x00E7: // Ç ç
combining = 0x0327; break;

default:
// For other characters, try to infer from Unicode blocks
if (cpt >= 0x0100 && cpt <= 0x017F) {
// Extended Latin A - try common patterns
if ((cpt & 1) == 0) { // Even codepoints (uppercase)
if (cpt >= 0x0100 && cpt <= 0x0105) combining = 0x0304; // macron
else if (cpt >= 0x0102 && cpt <= 0x0107) combining = 0x0306; // breve
else if (cpt >= 0x0104 && cpt <= 0x0119) combining = 0x0328; // ogonek
else if (cpt >= 0x0106 && cpt <= 0x010D) combining = 0x0301; // acute
else if (cpt >= 0x0108 && cpt <= 0x010F) combining = 0x0302; // circumflex
else if (cpt >= 0x010A && cpt <= 0x0111) combining = 0x0307; // dot above
else if (cpt >= 0x010C && cpt <= 0x0165) combining = 0x030C; // caron
}
}
break;
}

// Only add to composition table if we identified a combining character
if (combining != 0) {
composition_map[{base, combining}] = cpt;
}
}
}

return composition_map;
}

// Get the composition table (built once, cached)
static const std::unordered_map<std::pair<uint32_t, uint32_t>, uint32_t>& get_composition_table() {
static const auto composition_table = build_composition_table();
return composition_table;
}

std::vector<uint32_t> unicode_cpts_normalize_nfc(
const std::vector<uint32_t>& cpts) {

// Step 1: Apply NFD (canonical decomposition) using existing implementation
std::vector<uint32_t> nfd_result = unicode_cpts_normalize_nfd(cpts);

// Step 2: Apply canonical ordering
canonical_order(nfd_result);

// Step 3: Apply canonical composition
const auto& composition_table = get_composition_table();
std::vector<uint32_t> result;
result.reserve(nfd_result.size());

size_t i = 0;
while (i < nfd_result.size()) {
uint32_t starter = nfd_result[i];
result.push_back(starter);

// Only try to compose if this is a starter (combining class 0)
if (get_combining_class(starter) == 0) {
size_t last_starter_pos = result.size() - 1;

// Look for composable combining marks after this starter
size_t j = i + 1;
while (j < nfd_result.size()) {
uint32_t combining = nfd_result[j];
uint8_t cc = get_combining_class(combining);

// If we hit another starter, stop
if (cc == 0) break;

// Try to compose with the last starter
auto key = std::make_pair(result[last_starter_pos], combining);
auto it = composition_table.find(key);

if (it != composition_table.end()) {
// Compose: replace starter with composed character
result[last_starter_pos] = it->second;
// Skip this combining character
++j;
continue;
}

// No composition possible, add the combining character
result.push_back(combining);
++j;
}
i = j;
} else {
++i;
}
}

return result;
}