diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index ef2c31e56..8eecf5c61 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -345,6 +345,7 @@ impl PyBpeTrainer { } "limit_alphabet" => builder = builder.limit_alphabet(val.extract()?), "max_token_length" => builder = builder.max_token_length(val.extract()?), + "enforce_utf8_boundaries" => builder = builder.enforce_utf8_boundaries(val.extract()?), "initial_alphabet" => { let alphabet: Vec = val.extract()?; builder = builder.initial_alphabet( diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py index 38b599448..336a78db2 100644 --- a/bindings/python/tests/bindings/test_trainers.py +++ b/bindings/python/tests/bindings/test_trainers.py @@ -74,6 +74,49 @@ def test_can_pickle(self): ) + def test_enforce_utf8_boundaries(self): + # This input is designed to have a very frequent but invalid merge candidate: + # a space (0x20) followed by the first byte of different 4-byte encodings (0xF0). + # A less frequent but valid candidate is the first two bytes of an emoji (0xF0, 0x9F). + data = [" 🤗"] * 10 + [" 𝟑"] * 9 + + # Setup a tokenizer with a ByteLevel pre-tokenizer + tokenizer = Tokenizer(models.BPE()) + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) + + # 1. Train with `enforce_utf8_boundaries=False` (unconstrained) + unconstrained_trainer = trainers.BpeTrainer( + vocab_size=260, + special_tokens=[""], + enforce_utf8_boundaries=False, + show_progress=False, + ) + tokenizer.train_from_iterator(data, trainer=unconstrained_trainer) + vocab = tokenizer.get_vocab() + + # The pre-tokenizer maps byte 0x20 to `Ġ` and 0xF0 to `ð`. + # The invalid merge of these two should be present. + invalid_token = "Ġð" # Bytes: [20, F0] + assert invalid_token in vocab, "Unconstrained trainer should learn the invalid merge" + + # 2. Train with `enforce_utf8_boundaries=True` (constrained) + # We must re-initialize the tokenizer to start with a fresh model + tokenizer = Tokenizer(models.BPE()) + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) + + # Train with enforce_utf8_boundaries=True + constrained_trainer = trainers.BpeTrainer( + vocab_size=260, + special_tokens=[""], + enforce_utf8_boundaries=True, + show_progress=False, + ) + tokenizer.train_from_iterator(data, trainer=constrained_trainer) + vocab = tokenizer.get_vocab() + + # The invalid merge should not be present when enforcing UTF-8 boundaries + assert invalid_token not in vocab, "Constrained trainer should not learn invalid merges" + class TestWordPieceTrainer: def test_can_modify(self): trainer = trainers.WordPieceTrainer( diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index b3a6fd4b2..977d0e224 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,6 +4,7 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use crate::pre_tokenizers::byte_level::CHAR_BYTES; use ahash::{AHashMap, AHashSet}; use compact_str::CompactString; use dary_heap::OctonaryHeap; @@ -48,6 +49,7 @@ struct Config { continuing_subword_prefix: Option, end_of_word_suffix: Option, max_token_length: Option, + enforce_utf8_boundaries: bool, } /// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom @@ -69,6 +71,7 @@ impl Default for BpeTrainerBuilder { continuing_subword_prefix: None, end_of_word_suffix: None, max_token_length: None, + enforce_utf8_boundaries: false, }, } } @@ -144,6 +147,13 @@ impl BpeTrainerBuilder { self } + /// Whether to enforce UTF-8 character boundaries during merges + #[must_use] + pub fn enforce_utf8_boundaries(mut self, enforce: bool) -> Self { + self.config.enforce_utf8_boundaries = enforce; + self + } + /// Constructs the final BpeTrainer pub fn build(self) -> BpeTrainer { BpeTrainer { @@ -156,6 +166,7 @@ impl BpeTrainerBuilder { continuing_subword_prefix: self.config.continuing_subword_prefix, end_of_word_suffix: self.config.end_of_word_suffix, max_token_length: self.config.max_token_length, + enforce_utf8_boundaries: self.config.enforce_utf8_boundaries, words: AHashMap::new(), } } @@ -199,6 +210,11 @@ pub struct BpeTrainer { pub end_of_word_suffix: Option, /// An optional parameter to limit the max length of any single token pub max_token_length: Option, + /// Whether to enforce UTF-8 character boundaries during merges. When true, only allows merging: + /// 1. Complete UTF-8 characters with each other + /// 2. Single bytes that are part of the same UTF-8 character, from left to right + /// This is useful to avoid creating tokens that are not valid UTF-8 sequences, at no cost to compression. + pub enforce_utf8_boundaries: bool, words: AHashMap, } @@ -210,6 +226,7 @@ impl Default for BpeTrainer { } impl BpeTrainer { + pub fn new(min_frequency: u64, vocab_size: usize) -> Self { Self { min_frequency, @@ -270,6 +287,67 @@ impl BpeTrainer { } } + /// helper for is_merge_allowed, to get the original bytes of a part + fn get_original_bytes(&self, part: &str) -> Option> { + part.chars().map(|c| CHAR_BYTES.get(&c).copied()).collect() + } + /// Determines if a merge is allowed under UTF-8 boundary constraints. + /// + /// This check is only performed if `enforce_utf8_boundaries` is true. + /// A merge is allowed if it meets one of the following criteria: + /// 1. Both tokens consist of complete characters. + /// 2. Both tokens are part of the same single character, and the second is a single byte. + /// This allows building multi-byte characters from their individual bytes left-to-right. + /// All other combinations, such as merging a complete character with a partial byte, are disallowed. + /// This function is designed to work on the character-mapped output of a `ByteLevel` + /// pre-tokenizer by reversing the mapping to check the original bytes. + /// Determines if a merge is allowed under UTF-8 boundary constraints. + /// This function is designed to work on the character-mapped output of a `ByteLevel` + /// pre-tokenizer by reversing the mapping to check the original bytes. + fn is_merge_allowed(&self, pair: &Pair, id_to_word: &[CompactString]) -> bool { + if !self.enforce_utf8_boundaries { + return true; + } + + let part_a = &id_to_word[pair.0 as usize]; + let part_b = &id_to_word[pair.1 as usize]; + + // Get the original bytes by reversing the ByteLevel character mapping. + let bytes_a = self.get_original_bytes(part_a.as_ref()).unwrap_or_default(); + let bytes_b = self.get_original_bytes(part_b.as_ref()).unwrap_or_default(); + + // A "complete" token is one whose underlying bytes form a valid UTF-8 string. + // For ByteLevel, this means single-byte ASCII chars (like a space) are complete, + // but single bytes from a multi-byte sequence (like 0xF0) are not. + let is_a_complete = std::str::from_utf8(&bytes_a).is_ok(); + let is_b_complete = std::str::from_utf8(&bytes_b).is_ok(); + + // - Allow merging two complete tokens. + // - Any mix of complete and incomplete is disallowed. + if is_a_complete && is_b_complete { + return true; + } + if is_a_complete ^ is_b_complete { + return false; + } + + // Here we know both tokens are incomplete. + // Allow merge only if building a valid UTF-8 prefix by appending a single byte. + if bytes_b.len() == 1 { + let mut merged = bytes_a; + merged.extend_from_slice(&bytes_b); + match std::str::from_utf8(&merged) { + // The merged bytes form one or more complete characters. Valid. + Ok(_) => true, + // The merged bytes are an incomplete but valid prefix. Valid. + Err(e) => e.error_len().is_none(), + } + } else { + // If part_b is not a single byte, it's not a valid continuation merge. + false + } + } + /// Compute the initial alphabet and limit it if relevant fn compute_alphabet( &self, @@ -455,7 +533,7 @@ impl BpeTrainer { let mut queue = OctonaryHeap::with_capacity(pair_counts.len()); where_to_update.drain().for_each(|(pair, pos)| { let count = pair_counts[&pair]; - if count > 0 { + if count > 0 && self.is_merge_allowed(&pair, &id_to_word) { queue.push(Merge { pair, count: count as u64, @@ -550,13 +628,13 @@ impl BpeTrainer { for ((pair, change), iw) in changes { let count = change * counts[iw] as i32; *pair_counts.entry(pair).or_default() += count; - if change > 0 { + if change > 0 && self.is_merge_allowed(&pair, &id_to_word) { where_to_update.entry(pair).or_default().insert(iw); } } where_to_update.drain().for_each(|(pair, pos)| { let count = pair_counts[&pair]; - if count > 0 { + if count > 0 && self.is_merge_allowed(&pair, &id_to_word) { queue.push(Merge { pair, count: count as u64, @@ -644,8 +722,14 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { use super::{BpeTrainer, Pair, BPE}; + use crate::pre_tokenizers::byte_level::{bytes_char, ByteLevel}; + use crate::tokenizer::{ + OffsetReferential, OffsetType, PreTokenizedString, PreTokenizer, Result, Trainer, + }; use ahash::AHashMap; use compact_str::CompactString; + use std::collections::HashMap; + use std::sync::LazyLock; #[test] fn test_train() { @@ -762,6 +846,7 @@ mod tests { ) } } + #[test] fn bpe_test_max_token_length_direct_assert() { /* more direct version of bpe_test_max_token_length test @@ -831,4 +916,74 @@ mod tests { .collect(); assert_eq!(trained_vocab, expected_vocab) } + + static BYTE_TO_CHAR: LazyLock> = LazyLock::new(bytes_char); + + #[test] + fn test_bpe_utf8_boundary_enforcement_with_byte_level_pretokenizer() { + // Use the actual ByteLevel pre-tokenizer to process the input string. + let byte_level_pretok = ByteLevel::new(false, false, false); + let process_fn = |s: &str| -> Result> { + let mut pretokenized = PreTokenizedString::from(s); + byte_level_pretok.pre_tokenize(&mut pretokenized)?; + Ok(pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(word, _, _)| word.to_string()) + .collect()) + }; + + let sequence = " 🤗 🦒 🐹 🦦 🤗 𝟑".to_string(); + let vocab_size = 25; + + // --- Part 1: Unconstrained BPE --- + let mut unconstrained_trainer = BpeTrainer::builder() + .vocab_size(vocab_size) + .show_progress(false) + .enforce_utf8_boundaries(false) + .build(); + unconstrained_trainer + .feed(std::iter::once(&sequence), &process_fn) + .unwrap(); + let mut unconstrained_model = BPE::default(); + unconstrained_trainer + .train(&mut unconstrained_model) + .unwrap(); + + let invalid_merge_token: String = + [BYTE_TO_CHAR[&b' '], BYTE_TO_CHAR[&0xF0]].iter().collect(); + assert!( + unconstrained_model + .get_vocab() + .contains_key(&invalid_merge_token), + "Unconstrained vocab SHOULD contain the top frequency merge (bytes [20 F0])" + ); + + // --- Part 2: Constrained BPE --- + let mut constrained_trainer = BpeTrainer::builder() + .vocab_size(vocab_size) + .show_progress(false) + .enforce_utf8_boundaries(true) + .build(); + constrained_trainer + .feed(std::iter::once(&sequence), &process_fn) + .unwrap(); + let mut constrained_model = BPE::default(); + constrained_trainer.train(&mut constrained_model).unwrap(); + + let valid_merge_token: String = + [BYTE_TO_CHAR[&0xF0], BYTE_TO_CHAR[&0x9F]].iter().collect(); + assert!( + !constrained_model + .get_vocab() + .contains_key(&invalid_merge_token), + "Constrained vocab MUST NOT contain the invalid merge (bytes [20 F0])" + ); + assert!( + constrained_model + .get_vocab() + .contains_key(&valid_merge_token), + "Constrained vocab SHOULD contain the next valid merge (bytes [F0 9F])" + ); + } } diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8bc0f30af..54a984282 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -45,7 +45,7 @@ static RE: LazyLock = LazyLock::new(|| { .unwrap() }); static BYTES_CHAR: LazyLock> = LazyLock::new(bytes_char); -static CHAR_BYTES: LazyLock> = +pub(crate) static CHAR_BYTES: LazyLock> = LazyLock::new(|| bytes_char().into_iter().map(|(c, b)| (b, c)).collect()); #[derive(Copy, Clone, Debug, PartialEq, Eq)] diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 84f77a523..7a3086dc5 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -534,6 +534,16 @@ where PP: PostProcessor, D: Decoder, { + /// Validates compatibility between a trainer and the current tokenizer configuration. + /// Currently only checks: + // For BpeTrainer with `enforce_utf8_boundaries=True` => pretokenizer must be ByteLevel. + fn _check_trainer_compat( + &self, + _trainer: &T, + ) -> Result<()> { + Ok(()) + } + /// Instantiate a new Tokenizer, with the given Model pub fn new(model: M) -> Self { Self { @@ -1345,6 +1355,7 @@ where where T: Trainer + Sync, { + self._check_trainer_compat(trainer)?; // check that settings are compatible let mut len = 0; for file in files.iter() { len += File::open(file) @@ -1420,6 +1431,7 @@ where I: Iterator + Send, S: AsRef + Send, { + self._check_trainer_compat(trainer)?; // check that settings are compatible let (lower, upper) = sequences.size_hint(); let len = upper.unwrap_or(lower) as u64; let progress = if trainer.should_show_progress() {