From 7a6423eba7656e00ff8e9788d1ec83085c3cbb9c Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Wed, 7 Oct 2020 16:16:02 -0400 Subject: [PATCH 01/11] Rust - Trainer::process_tokens has a default impl --- tokenizers/src/models/bpe/trainer.rs | 13 +------------ tokenizers/src/models/unigram/trainer.rs | 10 ---------- tokenizers/src/tokenizer/mod.rs | 9 ++++++++- 3 files changed, 9 insertions(+), 23 deletions(-) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 2938b56cd..4e7146a03 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -583,18 +583,7 @@ impl Trainer for BpeTrainer { /// Train a BPE model fn train(&self, word_counts: HashMap) -> Result<(BPE, Vec)> { - let (bpe, tokens) = self.train(word_counts)?; - Ok((bpe, tokens)) - } - - /// Process a bunch of tokens, counting them - fn process_tokens(&self, words: &mut HashMap, tokens: Vec) { - for token in tokens { - words - .entry(token.clone()) - .and_modify(|c| *c += 1) - .or_insert(1); - } + self.train(word_counts) } /// Whether we should show progress diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 9af21aa72..91f65a975 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -531,16 +531,6 @@ impl Trainer for UnigramTrainer { self._train(sentences) } - /// Process a bunch of tokens, counting them - fn process_tokens(&self, words: &mut HashMap, tokens: Vec) { - for token in tokens { - words - .entry(token.clone()) - .and_modify(|c| *c += 1) - .or_insert(1); - } - } - /// Whether we should show progress fn should_show_progress(&self) -> bool { self.show_progress diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 2baed70b0..8947654c7 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -133,7 +133,14 @@ pub trait Trainer { words: HashMap, ) -> Result<(::Model, Vec)>; /// Process a bunch of token, counting them as relevant. - fn process_tokens(&self, words: &mut HashMap, tokens: Vec); + fn process_tokens(&self, words: &mut HashMap, tokens: Vec) { + for token in tokens { + words + .entry(token.clone()) + .and_modify(|c| *c += 1) + .or_insert(1); + } + } } #[derive(Debug, Clone, PartialEq)] From 5a6e91be32d7a0fb9598a51e4b4ccd4ac4fdcfd3 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Wed, 7 Oct 2020 16:46:55 -0400 Subject: [PATCH 02/11] Add WordLevel trainer --- .../py_src/tokenizers/trainers/__init__.py | 1 + bindings/python/src/lib.rs | 1 + bindings/python/src/trainers.rs | 63 ++++++++++ tokenizers/src/models/wordlevel/mod.rs | 4 + tokenizers/src/models/wordlevel/trainer.rs | 116 ++++++++++++++++++ 5 files changed, 185 insertions(+) create mode 100644 tokenizers/src/models/wordlevel/trainer.rs diff --git a/bindings/python/py_src/tokenizers/trainers/__init__.py b/bindings/python/py_src/tokenizers/trainers/__init__.py index 05243aa57..22f94c50b 100644 --- a/bindings/python/py_src/tokenizers/trainers/__init__.py +++ b/bindings/python/py_src/tokenizers/trainers/__init__.py @@ -4,4 +4,5 @@ Trainer = trainers.Trainer BpeTrainer = trainers.BpeTrainer UnigramTrainer = trainers.UnigramTrainer +WordLevelTrainer = trainers.WordLevelTrainer WordPieceTrainer = trainers.WordPieceTrainer diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 7cd727ff6..d43de428c 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -44,6 +44,7 @@ fn trainers(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; Ok(()) } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 3db218513..96743331a 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -242,6 +242,69 @@ impl PyWordPieceTrainer { } } +/// Capable of training a WorldLevel model +/// +/// Args: +/// vocab_size: unsigned int: +/// The size of the final vocabulary, including all tokens and alphabet. +/// +/// min_frequency: unsigned int: +/// The minimum frequency a pair should have in order to be merged. +/// +/// show_progress: boolean: +/// Whether to show progress bars while training. +/// +/// special_tokens: List[Union[str, AddedToken]]: +/// A list of special tokens the model should know of. +/// +/// Returns: +/// Trainer +#[pyclass(extends=PyTrainer, name=WordLevelTrainer)] +pub struct PyWordLevelTrainer {} +#[pymethods] +impl PyWordLevelTrainer { + /// Create a new WordLevelTrainer with the given configuration + #[new] + #[args(kwargs = "**")] + pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> { + let mut trainer = tk::models::wordlevel::WordLevelTrainer::default(); + + if let Some(kwargs) = kwargs { + for (key, val) in kwargs { + let key: &str = key.extract()?; + match key { + "vocab_size" => trainer.vocab_size = val.extract()?, + "min_frequency" => trainer.min_frequency = val.extract()?, + "show_progress" => trainer.show_progress = val.extract()?, + "special_tokens" => { + trainer.special_tokens = val + .cast_as::()? + .into_iter() + .map(|token| { + if let Ok(content) = token.extract::() { + Ok(PyAddedToken::from(content, Some(true)).get_token()) + } else if let Ok(mut token) = + token.extract::>() + { + token.is_special_token = true; + Ok(token.get_token()) + } else { + Err(exceptions::PyTypeError::new_err( + "special_tokens must be a List[Union[str, AddedToken]]", + )) + } + }) + .collect::>>()? + } + _ => println!("Ignored unknown kwargs option {}", key), + } + } + } + + Ok((PyWordLevelTrainer {}, PyTrainer::new(trainer.into()))) + } +} + /// Capable of training a Unigram model /// /// Args: diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 7e1ac4407..88d731d70 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -8,6 +8,10 @@ use std::io::{BufReader, Read, Write}; use std::path::{Path, PathBuf}; mod serialization; +mod trainer; + +// Re-export +pub use trainer::*; type Vocab = HashMap; diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs new file mode 100644 index 000000000..9fa2b664a --- /dev/null +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -0,0 +1,116 @@ +use super::WordLevel; +use crate::{AddedToken, Result, Trainer}; +use std::collections::HashMap; + +pub struct WordLevelTrainer { + /// The minimum frequency a word must have to be part of the vocabulary + pub min_frequency: u32, + /// The target vocabulary size + pub vocab_size: usize, + /// Whether to show progress while training + pub show_progress: bool, + /// A list of special tokens that the model should know of + pub special_tokens: Vec, +} + +impl Default for WordLevelTrainer { + fn default() -> Self { + Self { + min_frequency: 0, + vocab_size: 30_000, + show_progress: true, + special_tokens: vec![], + } + } +} + +impl WordLevelTrainer { + fn train(&self, word_counts: HashMap) -> Result<(WordLevel, Vec)> { + let mut ordered_counts = word_counts.into_iter().collect::>(); + ordered_counts.sort_by_key(|(_, n)| std::cmp::Reverse(*n)); + let word_level = WordLevel::builder() + .vocab( + self.special_tokens + .iter() + .map(|token| token.content.clone()) + .chain( + ordered_counts + .into_iter() + .filter(|(_, n)| *n >= self.min_frequency) + .map(|(w, _)| w), + ) + .take(self.vocab_size) + .enumerate() + .map(|(i, w)| (w, i as u32)) + .collect(), + ) + .build(); + + Ok((word_level, self.special_tokens.clone())) + } +} + +impl Trainer for WordLevelTrainer { + type Model = WordLevel; + + /// Train a WordLevel model + fn train(&self, word_counts: HashMap) -> Result<(WordLevel, Vec)> { + self.train(word_counts) + } + + /// Whether we should show progress + fn should_show_progress(&self) -> bool { + self.show_progress + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_train() { + let word_counts: HashMap = [ + ("the".into(), 25), + ("roses".into(), 22), + ("are".into(), 24), + ("red".into(), 12), + ("voilets".into(), 10), + ("blue".into(), 16), + ] + .iter() + .cloned() + .collect(); + + let mut trainer = WordLevelTrainer::default(); + trainer.vocab_size = 5; + + let (model, _) = trainer.train(word_counts.clone()).unwrap(); + let expected_vocab: HashMap = [ + ("the".into(), 0), + ("are".into(), 1), + ("roses".into(), 2), + ("blue".into(), 3), + ("red".into(), 4), + ] + .iter() + .cloned() + .collect(); + assert_eq!(model.vocab, expected_vocab); + + // If we specify a min_frequency + trainer.min_frequency = 15; + let (model, _) = trainer.train(word_counts).unwrap(); + let expected_vocab: HashMap = [ + ("the".into(), 0), + ("are".into(), 1), + ("roses".into(), 2), + ("blue".into(), 3), + ] + .iter() + .cloned() + .collect(); + + assert_eq!(model.vocab, expected_vocab); + } +} From 7dd89990cd3e3f6b6c51ead8387f497338f5f90e Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Wed, 7 Oct 2020 17:44:58 -0400 Subject: [PATCH 03/11] A Model can return its associated Trainer --- bindings/python/src/models.rs | 11 +++++ bindings/python/src/trainers.rs | 51 +++++++++++++++++--- tokenizers/src/models/bpe/model.rs | 8 ++- tokenizers/src/models/mod.rs | 19 +++++++- tokenizers/src/models/unigram/model.rs | 13 ++++- tokenizers/src/models/unigram/trainer.rs | 6 +++ tokenizers/src/models/wordlevel/mod.rs | 6 +++ tokenizers/src/models/wordpiece/mod.rs | 6 +++ tokenizers/src/tokenizer/added_vocabulary.rs | 18 ++++++- tokenizers/src/tokenizer/mod.rs | 3 ++ 10 files changed, 130 insertions(+), 11 deletions(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index b38e3033c..e2d2e6fee 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -3,6 +3,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use crate::token::PyToken; +use crate::trainers::PyTrainer; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -45,6 +46,8 @@ impl PyModel { } impl Model for PyModel { + type Trainer = PyTrainer; + fn tokenize(&self, tokens: &str) -> tk::Result> { self.model.tokenize(tokens) } @@ -68,6 +71,10 @@ impl Model for PyModel { fn save(&self, folder: &Path, name: Option<&str>) -> tk::Result> { self.model.save(folder, name) } + + fn get_trainer(&self) -> Self::Trainer { + self.model.get_trainer().into() + } } #[pymethods] @@ -142,6 +149,10 @@ impl PyModel { .map(|path| path.to_string_lossy().into_owned()) .collect()) } + + fn get_trainer(&self) -> PyResult { + PyTrainer::from(self.model.get_trainer()).get_as_subtype() + } } /// Instantiate a BPE Model from the given vocab and merges. diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 96743331a..eb90ea5a2 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -47,15 +47,34 @@ use crate::tokenizer::PyAddedToken; /// Returns: /// Trainer #[pyclass(name=Trainer)] +#[derive(Clone)] #[text_signature = "(self, vocab_size=30000, min_frequency=0,show_progress=True, special_tokens=[],limit_alphabet=None, initial_alphabet = [], continuing_subword_prefix=None, end_of_word_suffix=None)"] pub struct PyTrainer { - pub trainer: TrainerWrapper, + pub trainer: Arc, } impl PyTrainer { - pub fn new(trainer: TrainerWrapper) -> Self { + pub(crate) fn new(trainer: Arc) -> Self { PyTrainer { trainer } } + + pub(crate) fn get_as_subtype(&self) -> PyResult { + let base = self.clone(); + let gil = Python::acquire_gil(); + let py = gil.python(); + Ok(match self.trainer.as_ref() { + TrainerWrapper::BpeTrainer(_) => Py::new(py, (PyBpeTrainer {}, base))?.into_py(py), + TrainerWrapper::WordPieceTrainer(_) => { + Py::new(py, (PyWordPieceTrainer {}, base))?.into_py(py) + } + TrainerWrapper::WordLevelTrainer(_) => { + Py::new(py, (PyWordLevelTrainer {}, base))?.into_py(py) + } + TrainerWrapper::UnigramTrainer(_) => { + Py::new(py, (PyUnigramTrainer {}, base))?.into_py(py) + } + }) + } } impl Trainer for PyTrainer { @@ -77,6 +96,17 @@ impl Trainer for PyTrainer { } } +impl From for PyTrainer +where + I: Into, +{ + fn from(trainer: I) -> Self { + PyTrainer { + trainer: trainer.into().into(), + } + } +} + /// Capable of training a BPE model #[pyclass(extends=PyTrainer, name=BpeTrainer)] pub struct PyBpeTrainer {} @@ -138,7 +168,10 @@ impl PyBpeTrainer { }; } } - Ok((PyBpeTrainer {}, PyTrainer::new(builder.build().into()))) + Ok(( + PyBpeTrainer {}, + PyTrainer::new(Arc::new(builder.build().into())), + )) } } @@ -237,7 +270,7 @@ impl PyWordPieceTrainer { Ok(( PyWordPieceTrainer {}, - PyTrainer::new(builder.build().into()), + PyTrainer::new(Arc::new(builder.build().into())), )) } } @@ -301,7 +334,10 @@ impl PyWordLevelTrainer { } } - Ok((PyWordLevelTrainer {}, PyTrainer::new(trainer.into()))) + Ok(( + PyWordLevelTrainer {}, + PyTrainer::new(Arc::new(trainer.into())), + )) } } @@ -388,6 +424,9 @@ impl PyUnigramTrainer { builder.build().map_err(|e| { exceptions::PyException::new_err(format!("Cannot build UnigramTrainer: {}", e)) })?; - Ok((PyUnigramTrainer {}, PyTrainer::new(trainer.into()))) + Ok(( + PyUnigramTrainer {}, + PyTrainer::new(Arc::new(trainer.into())), + )) } } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 508bd0219..15c6f679f 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,4 +1,4 @@ -use super::{super::OrderedVocabIter, Error, Pair, Word}; +use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY}; use crate::utils::iter::ResultShunt; @@ -415,6 +415,8 @@ impl BPE { } impl Model for BPE { + type Trainer = BpeTrainer; + fn get_vocab(&self) -> &HashMap { &self.vocab } @@ -487,6 +489,10 @@ impl Model for BPE { Ok(vec![vocab_path, merges_path]) } + + fn get_trainer(&self) -> BpeTrainer { + BpeTrainer::default() + } } #[cfg(test)] diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 0d6b4915f..46a5b8c57 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize, Serializer}; use crate::models::bpe::{BpeTrainer, BPE}; use crate::models::unigram::{Unigram, UnigramTrainer}; -use crate::models::wordlevel::WordLevel; +use crate::models::wordlevel::{WordLevel, WordLevelTrainer}; use crate::models::wordpiece::{WordPiece, WordPieceTrainer}; use crate::{AddedToken, Model, Result, Token, Trainer}; @@ -53,6 +53,8 @@ impl_enum_from!(BPE, ModelWrapper, BPE); impl_enum_from!(Unigram, ModelWrapper, Unigram); impl Model for ModelWrapper { + type Trainer = TrainerWrapper; + fn tokenize(&self, tokens: &str) -> Result> { use ModelWrapper::*; match self { @@ -112,11 +114,22 @@ impl Model for ModelWrapper { Unigram(t) => t.save(folder, name), } } + + fn get_trainer(&self) -> Self::Trainer { + use ModelWrapper::*; + match self { + WordLevel(t) => t.get_trainer().into(), + WordPiece(t) => t.get_trainer().into(), + BPE(t) => t.get_trainer().into(), + Unigram(t) => t.get_trainer().into(), + } + } } pub enum TrainerWrapper { BpeTrainer(BpeTrainer), WordPieceTrainer(WordPieceTrainer), + WordLevelTrainer(WordLevelTrainer), UnigramTrainer(UnigramTrainer), } @@ -127,6 +140,7 @@ impl Trainer for TrainerWrapper { match self { TrainerWrapper::BpeTrainer(bpe) => bpe.should_show_progress(), TrainerWrapper::WordPieceTrainer(wpt) => wpt.should_show_progress(), + TrainerWrapper::WordLevelTrainer(wpt) => wpt.should_show_progress(), TrainerWrapper::UnigramTrainer(wpt) => wpt.should_show_progress(), } } @@ -135,6 +149,7 @@ impl Trainer for TrainerWrapper { match self { TrainerWrapper::BpeTrainer(bpe) => bpe.train(words).map(|(m, t)| (m.into(), t)), TrainerWrapper::WordPieceTrainer(wpt) => wpt.train(words).map(|(m, t)| (m.into(), t)), + TrainerWrapper::WordLevelTrainer(wpt) => wpt.train(words).map(|(m, t)| (m.into(), t)), TrainerWrapper::UnigramTrainer(wpt) => wpt.train(words).map(|(m, t)| (m.into(), t)), } } @@ -143,6 +158,7 @@ impl Trainer for TrainerWrapper { match self { TrainerWrapper::BpeTrainer(bpe) => bpe.process_tokens(words, tokens), TrainerWrapper::WordPieceTrainer(wpt) => wpt.process_tokens(words, tokens), + TrainerWrapper::WordLevelTrainer(wpt) => wpt.process_tokens(words, tokens), TrainerWrapper::UnigramTrainer(wpt) => wpt.process_tokens(words, tokens), } } @@ -151,3 +167,4 @@ impl Trainer for TrainerWrapper { impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer); impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer); impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer); +impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer); diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 8847d0270..fc6d082e7 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -1,5 +1,8 @@ -use crate::models::unigram::lattice::Lattice; -use crate::models::unigram::trie::{Trie, TrieBuilder}; +use super::{ + lattice::Lattice, + trainer::UnigramTrainer, + trie::{Trie, TrieBuilder}, +}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; @@ -404,6 +407,8 @@ impl<'a> Iterator for UnigramIterator<'a> { } impl Model for Unigram { + type Trainer = UnigramTrainer; + fn get_vocab(&self) -> &HashMap { &self.token_to_ids } @@ -452,6 +457,10 @@ impl Model for Unigram { std::fs::write(&fullpath, string)?; Ok(vec![fullpath]) } + + fn get_trainer(&self) -> Self::Trainer { + UnigramTrainer::default() + } } #[cfg(test)] diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 91f65a975..ba06023a9 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -60,6 +60,12 @@ pub struct UnigramTrainer { seed_size: usize, } +impl Default for UnigramTrainer { + fn default() -> Self { + Self::builder().build().unwrap() + } +} + impl UnigramTrainer { pub fn builder() -> UnigramTrainerBuilder { UnigramTrainerBuilder::default() diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 88d731d70..2918f8bb0 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -165,6 +165,8 @@ impl Default for WordLevel { } impl Model for WordLevel { + type Trainer = WordLevelTrainer; + fn tokenize(&self, token: &str) -> Result> { Ok(vec![Token { id: *self @@ -210,4 +212,8 @@ impl Model for WordLevel { Ok(vec![vocab_path]) } + + fn get_trainer(&self) -> Self::Trainer { + WordLevelTrainer::default() + } } diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 1c817a6ff..a3b2e6e04 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -200,6 +200,8 @@ impl WordPiece { } impl Model for WordPiece { + type Trainer = WordPieceTrainer; + fn get_vocab(&self) -> &HashMap { &self.vocab } @@ -299,6 +301,10 @@ impl Model for WordPiece { Ok(vec![vocab_path]) } + + fn get_trainer(&self) -> Self::Trainer { + WordPieceTrainer::builder().build() + } } #[cfg(test)] diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 7d2816480..93e1c1ec9 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -499,7 +499,7 @@ mod tests { use super::*; use crate::normalizers::utils::Lowercase; use crate::normalizers::NormalizerWrapper; - use crate::{OffsetReferential, OffsetType, Result, Token}; + use crate::{OffsetReferential, OffsetType, Result, Token, Trainer}; use std::path::{Path, PathBuf}; #[derive(Serialize, Deserialize)] @@ -526,7 +526,20 @@ mod tests { } } + struct TrainerMock; + impl Trainer for TrainerMock { + type Model = ModelMock; + fn should_show_progress(&self) -> bool { + true + } + fn train(&self, _words: HashMap) -> Result<(ModelMock, Vec)> { + unimplemented!() + } + } + impl Model for ModelMock { + type Trainer = TrainerMock; + fn tokenize(&self, _sequence: &str) -> Result> { unimplemented!() } @@ -545,6 +558,9 @@ mod tests { fn save(&self, _folder: &Path, _name: Option<&str>) -> Result> { unimplemented!() } + fn get_trainer(&self) -> Self::Trainer { + TrainerMock + } } #[test] diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 8947654c7..93504c478 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -68,6 +68,7 @@ pub trait PreTokenizer { /// Represents a model used during Tokenization (like BPE or Word or Unigram). pub trait Model { + type Trainer: Trainer + Sync; /// Tokenize the given sequence into multiple underlying `Token`. The `offsets` on the `Token` /// are expected to be relative to the given sequence. fn tokenize(&self, sequence: &str) -> Result>; @@ -82,6 +83,8 @@ pub trait Model { /// Save the current `Model` in the given folder, using the given `prefix` for the various /// files that need to be saved. fn save(&self, folder: &Path, prefix: Option<&str>) -> Result>; + /// Get an instance of a Trainer capable of training this Model + fn get_trainer(&self) -> ::Trainer; } /// A `PostProcessor` has the responsibility to post process an encoded output of the `Tokenizer`. From 9bd12fd71a7b87038f3bf99a7a4f0892ab7544ad Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Wed, 7 Oct 2020 21:25:32 -0400 Subject: [PATCH 04/11] Python - Make the trainer optional on Tokenizer.train --- bindings/python/README.md | 4 ++-- .../tokenizers/implementations/bert_wordpiece.py | 2 +- .../tokenizers/implementations/byte_level_bpe.py | 2 +- .../tokenizers/implementations/char_level_bpe.py | 2 +- .../tokenizers/implementations/sentencepiece_bpe.py | 2 +- .../implementations/sentencepiece_unigram.py | 2 +- bindings/python/src/tokenizer.rs | 13 ++++++++----- 7 files changed, 15 insertions(+), 12 deletions(-) diff --git a/bindings/python/README.md b/bindings/python/README.md index 6518b03d3..a2d02b1c0 100644 --- a/bindings/python/README.md +++ b/bindings/python/README.md @@ -138,11 +138,11 @@ tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) # And then train trainer = trainers.BpeTrainer(vocab_size=20000, min_frequency=2) -tokenizer.train(trainer, [ +tokenizer.train([ "./path/to/dataset/1.txt", "./path/to/dataset/2.txt", "./path/to/dataset/3.txt" -]) +], trainer=trainer) # And Save it tokenizer.save("byte-level-bpe.tokenizer.json", pretty=True) diff --git a/bindings/python/py_src/tokenizers/implementations/bert_wordpiece.py b/bindings/python/py_src/tokenizers/implementations/bert_wordpiece.py index 14c122c3b..04cb4f920 100644 --- a/bindings/python/py_src/tokenizers/implementations/bert_wordpiece.py +++ b/bindings/python/py_src/tokenizers/implementations/bert_wordpiece.py @@ -115,4 +115,4 @@ def train( ) if isinstance(files, str): files = [files] - self._tokenizer.train(trainer, files) + self._tokenizer.train(files, trainer=trainer) diff --git a/bindings/python/py_src/tokenizers/implementations/byte_level_bpe.py b/bindings/python/py_src/tokenizers/implementations/byte_level_bpe.py index e9e88cb8d..85c423269 100644 --- a/bindings/python/py_src/tokenizers/implementations/byte_level_bpe.py +++ b/bindings/python/py_src/tokenizers/implementations/byte_level_bpe.py @@ -101,4 +101,4 @@ def train( ) if isinstance(files, str): files = [files] - self._tokenizer.train(trainer, files) + self._tokenizer.train(files, trainer=trainer) diff --git a/bindings/python/py_src/tokenizers/implementations/char_level_bpe.py b/bindings/python/py_src/tokenizers/implementations/char_level_bpe.py index 73b89f4a7..22111ede8 100644 --- a/bindings/python/py_src/tokenizers/implementations/char_level_bpe.py +++ b/bindings/python/py_src/tokenizers/implementations/char_level_bpe.py @@ -123,4 +123,4 @@ def train( ) if isinstance(files, str): files = [files] - self._tokenizer.train(trainer, files) + self._tokenizer.train(files, trainer=trainer) diff --git a/bindings/python/py_src/tokenizers/implementations/sentencepiece_bpe.py b/bindings/python/py_src/tokenizers/implementations/sentencepiece_bpe.py index 645777f3b..5e481dd69 100644 --- a/bindings/python/py_src/tokenizers/implementations/sentencepiece_bpe.py +++ b/bindings/python/py_src/tokenizers/implementations/sentencepiece_bpe.py @@ -74,4 +74,4 @@ def train( ) if isinstance(files, str): files = [files] - self._tokenizer.train(trainer, files) + self._tokenizer.train(files, trainer=trainer) diff --git a/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py b/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py index 8553bc38c..bc378f38e 100644 --- a/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py +++ b/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py @@ -75,7 +75,7 @@ def train( if isinstance(files, str): files = [files] - self._tokenizer.train(trainer, files) + self._tokenizer.train(files, trainer=trainer) @staticmethod def from_spm(filename: str): diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index df2ae32c5..0d26edba5 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -8,7 +8,7 @@ use pyo3::types::*; use pyo3::PyObjectProtocol; use tk::models::bpe::BPE; use tk::tokenizer::{ - PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, + Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, TruncationParams, TruncationStrategy, }; use tokenizers as tk; @@ -1039,10 +1039,13 @@ impl PyTokenizer { Ok(self.tokenizer.add_special_tokens(&tokens)) } - fn train(&mut self, trainer: &PyTrainer, files: Vec) -> PyResult<()> { - let gil = Python::acquire_gil(); - gil.python() - .allow_threads(|| ToPyResult(self.tokenizer.train_and_replace(trainer, files)).into()) + #[args(trainer = "None")] + fn train(&mut self, files: Vec, trainer: Option<&PyTrainer>) -> PyResult<()> { + let trainer = + trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); + Python::with_gil(|py| { + py.allow_threads(|| ToPyResult(self.tokenizer.train_and_replace(&trainer, files)).into()) + }) } /// Apply all the post-processing steps to the given encodings. From a4b6588c155b7a8bbee4928d6fe00af4146909ad Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 8 Oct 2020 18:20:38 -0400 Subject: [PATCH 05/11] Train Model in place This let us keep everything that was set on the model except from the vocabulary when trained. For example, this let us keep the configured `unk_token` of BPE when its trained. --- bindings/python/src/tokenizer.rs | 2 +- bindings/python/src/trainers.rs | 12 ++-- tokenizers/benches/common/mod.rs | 2 +- tokenizers/src/lib.rs | 4 +- tokenizers/src/models/bpe/trainer.rs | 57 ++++++++-------- tokenizers/src/models/mod.rs | 26 ++++++-- tokenizers/src/models/unigram/trainer.rs | 70 +++++++++++--------- tokenizers/src/models/wordlevel/trainer.rs | 28 ++++++-- tokenizers/src/models/wordpiece/trainer.rs | 30 +++++++-- tokenizers/src/tokenizer/added_vocabulary.rs | 6 +- tokenizers/src/tokenizer/mod.rs | 37 ++--------- tokenizers/tests/unigram.rs | 3 +- 12 files changed, 156 insertions(+), 121 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 0d26edba5..25de9bbc5 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1044,7 +1044,7 @@ impl PyTokenizer { let trainer = trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); Python::with_gil(|py| { - py.allow_threads(|| ToPyResult(self.tokenizer.train_and_replace(&trainer, files)).into()) + py.allow_threads(|| ToPyResult(self.tokenizer.train(&trainer, files)).into()) }) } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index eb90ea5a2..58c35d330 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -84,11 +84,13 @@ impl Trainer for PyTrainer { self.trainer.should_show_progress() } - fn train(&self, words: HashMap) -> tk::Result<(PyModel, Vec)> { - self.trainer.train(words).map(|(m, t)| { - let m = PyModel { model: Arc::new(m) }; - (m, t) - }) + fn train( + &self, + words: HashMap, + model: &mut PyModel, + ) -> tk::Result> { + todo!("FIX THIS"); + self.trainer.train(words, &mut model.model) } fn process_tokens(&self, words: &mut HashMap, tokens: Vec) { diff --git a/tokenizers/benches/common/mod.rs b/tokenizers/benches/common/mod.rs index 76adbde97..1c2453df2 100644 --- a/tokenizers/benches/common/mod.rs +++ b/tokenizers/benches/common/mod.rs @@ -75,7 +75,7 @@ where let mut duration = Duration::new(0, 0); for _i in 0..iters { let start = Instant::now(); - tokenizer.train_and_replace(trainer, files.clone()).unwrap(); + tokenizer.train(trainer, files.clone()).unwrap(); duration = duration.checked_add(start.elapsed()).unwrap(); } duration diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index eac3aa1fc..5e536a162 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -43,7 +43,7 @@ //! ``` //! //! ## Training and serialization example -//! +//! //! ```no_run //! use tokenizers::decoders::DecoderWrapper; //! use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; @@ -71,7 +71,7 @@ //! ]) //! .build(); //! -//! let tokenizer = TokenizerBuilder::new() +//! let mut tokenizer = TokenizerBuilder::new() //! .with_model(BPE::default()) //! .with_normalizer(Some(Sequence::new(vec![ //! Strip::new(true, true).into(), diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 4e7146a03..4e59aff1c 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -144,14 +144,15 @@ impl BpeTrainerBuilder { /// ``` /// use std::collections::HashMap; /// use tokenizers::tokenizer::Trainer; -/// use tokenizers::models::bpe::BpeTrainer; +/// use tokenizers::models::bpe::{BPE, BpeTrainer}; /// /// let word_counts: HashMap = [ /// (String::from("Hello"), 1), /// (String::from("World"), 1), /// ].iter().cloned().collect(); /// let trainer = BpeTrainer::default(); -/// let (model, special_tokens) = trainer.train(word_counts).unwrap(); +/// let mut model = BPE::default(); +/// let special_tokens = trainer.train(word_counts, &mut model).unwrap(); /// ``` pub struct BpeTrainer { /// The minimum frequency a pair must have to produce a merge operation @@ -404,7 +405,11 @@ impl BpeTrainer { ) } - pub fn train(&self, word_counts: HashMap) -> Result<(BPE, Vec)> { + pub fn train( + &self, + word_counts: HashMap, + model: &mut BPE, + ) -> Result> { let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); @@ -551,30 +556,27 @@ impl BpeTrainer { } self.finalize_progress(&progress, merges.len()); - let mut builder = BPE::builder().vocab_and_merges( - word_to_id, - merges - .into_iter() - .map(|((a_id, b_id), _)| { - ( - id_to_word[a_id as usize].clone(), - id_to_word[b_id as usize].clone(), - ) - }) - .collect(), - ); + // Transfer new vocab & options to model + model.vocab = word_to_id; + model.vocab_r = model + .vocab + .iter() + .map(|(key, val)| (*val, key.to_owned())) + .collect(); + model.merges = merges + .into_iter() + .enumerate() + .map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id))) + .collect(); + if let Some(prefix) = &self.continuing_subword_prefix { - builder = builder.continuing_subword_prefix(prefix.to_owned()); + model.continuing_subword_prefix = Some(prefix.to_owned()); } if let Some(suffix) = &self.end_of_word_suffix { - builder = builder.end_of_word_suffix(suffix.to_owned()); + model.end_of_word_suffix = Some(suffix.to_owned()); } - Ok(( - builder - .build() - .expect("Trainer should know how to build BPE"), - self.special_tokens.clone(), - )) + + Ok(self.special_tokens.clone()) } } @@ -582,8 +584,8 @@ impl Trainer for BpeTrainer { type Model = BPE; /// Train a BPE model - fn train(&self, word_counts: HashMap) -> Result<(BPE, Vec)> { - self.train(word_counts) + fn train(&self, word_counts: HashMap, model: &mut BPE) -> Result> { + self.train(word_counts, model) } /// Whether we should show progress @@ -594,7 +596,7 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { - use super::{BpeTrainer, Pair}; + use super::{BpeTrainer, Pair, BPE}; use std::collections::HashMap; #[test] @@ -619,7 +621,8 @@ mod tests { .show_progress(false) .min_frequency(2) .build(); - let (model, _) = trainer.train(word_counts).unwrap(); + let mut model = BPE::default(); + trainer.train(word_counts, &mut model).unwrap(); // Vocab should contain all of the characters from the `word_counts` mapping // as well as three merges: 're', 'are', and 'is'. diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 46a5b8c57..3ee88ebfd 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -145,12 +145,28 @@ impl Trainer for TrainerWrapper { } } - fn train(&self, words: HashMap) -> Result<(Self::Model, Vec)> { + fn train( + &self, + words: HashMap, + model: &mut ModelWrapper, + ) -> Result> { match self { - TrainerWrapper::BpeTrainer(bpe) => bpe.train(words).map(|(m, t)| (m.into(), t)), - TrainerWrapper::WordPieceTrainer(wpt) => wpt.train(words).map(|(m, t)| (m.into(), t)), - TrainerWrapper::WordLevelTrainer(wpt) => wpt.train(words).map(|(m, t)| (m.into(), t)), - TrainerWrapper::UnigramTrainer(wpt) => wpt.train(words).map(|(m, t)| (m.into(), t)), + TrainerWrapper::BpeTrainer(t) => match model { + ModelWrapper::BPE(bpe) => t.train(words, bpe), + _ => unreachable!(), + }, + TrainerWrapper::WordPieceTrainer(t) => match model { + ModelWrapper::WordPiece(wp) => t.train(words, wp), + _ => unreachable!(), + }, + TrainerWrapper::WordLevelTrainer(t) => match model { + ModelWrapper::WordLevel(wl) => t.train(words, wl), + _ => unreachable!(), + }, + TrainerWrapper::UnigramTrainer(t) => match model { + ModelWrapper::Unigram(u) => t.train(words, u), + _ => unreachable!(), + }, } } diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index ba06023a9..ae467af8b 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -450,7 +450,7 @@ impl UnigramTrainer { .collect(); new_pieces } - pub fn _train(&self, sentences: Vec) -> Result<(Unigram, Vec)> { + pub fn _train(&self, sentences: Vec, model: &mut Unigram) -> Result> { let progress = self.setup_progress(); // // 1. Compute frequent substrings @@ -484,22 +484,22 @@ impl UnigramTrainer { let expected_updates = expected_loops as usize * self.n_sub_iterations as usize; self.update_progress(&progress, expected_updates, "EM training"); let required_chars = self.required_chars(&sentences); - let mut model = Unigram::from(pieces.clone(), Some(0))?; + let mut new_model = Unigram::from(pieces.clone(), Some(0))?; loop { // Sub-EM iteration. for _iter in 0..self.n_sub_iterations { // Executes E step - let (_objective, _num_tokens, expected) = self.run_e_step(&model, &sentences); + let (_objective, _num_tokens, expected) = self.run_e_step(&new_model, &sentences); // Executes M step. pieces = self.run_m_step(&pieces, &expected); - model = Unigram::from(pieces.clone(), Some(0))?; + new_model = Unigram::from(pieces.clone(), Some(0))?; // Useful comment for checking compatibility with spm debug!( "Em iter={} size={} obj={} num_tokens={} num_tokens/piece={}", _iter, - model.len(), + new_model.len(), _objective, _num_tokens, _num_tokens as f64 / model.len() as f64 @@ -516,15 +516,15 @@ impl UnigramTrainer { } // Prunes pieces. - pieces = self.prune_sentence_pieces(&model, &pieces, &sentences); - model = Unigram::from(pieces.clone(), Some(0))?; + pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences); + new_model = Unigram::from(pieces.clone(), Some(0))?; } self.finalize_progress(&progress, expected_updates); // Finally, adjusts the size of sentencepices to be |vocab_size|. - model = self.finalize(model, required_chars)?; + *model = self.finalize(new_model, required_chars)?; - Ok((model, self.special_tokens.clone())) + Ok(self.special_tokens.clone()) } } @@ -532,9 +532,13 @@ impl Trainer for UnigramTrainer { type Model = Unigram; /// Train a Unigram model - fn train(&self, word_counts: HashMap) -> Result<(Self::Model, Vec)> { + fn train( + &self, + word_counts: HashMap, + model: &mut Unigram, + ) -> Result> { let sentences: Vec<_> = word_counts.into_iter().collect(); - self._train(sentences) + self._train(sentences, model) } /// Whether we should show progress @@ -633,11 +637,12 @@ mod tests { .build() .unwrap(); - let (unigram, _) = trainer - .train(HashMap::from_iter(vec![ - ("The".into(), 12), - ("are".into(), 11), - ])) + let mut unigram = Unigram::default(); + trainer + .train( + HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]), + &mut unigram, + ) .unwrap(); let mut pieces = unigram.iter(); @@ -657,11 +662,12 @@ mod tests { .build() .unwrap(); - let (unigram, _) = trainer - .train(HashMap::from_iter(vec![ - ("The".into(), 12), - ("are".into(), 11), - ])) + let mut unigram = Unigram::default(); + trainer + .train( + HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]), + &mut unigram, + ) .unwrap(); let mut pieces = unigram.iter(); @@ -675,11 +681,12 @@ mod tests { .build() .unwrap(); - let (unigram, _) = trainer - .train(HashMap::from_iter(vec![ - ("The".into(), 12), - ("are".into(), 11), - ])) + let mut unigram = Unigram::default(); + trainer + .train( + HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]), + &mut unigram, + ) .unwrap(); let mut pieces = unigram.iter(); @@ -697,11 +704,12 @@ mod tests { .build() .unwrap(); - let (unigram, _) = trainer - .train(HashMap::from_iter(vec![ - ("The".into(), 12), - ("are".into(), 11), - ])) + let mut unigram = Unigram::default(); + trainer + .train( + HashMap::from_iter(vec![("The".into(), 12), ("are".into(), 11)]), + &mut unigram, + ) .unwrap(); let mut pieces = unigram.iter(); diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs index 9fa2b664a..2485b57ca 100644 --- a/tokenizers/src/models/wordlevel/trainer.rs +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -25,7 +25,11 @@ impl Default for WordLevelTrainer { } impl WordLevelTrainer { - fn train(&self, word_counts: HashMap) -> Result<(WordLevel, Vec)> { + fn train( + &self, + word_counts: HashMap, + model: &mut WordLevel, + ) -> Result> { let mut ordered_counts = word_counts.into_iter().collect::>(); ordered_counts.sort_by_key(|(_, n)| std::cmp::Reverse(*n)); let word_level = WordLevel::builder() @@ -44,9 +48,13 @@ impl WordLevelTrainer { .map(|(i, w)| (w, i as u32)) .collect(), ) - .build(); + .build()?; - Ok((word_level, self.special_tokens.clone())) + // Transfer the vocab + model.vocab = word_level.vocab; + model.vocab_r = word_level.vocab_r; + + Ok(self.special_tokens.clone()) } } @@ -54,8 +62,12 @@ impl Trainer for WordLevelTrainer { type Model = WordLevel; /// Train a WordLevel model - fn train(&self, word_counts: HashMap) -> Result<(WordLevel, Vec)> { - self.train(word_counts) + fn train( + &self, + word_counts: HashMap, + model: &mut WordLevel, + ) -> Result> { + self.train(word_counts, model) } /// Whether we should show progress @@ -85,7 +97,8 @@ mod tests { let mut trainer = WordLevelTrainer::default(); trainer.vocab_size = 5; - let (model, _) = trainer.train(word_counts.clone()).unwrap(); + let mut model = WordLevel::default(); + trainer.train(word_counts.clone(), &mut model).unwrap(); let expected_vocab: HashMap = [ ("the".into(), 0), ("are".into(), 1), @@ -100,7 +113,8 @@ mod tests { // If we specify a min_frequency trainer.min_frequency = 15; - let (model, _) = trainer.train(word_counts).unwrap(); + let mut model = WordLevel::default(); + trainer.train(word_counts, &mut model).unwrap(); let expected_vocab: HashMap = [ ("the".into(), 0), ("are".into(), 1), diff --git a/tokenizers/src/models/wordpiece/trainer.rs b/tokenizers/src/models/wordpiece/trainer.rs index 1ec4b2669..2db8a78c6 100644 --- a/tokenizers/src/models/wordpiece/trainer.rs +++ b/tokenizers/src/models/wordpiece/trainer.rs @@ -1,5 +1,5 @@ use super::WordPiece; -use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder}; +use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder, BPE}; use crate::tokenizer::{AddedToken, Result, Trainer}; use std::collections::{HashMap, HashSet}; @@ -89,18 +89,34 @@ impl WordPieceTrainer { WordPieceTrainerBuilder::default() } - pub fn train(&self, word_counts: HashMap) -> Result<(WordPiece, Vec)> { - let (bpe, tokens) = self.bpe_trainer.train(word_counts)?; - Ok((WordPiece::from_bpe(&bpe), tokens)) + pub fn train( + &self, + word_counts: HashMap, + model: &mut WordPiece, + ) -> Result> { + let mut bpe = BPE::default(); + let special_tokens = self.bpe_trainer.train(word_counts, &mut bpe)?; + let new_wordpiece = WordPiece::from_bpe(&bpe); + + // Transfer the vocab + model.vocab = new_wordpiece.vocab; + model.vocab_r = new_wordpiece.vocab_r; + // The continuing_subword_prefix is the only other option to be overriden by the trainer + model.continuing_subword_prefix = new_wordpiece.continuing_subword_prefix; + + Ok(special_tokens) } } impl Trainer for WordPieceTrainer { type Model = WordPiece; - fn train(&self, word_counts: HashMap) -> Result<(WordPiece, Vec)> { - let (wp, tokens) = self.train(word_counts)?; - Ok((wp, tokens)) + fn train( + &self, + word_counts: HashMap, + model: &mut WordPiece, + ) -> Result> { + self.train(word_counts, model) } fn process_tokens(&self, mut words: &mut HashMap, tokens: Vec) { diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 93e1c1ec9..b61e420ac 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -532,7 +532,11 @@ mod tests { fn should_show_progress(&self) -> bool { true } - fn train(&self, _words: HashMap) -> Result<(ModelMock, Vec)> { + fn train( + &self, + _words: HashMap, + _model: &mut ModelMock, + ) -> Result> { unimplemented!() } } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 93504c478..74ecc315e 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -134,7 +134,8 @@ pub trait Trainer { fn train( &self, words: HashMap, - ) -> Result<(::Model, Vec)>; + model: &mut Self::Model, + ) -> Result>; /// Process a bunch of token, counting them as relevant. fn process_tokens(&self, words: &mut HashMap, tokens: Vec) { for token in tokens { @@ -1054,44 +1055,14 @@ where Ok(words) } - /// Train a model and return a new Tokenizer, using the given Trainer - pub fn train( - self, - trainer: &T, - files: Vec, - ) -> Result> - where - T: Trainer + Sync, - TM: Model, - { - let words = self.word_count(trainer, files)?; - - let (model, special_tokens) = trainer.train(words)?; - let mut new_tok = TokenizerImpl { - normalizer: self.normalizer, - pre_tokenizer: self.pre_tokenizer, - model, - post_processor: self.post_processor, - decoder: self.decoder, - added_vocabulary: self.added_vocabulary, - truncation: self.truncation, - padding: self.padding, - }; - - new_tok.add_special_tokens(&special_tokens); - - Ok(new_tok) - } - /// Train a model and replace our current Model, using the given Trainer - pub fn train_and_replace(&mut self, trainer: &T, files: Vec) -> Result<()> + pub fn train(&mut self, trainer: &T, files: Vec) -> Result<()> where T: Trainer + Sync, { let words = self.word_count(trainer, files)?; - let (model, special_tokens) = trainer.train(words)?; - self.model = model; + let special_tokens = trainer.train(words, &mut self.model)?; self.add_special_tokens(&special_tokens); Ok(()) diff --git a/tokenizers/tests/unigram.rs b/tokenizers/tests/unigram.rs index 3c3ed9b9c..e335f9d37 100644 --- a/tokenizers/tests/unigram.rs +++ b/tokenizers/tests/unigram.rs @@ -55,7 +55,8 @@ fn test_train_unigram_from_file() { .unk_token(Some("".into())) .build() .unwrap(); - let (model, _) = trainer.train(word_counts).unwrap(); + let mut model = Unigram::default(); + trainer.train(word_counts, &mut model).unwrap(); assert_eq!(model.get_vocab_size(), 719); } From b401c3b7f4b0c59f70af249568916deabeb84ba1 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 8 Oct 2020 19:33:30 -0400 Subject: [PATCH 06/11] PyModel uses a RwLock to allow modifications --- bindings/python/src/models.rs | 72 +++++++++++--------- bindings/python/src/tokenizer.rs | 7 +- bindings/python/src/trainers.rs | 3 +- tokenizers/src/models/bpe/model.rs | 12 ++-- tokenizers/src/models/mod.rs | 4 +- tokenizers/src/models/unigram/model.rs | 8 +-- tokenizers/src/models/wordlevel/mod.rs | 8 +-- tokenizers/src/models/wordpiece/mod.rs | 13 ++-- tokenizers/src/tokenizer/added_vocabulary.rs | 12 ++-- tokenizers/src/tokenizer/mod.rs | 6 +- 10 files changed, 76 insertions(+), 69 deletions(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index e2d2e6fee..0eb9bde1b 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use crate::token::PyToken; use crate::trainers::PyTrainer; @@ -24,11 +24,11 @@ use super::error::{deprecation_warning, ToPyResult}; #[derive(Clone, Serialize, Deserialize)] pub struct PyModel { #[serde(flatten)] - pub model: Arc, + pub model: Arc>, } impl PyModel { - pub(crate) fn new(model: Arc) -> Self { + pub(crate) fn new(model: Arc>) -> Self { PyModel { model } } @@ -36,7 +36,7 @@ impl PyModel { let base = self.clone(); let gil = Python::acquire_gil(); let py = gil.python(); - Ok(match self.model.as_ref() { + Ok(match *self.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base))?.into_py(py), ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base))?.into_py(py), ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base))?.into_py(py), @@ -49,31 +49,31 @@ impl Model for PyModel { type Trainer = PyTrainer; fn tokenize(&self, tokens: &str) -> tk::Result> { - self.model.tokenize(tokens) + self.model.read().unwrap().tokenize(tokens) } fn token_to_id(&self, token: &str) -> Option { - self.model.token_to_id(token) + self.model.read().unwrap().token_to_id(token) } - fn id_to_token(&self, id: u32) -> Option<&str> { - self.model.id_to_token(id) + fn id_to_token(&self, id: u32) -> Option { + self.model.read().unwrap().id_to_token(id) } - fn get_vocab(&self) -> &HashMap { - self.model.get_vocab() + fn get_vocab(&self) -> HashMap { + self.model.read().unwrap().get_vocab() } fn get_vocab_size(&self) -> usize { - self.model.get_vocab_size() + self.model.read().unwrap().get_vocab_size() } fn save(&self, folder: &Path, name: Option<&str>) -> tk::Result> { - self.model.save(folder, name) + self.model.read().unwrap().save(folder, name) } fn get_trainer(&self) -> Self::Trainer { - self.model.get_trainer().into() + self.model.read().unwrap().get_trainer().into() } } @@ -84,7 +84,7 @@ impl PyModel { // Instantiate a default empty model. This doesn't really make sense, but we need // to be able to instantiate an empty model for pickle capabilities. Ok(PyModel { - model: Arc::new(BPE::default().into()), + model: Arc::new(RwLock::new(BPE::default().into())), }) } @@ -116,7 +116,7 @@ impl PyModel { /// Tokenize the given sequence #[text_signature = "(self, tokens)"] fn tokenize(&self, tokens: &str) -> PyResult> { - Ok(ToPyResult(self.model.tokenize(tokens)) + Ok(ToPyResult(self.model.read().unwrap().tokenize(tokens)) .into_py()? .into_iter() .map(|t| t.into()) @@ -126,13 +126,13 @@ impl PyModel { /// Returns the id associated with the given token #[text_signature = "(self, tokens)"] fn token_to_id(&self, token: &str) -> Option { - self.model.token_to_id(token) + self.model.read().unwrap().token_to_id(token) } /// Returns the token associated with the given id #[text_signature = "(self, id)"] - fn id_to_token(&self, id: u32) -> Option<&str> { - self.model.id_to_token(id) + fn id_to_token(&self, id: u32) -> Option { + self.model.read().unwrap().id_to_token(id) } /// Save the current model @@ -142,7 +142,8 @@ impl PyModel { /// Any file with the same name that already exist in this folder will be overwritten. #[text_signature = "(self, folder, name)"] fn save(&self, folder: &str, name: Option<&str>) -> PyResult> { - let saved: PyResult> = ToPyResult(self.model.save(Path::new(folder), name)).into(); + let saved: PyResult> = + ToPyResult(self.model.read().unwrap().save(Path::new(folder), name)).into(); Ok(saved? .into_iter() @@ -151,7 +152,7 @@ impl PyModel { } fn get_trainer(&self) -> PyResult { - PyTrainer::from(self.model.get_trainer()).get_as_subtype() + PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype() } } @@ -219,7 +220,7 @@ impl PyBPE { "Error while initializing BPE: {}", e ))), - Ok(bpe) => Ok((PyBPE {}, PyModel::new(Arc::new(bpe.into())))), + Ok(bpe) => Ok((PyBPE {}, PyModel::new(Arc::new(RwLock::new(bpe.into()))))), } } } @@ -360,7 +361,10 @@ impl PyWordPiece { "Error while initializing WordPiece: {}", e ))), - Ok(wordpiece) => Ok((PyWordPiece {}, PyModel::new(Arc::new(wordpiece.into())))), + Ok(wordpiece) => Ok(( + PyWordPiece {}, + PyModel::new(Arc::new(RwLock::new(wordpiece.into()))), + )), } } } @@ -476,11 +480,14 @@ impl PyWordLevel { } }; - Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into())))) + Ok(( + PyWordLevel {}, + PyModel::new(Arc::new(RwLock::new(model.into()))), + )) } else { Ok(( PyWordLevel {}, - PyModel::new(Arc::new(WordLevel::default().into())), + PyModel::new(Arc::new(RwLock::new(WordLevel::default().into()))), )) } } @@ -523,11 +530,14 @@ impl PyUnigram { let model = Unigram::from(vocab, unk_id).map_err(|e| { exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e)) })?; - Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))) + Ok(( + PyUnigram {}, + PyModel::new(Arc::new(RwLock::new(model.into()))), + )) } (None, None) => Ok(( PyUnigram {}, - PyModel::new(Arc::new(Unigram::default().into())), + PyModel::new(Arc::new(RwLock::new(Unigram::default().into()))), )), _ => Err(exceptions::PyValueError::new_err( "`vocab` and `unk_id` must be both specified", @@ -540,13 +550,13 @@ impl PyUnigram { mod test { use crate::models::PyModel; use pyo3::prelude::*; - use std::sync::Arc; + use std::sync::{Arc, RwLock}; use tk::models::bpe::BPE; use tk::models::ModelWrapper; #[test] fn get_subtype() { - let py_model = PyModel::new(Arc::new(BPE::default().into())); + let py_model = PyModel::new(Arc::new(RwLock::new(BPE::default().into()))); let py_bpe = py_model.get_as_subtype().unwrap(); let gil = Python::acquire_gil(); assert_eq!( @@ -562,19 +572,19 @@ mod test { let rs_wrapper: ModelWrapper = rs_bpe.into(); let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap(); - let py_model = PyModel::new(Arc::new(rs_wrapper)); + let py_model = PyModel::new(Arc::new(RwLock::new(rs_wrapper))); let py_ser = serde_json::to_string(&py_model).unwrap(); assert_eq!(py_ser, rs_bpe_ser); assert_eq!(py_ser, rs_wrapper_ser); let py_model: PyModel = serde_json::from_str(&rs_bpe_ser).unwrap(); - match py_model.model.as_ref() { + match *py_model.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => (), _ => panic!("Expected Bert postprocessor."), } let py_model: PyModel = serde_json::from_str(&rs_wrapper_ser).unwrap(); - match py_model.model.as_ref() { + match *py_model.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => (), _ => panic!("Expected Bert postprocessor."), } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 25de9bbc5..cc8f2ed9d 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,5 +1,5 @@ use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use numpy::PyArray1; use pyo3::exceptions; @@ -457,7 +457,8 @@ impl PyTokenizer { } fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> { - let model: PyObject = PyModel::new(Arc::new(BPE::default().into())).into_py(py); + let model: PyObject = + PyModel::new(Arc::new(RwLock::new(BPE::default().into()))).into_py(py); let args = PyTuple::new(py, vec![model]); Ok(args) } @@ -965,7 +966,7 @@ impl PyTokenizer { /// Returns: /// :obj:`Optional[str]`: An optional token, :obj:`None` if out of vocabulary #[text_signature = "(self, id)"] - fn id_to_token(&self, id: u32) -> Option<&str> { + fn id_to_token(&self, id: u32) -> Option { self.tokenizer.id_to_token(id) } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 58c35d330..404adc656 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -89,8 +89,7 @@ impl Trainer for PyTrainer { words: HashMap, model: &mut PyModel, ) -> tk::Result> { - todo!("FIX THIS"); - self.trainer.train(words, &mut model.model) + self.trainer.train(words, &mut model.model.write().unwrap()) } fn process_tokens(&self, words: &mut HashMap, tokens: Vec) { diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 15c6f679f..64eb413bd 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -318,8 +318,8 @@ impl BPE { } } - pub fn get_vocab(&self) -> &Vocab { - &self.vocab + pub fn get_vocab(&self) -> Vocab { + self.vocab.clone() } pub fn get_unk_token(&self) -> &Option { @@ -417,8 +417,8 @@ impl BPE { impl Model for BPE { type Trainer = BpeTrainer; - fn get_vocab(&self) -> &HashMap { - &self.vocab + fn get_vocab(&self) -> HashMap { + self.vocab.clone() } fn get_vocab_size(&self) -> usize { @@ -442,8 +442,8 @@ impl Model for BPE { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option<&str> { - self.vocab_r.get(&id).map(String::as_ref) + fn id_to_token(&self, id: u32) -> Option { + self.vocab_r.get(&id).cloned() } fn save(&self, folder: &Path, name: Option<&str>) -> Result> { diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 3ee88ebfd..44a0d17ca 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -75,7 +75,7 @@ impl Model for ModelWrapper { } } - fn id_to_token(&self, id: u32) -> Option<&str> { + fn id_to_token(&self, id: u32) -> Option { use ModelWrapper::*; match self { WordLevel(t) => t.id_to_token(id), @@ -85,7 +85,7 @@ impl Model for ModelWrapper { } } - fn get_vocab(&self) -> &HashMap { + fn get_vocab(&self) -> HashMap { use ModelWrapper::*; match self { WordLevel(t) => t.get_vocab(), diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index fc6d082e7..9a37a7313 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -409,8 +409,8 @@ impl<'a> Iterator for UnigramIterator<'a> { impl Model for Unigram { type Trainer = UnigramTrainer; - fn get_vocab(&self) -> &HashMap { - &self.token_to_ids + fn get_vocab(&self) -> HashMap { + self.token_to_ids.clone() } fn get_vocab_size(&self) -> usize { @@ -438,9 +438,9 @@ impl Model for Unigram { self.token_to_ids.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option<&str> { + fn id_to_token(&self, id: u32) -> Option { match self.vocab.get(id as usize) { - Some(item) => Some(&item.0), + Some(item) => Some(item.0.clone()), None => None, } } diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 2918f8bb0..5d1a4a1a7 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -183,12 +183,12 @@ impl Model for WordLevel { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option<&str> { - self.vocab_r.get(&id).map(String::as_ref) + fn id_to_token(&self, id: u32) -> Option { + self.vocab_r.get(&id).cloned() } - fn get_vocab(&self) -> &HashMap { - &self.vocab + fn get_vocab(&self) -> HashMap { + self.vocab.clone() } fn get_vocab_size(&self) -> usize { diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index a3b2e6e04..990cfeabd 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -185,10 +185,7 @@ impl WordPiece { /// Create a `WordPiece` model from a `BPE` model. pub fn from_bpe(bpe: &BPE) -> Self { - let mut wp = Self::builder() - .vocab(bpe.get_vocab().clone()) - .build() - .unwrap(); + let mut wp = Self::builder().vocab(bpe.get_vocab()).build().unwrap(); if let Some(unk) = bpe.get_unk_token() { wp.unk_token = unk.to_owned(); } @@ -202,8 +199,8 @@ impl WordPiece { impl Model for WordPiece { type Trainer = WordPieceTrainer; - fn get_vocab(&self) -> &HashMap { - &self.vocab + fn get_vocab(&self) -> HashMap { + self.vocab.clone() } fn get_vocab_size(&self) -> usize { @@ -275,8 +272,8 @@ impl Model for WordPiece { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option<&str> { - self.vocab_r.get(&id).map(String::as_ref) + fn id_to_token(&self, id: u32) -> Option { + self.vocab_r.get(&id).cloned() } fn save(&self, folder: &Path, name: Option<&str>) -> Result> { diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index b61e420ac..09193bac2 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -202,10 +202,10 @@ impl AddedVocabulary { } /// Get the token matching the given id if it exists - pub fn id_to_token<'s>(&'s self, id: u32, model: &'s impl Model) -> Option<&'s str> { + pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option { self.added_tokens_map_r .get(&id) - .map(|t| t.content.as_ref()) + .map(|t| t.content.clone()) .or_else(|| model.id_to_token(id)) } @@ -550,11 +550,11 @@ mod tests { fn token_to_id(&self, token: &str) -> Option { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option<&str> { - self.vocab_r.get(&id).map(String::as_ref) + fn id_to_token(&self, id: u32) -> Option { + self.vocab_r.get(&id).cloned() } - fn get_vocab(&self) -> &HashMap { - &self.vocab + fn get_vocab(&self) -> HashMap { + self.vocab.clone() } fn get_vocab_size(&self) -> usize { self.vocab.len() diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 74ecc315e..08a656d82 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -75,9 +75,9 @@ pub trait Model { /// Find the ID associated to a string token fn token_to_id(&self, token: &str) -> Option; /// Find the string token associated to an ID - fn id_to_token(&self, id: u32) -> Option<&str>; + fn id_to_token(&self, id: u32) -> Option; /// Retrieve the entire vocabulary mapping (token -> ID) - fn get_vocab(&self) -> &HashMap; + fn get_vocab(&self) -> HashMap; /// Retrieve the size of the vocabulary fn get_vocab_size(&self) -> usize; /// Save the current `Model` in the given folder, using the given `prefix` for the various @@ -616,7 +616,7 @@ where } /// Converts an id to the corresponding token. - pub fn id_to_token(&self, id: u32) -> Option<&str> { + pub fn id_to_token(&self, id: u32) -> Option { self.added_vocabulary.id_to_token(id, &self.model) } From 3983bce6111f51a3852abff1b2db0f6e06921b94 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 10 Nov 2020 11:01:56 -0500 Subject: [PATCH 07/11] Test BPE keeping its options after training --- bindings/python/src/models.rs | 4 +-- bindings/python/src/tokenizer.rs | 10 ++++--- .../python/tests/bindings/test_trainers.py | 4 +-- tokenizers/src/models/bpe/model.rs | 10 +++---- tokenizers/src/tokenizer/mod.rs | 4 +-- tokenizers/tests/documentation.rs | 6 ++-- tokenizers/tests/training.rs | 29 +++++++++++++++++++ 7 files changed, 49 insertions(+), 18 deletions(-) create mode 100644 tokenizers/tests/training.rs diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 0eb9bde1b..d1f5e5461 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -581,12 +581,12 @@ mod test { match *py_model.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => (), _ => panic!("Expected Bert postprocessor."), - } + }; let py_model: PyModel = serde_json::from_str(&rs_wrapper_ser).unwrap(); match *py_model.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => (), _ => panic!("Expected Bert postprocessor."), - } + }; } } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index cc8f2ed9d..d5cfd9945 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1045,7 +1045,9 @@ impl PyTokenizer { let trainer = trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); Python::with_gil(|py| { - py.allow_threads(|| ToPyResult(self.tokenizer.train(&trainer, files)).into()) + py.allow_threads(|| { + ToPyResult(self.tokenizer.train(&trainer, files).map(|_| {})).into() + }) }) } @@ -1173,15 +1175,15 @@ mod test { use super::*; use crate::models::PyModel; use crate::normalizers::{PyNormalizer, PyNormalizerTypeWrapper}; - use std::sync::Arc; + use std::sync::{Arc, RwLock}; use tempfile::NamedTempFile; use tk::normalizers::{Lowercase, NFKC}; #[test] fn serialize() { - let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new( + let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new(RwLock::new( tk::models::bpe::BPE::default().into(), - ))); + )))); tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ Arc::new(NFKC.into()), Arc::new(Lowercase.into()), diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py index c8f092d79..ee249557e 100644 --- a/bindings/python/tests/bindings/test_trainers.py +++ b/bindings/python/tests/bindings/test_trainers.py @@ -41,7 +41,7 @@ def pre_tokenize(self, pretok): del os.environ["TOKENIZERS_PARALLELISM"] trainer = trainers.BpeTrainer(special_tokens=[""], show_progress=False) - bpe_tokenizer.train(trainer, [train_files["small"]]) + bpe_tokenizer.train([train_files["small"]], trainer=trainer) def test_train_with_special_tokens(self): filename = "tests/data/dummy-unigram-special_tokens-train.txt" @@ -76,7 +76,7 @@ def test_train_with_special_tokens(self): show_progress=False, special_tokens=["[PAD]", "[SEP]", "[CLS]"], unk_token="[UNK]" ) - tokenizer.train(trainer, [filename]) + tokenizer.train([filename], trainer=trainer) assert tokenizer.encode("[CLS] This is a test [SEP]").tokens == [ "[CLS]", diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 64eb413bd..2ffba7118 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -189,15 +189,15 @@ pub struct BPE { cache: Option>, /// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will /// perform no merges, so the result will just be characters. - pub(super) dropout: Option, + pub dropout: Option, /// The unknown token to be used when we encounter an unknown char - pub(super) unk_token: Option, + pub unk_token: Option, /// An optional prefix to use on any subword that exist only behind another one - pub(super) continuing_subword_prefix: Option, + pub continuing_subword_prefix: Option, /// An optional suffix to caracterize and end-of-word subword - pub(super) end_of_word_suffix: Option, + pub end_of_word_suffix: Option, /// Do multiple unk tokens get fused - pub(super) fuse_unk: bool, + pub fuse_unk: bool, } impl std::fmt::Debug for BPE { diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 08a656d82..be62b12a0 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1056,7 +1056,7 @@ where } /// Train a model and replace our current Model, using the given Trainer - pub fn train(&mut self, trainer: &T, files: Vec) -> Result<()> + pub fn train(&mut self, trainer: &T, files: Vec) -> Result<&mut Self> where T: Trainer + Sync, { @@ -1065,7 +1065,7 @@ where let special_tokens = trainer.train(words, &mut self.model)?; self.add_special_tokens(&special_tokens); - Ok(()) + Ok(self) } } diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 9353d76cf..2520ffa63 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -8,7 +8,7 @@ use tokenizers::{Tokenizer, TokenizerImpl}; #[test] fn train_tokenizer() { let vocab_size: usize = 100; - let tokenizer = TokenizerBuilder::new() + let mut tokenizer = TokenizerBuilder::new() .with_model(BPE::default()) .with_normalizer(Some(Sequence::new(vec![ Strip::new(true, true).into(), @@ -97,7 +97,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> { "data/wikitext-103-raw/wiki.test.raw".into(), "data/wikitext-103-raw/wiki.valid.raw".into(), ]; - tokenizer.train_and_replace(&trainer, files)?; + tokenizer.train(&trainer, files)?; // END quicktour_train // START quicktour_reload_model use std::path::Path; @@ -427,7 +427,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { "data/wikitext-103-raw/wiki.test.raw".into(), "data/wikitext-103-raw/wiki.valid.raw".into(), ]; - bert_tokenizer.train_and_replace(&trainer, files)?; + bert_tokenizer.train(&trainer, files)?; let model_files = bert_tokenizer .get_model() diff --git a/tokenizers/tests/training.rs b/tokenizers/tests/training.rs new file mode 100644 index 000000000..1454a48b5 --- /dev/null +++ b/tokenizers/tests/training.rs @@ -0,0 +1,29 @@ +use tokenizers::models::bpe::BPE; +use tokenizers::{DecoderWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper}; +use tokenizers::{Model, TokenizerBuilder}; + +#[test] +fn bpe_values_after_training() { + let mut tokenizer = TokenizerBuilder::< + BPE, + NormalizerWrapper, + PreTokenizerWrapper, + PostProcessorWrapper, + DecoderWrapper, + >::default() + .with_model( + BPE::builder() + .unk_token("[UNK]".to_string()) + .dropout(0.1) + .build() + .unwrap(), + ) + .build() + .unwrap(); + let trainer = tokenizer.get_model().get_trainer(); + tokenizer + .train(&trainer, vec!["./data/small.txt".to_string()]) + .unwrap(); + assert_eq!(tokenizer.get_model().dropout, Some(0.1)); + assert_eq!(tokenizer.get_model().unk_token, Some("[UNK]".to_string())); +} From e70a66c37c4cbf2289388643eed430ab5d5ade08 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 19 Nov 2020 17:57:58 -0500 Subject: [PATCH 08/11] Generate pyi, fix tests and clippy warnings --- .../py_src/tokenizers/trainers/__init__.pyi | 21 ++++++++++ .../tests/documentation/test_pipeline.py | 15 ++++--- .../tests/documentation/test_quicktour.py | 17 +++++--- docs/source/quicktour.rst | 30 +------------ tokenizers/README.md | 2 +- tokenizers/src/tokenizer/mod.rs | 3 +- tokenizers/tests/documentation.rs | 42 ++++++------------- 7 files changed, 56 insertions(+), 74 deletions(-) diff --git a/bindings/python/py_src/tokenizers/trainers/__init__.pyi b/bindings/python/py_src/tokenizers/trainers/__init__.pyi index 6d0ad3041..189b87934 100644 --- a/bindings/python/py_src/tokenizers/trainers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/trainers/__init__.pyi @@ -83,6 +83,27 @@ class UnigramTrainer(Trainer): def __init__(self, vocab_size=8000, show_progress=True, special_tokens=[]): pass +class WordLevelTrainer(Trainer): + """ + Capable of training a WorldLevel model + + Args: + vocab_size: unsigned int: + The size of the final vocabulary, including all tokens and alphabet. + + min_frequency: unsigned int: + The minimum frequency a pair should have in order to be merged. + + show_progress: boolean: + Whether to show progress bars while training. + + special_tokens: List[Union[str, AddedToken]]: + A list of special tokens the model should know of. + + Returns: + Trainer + """ + class WordPieceTrainer(Trainer): """ Capable of training a WordPiece model diff --git a/bindings/python/tests/documentation/test_pipeline.py b/bindings/python/tests/documentation/test_pipeline.py index 6a0f4626d..46aedf71e 100644 --- a/bindings/python/tests/documentation/test_pipeline.py +++ b/bindings/python/tests/documentation/test_pipeline.py @@ -2,8 +2,13 @@ from tokenizers import Tokenizer +disable_printing = True +original_print = print + + def print(*args, **kwargs): - pass + if not disable_printing: + original_print(*args, **kwargs) class TestPipeline: @@ -103,7 +108,7 @@ def slow_train(): from tokenizers import Tokenizer from tokenizers.models import WordPiece - bert_tokenizer = Tokenizer(WordPiece()) + bert_tokenizer = Tokenizer(WordPiece(unk_token="[UNK]")) # END bert_setup_tokenizer # START bert_setup_normalizer from tokenizers import normalizers @@ -135,10 +140,7 @@ def slow_train(): vocab_size=30522, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] ) files = [f"data/wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]] - bert_tokenizer.train(trainer, files) - - model_files = bert_tokenizer.model.save("data", "bert-wiki") - bert_tokenizer.model = WordPiece.from_file(*model_files, unk_token="[UNK]") + bert_tokenizer.train(files, trainer) bert_tokenizer.save("data/bert-wiki.json") # END bert_train_tokenizer @@ -173,6 +175,7 @@ def test_bert_example(self, doc_pipeline_bert_tokenizer): from zipfile import ZipFile import os + disable_printing = False if not os.path.isdir("data/wikitext-103-raw"): print("Downloading wikitext-103...") wiki_text, _ = request.urlretrieve( diff --git a/bindings/python/tests/documentation/test_quicktour.py b/bindings/python/tests/documentation/test_quicktour.py index e6474e19a..b85947ed9 100644 --- a/bindings/python/tests/documentation/test_quicktour.py +++ b/bindings/python/tests/documentation/test_quicktour.py @@ -4,6 +4,14 @@ from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace +disable_printing = True +original_print = print + + +def print(*args, **kwargs): + if not disable_printing: + original_print(*args, **kwargs) + class TestQuicktour: # This method contains everything we don't want to run @@ -13,12 +21,8 @@ def slow_train(): # START train files = [f"data/wikitext-103-raw/wiki.{split}.raw" for split in ["test", "train", "valid"]] - tokenizer.train(trainer, files) + tokenizer.train(files, trainer) # END train - # START reload_model - files = tokenizer.model.save("data", "wiki") - tokenizer.model = BPE.from_file(*files, unk_token="[UNK]") - # END reload_model # START save tokenizer.save("data/tokenizer-wiki.json") # END save @@ -29,7 +33,7 @@ def get_tokenizer_trainer(): from tokenizers import Tokenizer from tokenizers.models import BPE - tokenizer = Tokenizer(BPE()) + tokenizer = Tokenizer(BPE(unk_token="[UNK]")) # END init_tokenizer # START init_trainer from tokenizers.trainers import BpeTrainer @@ -181,6 +185,7 @@ def print(*args, **kwargs): from zipfile import ZipFile import os + disable_printing = False if not os.path.isdir("data/wikitext-103-raw"): print("Downloading wikitext-103...") wiki_text, _ = request.urlretrieve( diff --git a/docs/source/quicktour.rst b/docs/source/quicktour.rst index 301c79957..309d83dee 100644 --- a/docs/source/quicktour.rst +++ b/docs/source/quicktour.rst @@ -202,35 +202,7 @@ to use: :end-before: END train :dedent: 8 -This should only take a few seconds to train our tokenizer on the full wikitext dataset! Once this -is done, we need to save the model and reinstantiate it with the unknown token, or this token won't -be used. This will be simplified in a further release, to let you set the :entity:`unk_token` when -first instantiating the model. - -.. only:: python - - .. literalinclude:: ../../bindings/python/tests/documentation/test_quicktour.py - :language: python - :start-after: START reload_model - :end-before: END reload_model - :dedent: 8 - -.. only:: rust - - .. literalinclude:: ../../tokenizers/tests/documentation.rs - :language: rust - :start-after: START quicktour_reload_model - :end-before: END quicktour_reload_model - :dedent: 4 - -.. only:: node - - .. literalinclude:: ../../bindings/node/examples/documentation/quicktour.test.ts - :language: javascript - :start-after: START reload_model - :end-before: END reload_model - :dedent: 8 - +This should only take a few seconds to train our tokenizer on the full wikitext dataset! To save the tokenizer in one file that contains all its configuration and vocabulary, just use the :entity:`Tokenizer.save` method: diff --git a/tokenizers/README.md b/tokenizers/README.md index fa473b6d5..97bce6fd6 100644 --- a/tokenizers/README.md +++ b/tokenizers/README.md @@ -84,7 +84,7 @@ fn main() -> Result<()> { ]) .build(); - let tokenizer = TokenizerBuilder::new() + let mut tokenizer = TokenizerBuilder::new() .with_model(BPE::default()) .with_normalizer(Some(Sequence::new(vec![ Strip::new(true, true).into(), diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index be62b12a0..a2b62a94d 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -585,7 +585,7 @@ where /// Get the vocabulary pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap { - let mut final_vocab = self.model.get_vocab().clone(); + let mut final_vocab = self.model.get_vocab(); if with_added_tokens { let added_vocab = self.added_vocabulary.get_vocab(); @@ -763,7 +763,6 @@ where .filter(|token| { !skip_special_tokens || !self.added_vocabulary.is_special_token(token) }) - .map(|t| t.to_owned()) }) .collect::>(); diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 2520ffa63..e4f1c2599 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -70,7 +70,12 @@ fn quicktour_slow_train() -> tokenizers::Result<()> { PreTokenizerWrapper, PostProcessorWrapper, DecoderWrapper, - > = TokenizerImpl::new(BPE::default()); + > = TokenizerImpl::new( + BPE::builder() + .unk_token("[UNK]".to_string()) + .build() + .unwrap(), + ); // END quicktour_init_tokenizer // START quicktour_init_trainer use tokenizers::models::bpe::BpeTrainer; @@ -99,22 +104,6 @@ fn quicktour_slow_train() -> tokenizers::Result<()> { ]; tokenizer.train(&trainer, files)?; // END quicktour_train - // START quicktour_reload_model - use std::path::Path; - use tokenizers::Model; - - let saved_files = tokenizer - .get_model() - .save(&Path::new("data"), Some("wiki"))?; - tokenizer.with_model( - BPE::from_file( - saved_files[0].to_str().unwrap(), - &saved_files[1].to_str().unwrap(), - ) - .unk_token("[UNK]".to_string()) - .build()?, - ); - // END quicktour_reload_model // START quicktour_save tokenizer.save("data/tokenizer-wiki.json", false)?; // END quicktour_save @@ -375,7 +364,12 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { use tokenizers::models::wordpiece::WordPiece; use tokenizers::Tokenizer; - let mut bert_tokenizer = Tokenizer::new(WordPiece::default()); + let mut bert_tokenizer = Tokenizer::new( + WordPiece::builder() + .unk_token("[UNK]".to_string()) + .build() + .unwrap(), + ); // END bert_setup_tokenizer // START bert_setup_normalizer use tokenizers::normalizers::utils::Sequence as NormalizerSequence; @@ -407,9 +401,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { ); // END bert_setup_processor // START bert_train_tokenizer - use std::path::Path; use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper}; - use tokenizers::Model; let trainer: TrainerWrapper = WordPieceTrainer::builder() .vocab_size(30_522) @@ -429,16 +421,6 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { ]; bert_tokenizer.train(&trainer, files)?; - let model_files = bert_tokenizer - .get_model() - .save(&Path::new("data"), Some("bert-wiki"))?; - bert_tokenizer.with_model( - WordPiece::from_file(model_files[0].to_str().unwrap()) - .unk_token("[UNK]".to_string()) - .build() - .unwrap(), - ); - bert_tokenizer.save("data/bert-wiki.json", false)?; // END bert_train_tokenizer Ok(()) From 22dabea5d16a559fb3ea113d6e4d4ce95c128271 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 19 Nov 2020 19:57:50 -0500 Subject: [PATCH 09/11] Node - Trainers train the Model in-place --- .../examples/documentation/pipeline.test.ts | 11 +--- .../examples/documentation/quicktour.test.ts | 13 +--- bindings/node/native/src/models.rs | 61 +++++++++++++------ bindings/node/native/src/tasks/models.rs | 8 +-- bindings/node/native/src/tokenizer.rs | 24 ++++++-- 5 files changed, 69 insertions(+), 48 deletions(-) diff --git a/bindings/node/examples/documentation/pipeline.test.ts b/bindings/node/examples/documentation/pipeline.test.ts index c7d963e30..f36237105 100644 --- a/bindings/node/examples/documentation/pipeline.test.ts +++ b/bindings/node/examples/documentation/pipeline.test.ts @@ -94,7 +94,7 @@ describe("pipelineExample", () => { let { Tokenizer } = require("tokenizers/bindings/tokenizer"); let { WordPiece } = require("tokenizers/bindings/models"); - let bertTokenizer = new Tokenizer(WordPiece.empty()); + let bertTokenizer = new Tokenizer(WordPiece.init({}, { unkToken: "[UNK]" })); // END bert_setup_tokenizer // START bert_setup_normalizer let { sequenceNormalizer, lowercaseNormalizer, nfdNormalizer, stripAccentsNormalizer } @@ -120,20 +120,13 @@ describe("pipelineExample", () => { // END bert_setup_processor // START bert_train_tokenizer let { wordPieceTrainer } = require("tokenizers/bindings/trainers"); - let { promisify } = require("util"); let trainer = wordPieceTrainer({ vocabSize: 30522, specialTokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] }); let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`); - bertTokenizer.train(trainer, files); - - let modelFiles = bertTokenizer.getModel().save("data", "bert-wiki"); - let fromFile = promisify(WordPiece.fromFile); - bertTokenizer.setModel(await fromFile(modelFiles[0], { - unkToken: "[UNK]" - })); + bertTokenizer.train(files, trainer); bertTokenizer.save("data/bert-wiki.json") // END bert_train_tokenizer diff --git a/bindings/node/examples/documentation/quicktour.test.ts b/bindings/node/examples/documentation/quicktour.test.ts index f144efec4..a91964b0b 100644 --- a/bindings/node/examples/documentation/quicktour.test.ts +++ b/bindings/node/examples/documentation/quicktour.test.ts @@ -16,7 +16,7 @@ describe("quicktourExample", () => { let { Tokenizer } = require("tokenizers/bindings/tokenizer"); let { BPE } = require("tokenizers/bindings/models"); - let tokenizer = new Tokenizer(BPE.empty()); + let tokenizer = new Tokenizer(BPE.init({}, [], { unkToken: "[UNK]" })); // END init_tokenizer // START init_trainer let { bpeTrainer } = require("tokenizers/bindings/trainers"); @@ -32,17 +32,8 @@ describe("quicktourExample", () => { // END init_pretok // START train let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`); - tokenizer.train(trainer, files); + tokenizer.train(files, trainer); // END train - // START reload_model - let { promisify } = require("util"); - - let modelFiles = tokenizer.getModel().save("data", "wiki"); - let fromFile = promisify(BPE.fromFile); - tokenizer.setModel(await fromFile(modelFiles[0], modelFiles[1], { - unkToken: "[UNK]" - })); - // END reload_model // START save tokenizer.save("data/tokenizer-wiki.json"); // END save diff --git a/bindings/node/native/src/models.rs b/bindings/node/native/src/models.rs index c1e45dbbe..a64187e7b 100644 --- a/bindings/node/native/src/models.rs +++ b/bindings/node/native/src/models.rs @@ -2,11 +2,12 @@ extern crate tokenizers as tk; use crate::extraction::*; use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask}; +use crate::trainers::Trainer; use neon::prelude::*; use std::collections::HashMap; use std::path::Path; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use tk::models::{ bpe::{BpeBuilder, Merges, Vocab}, @@ -21,37 +22,46 @@ use tk::Token; #[derive(Clone, Serialize, Deserialize)] pub struct Model { #[serde(flatten)] - pub model: Option>, + pub model: Option>>, } -impl From for Model { - fn from(wrapper: ModelWrapper) -> Self { +impl From for Model +where + M: Into, +{ + fn from(wrapper: M) -> Self { Self { - model: Some(Arc::new(wrapper)), + model: Some(Arc::new(RwLock::new(wrapper.into()))), } } } impl tk::Model for Model { + type Trainer = Trainer; + fn tokenize(&self, sequence: &str) -> tk::Result> { self.model .as_ref() .ok_or("Uninitialized Model")? + .read() + .unwrap() .tokenize(sequence) } fn token_to_id(&self, token: &str) -> Option { - self.model.as_ref()?.token_to_id(token) + self.model.as_ref()?.read().unwrap().token_to_id(token) } - fn id_to_token(&self, id: u32) -> Option<&str> { - self.model.as_ref()?.id_to_token(id) + fn id_to_token(&self, id: u32) -> Option { + self.model.as_ref()?.read().unwrap().id_to_token(id) } - fn get_vocab(&self) -> &HashMap { + fn get_vocab(&self) -> HashMap { self.model .as_ref() .expect("Uninitialized Model") + .read() + .unwrap() .get_vocab() } @@ -59,6 +69,8 @@ impl tk::Model for Model { self.model .as_ref() .expect("Uninitialized Model") + .read() + .unwrap() .get_vocab_size() } @@ -66,8 +78,20 @@ impl tk::Model for Model { self.model .as_ref() .ok_or("Uninitialized Model")? + .read() + .unwrap() .save(folder, name) } + + fn get_trainer(&self) -> Self::Trainer { + self.model + .as_ref() + .expect("Uninitialized Model") + .read() + .unwrap() + .get_trainer() + .into() + } } declare_types! { @@ -86,7 +110,8 @@ declare_types! { let guard = cx.lock(); let files = this.borrow(&guard) - .model.as_ref().unwrap() + .model.as_ref().expect("Uninitialized Model") + .read().unwrap() .save( Path::new(&folder), name.as_deref() @@ -153,7 +178,7 @@ fn bpe_init(mut cx: FunctionContext) -> JsResult { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(model.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into()))); Ok(js_model) } @@ -191,7 +216,7 @@ fn bpe_empty(mut cx: FunctionContext) -> JsResult { let bpe = tk::models::bpe::BPE::default(); let guard = cx.lock(); - model.borrow_mut(&guard).model = Some(Arc::new(bpe.into())); + model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(bpe.into()))); Ok(model) } @@ -236,7 +261,7 @@ fn wordpiece_init(mut cx: FunctionContext) -> JsResult { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(model.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into()))); Ok(js_model) } @@ -270,7 +295,7 @@ fn wordpiece_empty(mut cx: FunctionContext) -> JsResult { let wordpiece = tk::models::wordpiece::WordPiece::default(); let guard = cx.lock(); - model.borrow_mut(&guard).model = Some(Arc::new(wordpiece.into())); + model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordpiece.into()))); Ok(model) } @@ -305,7 +330,7 @@ fn wordlevel_init(mut cx: FunctionContext) -> JsResult { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(model.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into()))); Ok(js_model) } @@ -337,7 +362,7 @@ fn wordlevel_empty(mut cx: FunctionContext) -> JsResult { let wordlevel = tk::models::wordlevel::WordLevel::default(); let guard = cx.lock(); - model.borrow_mut(&guard).model = Some(Arc::new(wordlevel.into())); + model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordlevel.into()))); Ok(model) } @@ -362,7 +387,7 @@ fn unigram_init(mut cx: FunctionContext) -> JsResult { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(unigram.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(unigram.into()))); Ok(js_model) } @@ -373,7 +398,7 @@ fn unigram_empty(mut cx: FunctionContext) -> JsResult { let unigram = tk::models::unigram::Unigram::default(); let guard = cx.lock(); - model.borrow_mut(&guard).model = Some(Arc::new(unigram.into())); + model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(unigram.into()))); Ok(model) } diff --git a/bindings/node/native/src/tasks/models.rs b/bindings/node/native/src/tasks/models.rs index cc266e289..42c3fcede 100644 --- a/bindings/node/native/src/tasks/models.rs +++ b/bindings/node/native/src/tasks/models.rs @@ -2,7 +2,7 @@ extern crate tokenizers as tk; use crate::models::*; use neon::prelude::*; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use tk::models::bpe::{BpeBuilder, BPE}; use tk::models::wordlevel::{WordLevel, WordLevelBuilder}; use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; @@ -34,7 +34,7 @@ impl Task for WordPieceFromFilesTask { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(wordpiece.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordpiece.into()))); Ok(js_model.upcast()) } @@ -67,7 +67,7 @@ impl Task for WordLevelFromFilesTask { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(wordlevel.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordlevel.into()))); Ok(js_model.upcast()) } @@ -100,7 +100,7 @@ impl Task for BPEFromFilesTask { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(bpe.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(bpe.into()))); Ok(js_model.upcast()) } diff --git a/bindings/node/native/src/tokenizer.rs b/bindings/node/native/src/tokenizer.rs index 50d5054be..0d5264958 100644 --- a/bindings/node/native/src/tokenizer.rs +++ b/bindings/node/native/src/tokenizer.rs @@ -12,6 +12,7 @@ use crate::trainers::JsTrainer; use neon::prelude::*; use std::sync::{Arc, RwLock}; +use tk::Model as ModelTrait; use tk::TokenizerImpl; // AddedToken @@ -634,7 +635,7 @@ declare_types! { let guard = cx.lock(); let token = this.borrow(&guard) .tokenizer.read().unwrap() - .id_to_token(id).map(|t| t.to_owned()); + .id_to_token(id); if let Some(token) = token { Ok(cx.string(token).upcast()) @@ -745,18 +746,29 @@ declare_types! { } method train(mut cx) { - // train(trainer: JsTrainer, files: string[]) + // train(files: string[], trainer?: Trainer) - let trainer = cx.argument::(0)?; - let files = cx.extract::>(1)?; + let files = cx.extract::>(0)?; + let trainer = if let Some(val) = cx.argument_opt(1) { + let js_trainer = val.downcast::().or_throw(&mut cx)?; + let guard = cx.lock(); + + let trainer = js_trainer.borrow(&guard).clone(); + trainer + } else { + let this = cx.this(); + let guard = cx.lock(); + + let trainer = this.borrow(&guard).tokenizer.read().unwrap().get_model().get_trainer(); + trainer + }; let mut this = cx.this(); let guard = cx.lock(); - let trainer = trainer.borrow(&guard).clone(); this.borrow_mut(&guard) .tokenizer.write().unwrap() - .train_and_replace(&trainer, files) + .train(&trainer, files) .map_err(|e| Error(format!("{}", e)))?; Ok(cx.undefined().upcast()) From 1c45b1ceaa211113279c2afd1e51073bc6a1c4dd Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 19 Nov 2020 20:01:28 -0500 Subject: [PATCH 10/11] Node - Add WordLevelTrainer --- bindings/node/lib/bindings/trainers.d.ts | 29 ++++++ bindings/node/lib/bindings/trainers.js | 1 + bindings/node/native/src/trainers.rs | 107 ++++++++++++++++++++- bindings/python/src/trainers.rs | 58 ++++++----- tokenizers/src/models/wordlevel/trainer.rs | 17 +++- 5 files changed, 180 insertions(+), 32 deletions(-) diff --git a/bindings/node/lib/bindings/trainers.d.ts b/bindings/node/lib/bindings/trainers.d.ts index a08e6e144..357c096ac 100644 --- a/bindings/node/lib/bindings/trainers.d.ts +++ b/bindings/node/lib/bindings/trainers.d.ts @@ -63,6 +63,35 @@ export function bpeTrainer(options?: TrainerOptions): Trainer; */ export function wordPieceTrainer(options?: TrainerOptions): Trainer; +export interface WordLevelTrainerOptions { + /** + * The minimum frequency a pair should have in order to be merged. + * @default 2 + */ + minFrequency?: number; + /** + * Whether to show progress bars while training. + * @default true + */ + showProgress?: boolean; + /** + * A list of special tokens the model should know of. + * @default [] + */ + specialTokens?: (string | AddedToken)[]; + /** + * The size of the final vocabulary, including all tokens and alphabet. + * @default 30000 + */ + vocabSize?: number; +} + +/** + * Instantiate a new WordLevel Trainer + * @param [options] WordLevel Trainer options + */ +export function wordLevelTrainer(options?: WordLevelTrainerOptions): Trainer; + export interface UnigramTrainerOptions { vocabSize?: number; nSubIterations?: number; diff --git a/bindings/node/lib/bindings/trainers.js b/bindings/node/lib/bindings/trainers.js index 9521d1532..1a6d019b9 100644 --- a/bindings/node/lib/bindings/trainers.js +++ b/bindings/node/lib/bindings/trainers.js @@ -3,5 +3,6 @@ const native = require("./native"); module.exports = { bpeTrainer: native.trainers_BPETrainer, wordPieceTrainer: native.trainers_WordPieceTrainer, + wordLevelTrainer: native.trainers_WordLevelTrainer, unigramTrainer: native.trainers_UnigramTrainer, }; diff --git a/bindings/node/native/src/trainers.rs b/bindings/node/native/src/trainers.rs index 93cba7cc5..a58f77c11 100644 --- a/bindings/node/native/src/trainers.rs +++ b/bindings/node/native/src/trainers.rs @@ -8,7 +8,8 @@ use std::collections::HashMap; use std::sync::Arc; use tk::models::{ - bpe::BpeTrainer, unigram::UnigramTrainer, wordpiece::WordPieceTrainer, TrainerWrapper, + bpe::BpeTrainer, unigram::UnigramTrainer, wordlevel::WordLevelTrainer, + wordpiece::WordPieceTrainer, TrainerWrapper, }; /// Trainer @@ -17,6 +18,14 @@ pub struct Trainer { pub trainer: Option>, } +impl From for Trainer { + fn from(trainer: TrainerWrapper) -> Self { + Self { + trainer: Some(Arc::new(trainer)), + } + } +} + impl tk::Trainer for Trainer { type Model = Model; @@ -27,14 +36,26 @@ impl tk::Trainer for Trainer { .should_show_progress() } - fn train(&self, words: HashMap) -> tk::Result<(Self::Model, Vec)> { - let (model, special_tokens) = self + fn train( + &self, + words: HashMap, + model: &mut Self::Model, + ) -> tk::Result> { + let special_tokens = self .trainer .as_ref() .ok_or("Uninitialized Trainer")? - .train(words)?; + .train( + words, + &mut model + .model + .as_ref() + .ok_or("Uninitialized Model")? + .write() + .unwrap(), + )?; - Ok((model.into(), special_tokens)) + Ok(special_tokens) } fn process_tokens(&self, words: &mut HashMap, tokens: Vec) { @@ -238,6 +259,81 @@ fn wordpiece_trainer(mut cx: FunctionContext) -> JsResult { Ok(js_trainer) } +// WordLevel + +struct WordLevelTrainerOptions(WordLevelTrainer); +impl From for WordLevelTrainer { + fn from(v: WordLevelTrainerOptions) -> Self { + v.0 + } +} +impl FromJsValue for WordLevelTrainerOptions { + fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult { + if let Ok(options) = from.downcast::() { + let mut builder = WordLevelTrainer::builder(); + + if let Ok(size) = options.get(cx, "vocabSize") { + if let Some(size) = Option::from_value(size, cx)? { + builder.vocab_size(size); + } + } + if let Ok(freq) = options.get(cx, "minFrequency") { + if let Some(freq) = Option::from_value(freq, cx)? { + builder.min_frequency(freq); + } + } + if let Ok(tokens) = options.get(cx, "specialTokens") { + if tokens.downcast::().is_err() && tokens.downcast::().is_err() + { + builder.special_tokens( + tokens + .downcast::() + .map_err(|e| Error(format!("{}", e)))? + .to_vec(cx)? + .into_iter() + .map(|token| Ok(AddedToken::from_value(token, cx)?.into())) + .collect::, Error>>()?, + ); + } + } + if let Ok(show) = options.get(cx, "showProgress") { + if let Some(show) = Option::from_value(show, cx)? { + builder.show_progress(show); + } + } + + Ok(Self( + builder + .build() + .expect("WordLevelTrainerBuilder cannot fail"), + )) + } else { + Err(Error("Expected options type: object".into())) + } + } +} + +/// wordlevel_trainer(options?: { +/// vocabSize?: number = 30000, +/// minFrequency?: number = 0, +/// specialTokens?: string[] = [], +/// showProgress?: bool = true, +/// }) +fn wordlevel_trainer(mut cx: FunctionContext) -> JsResult { + let trainer = cx.extract_opt::(0)?.map_or_else( + || WordLevelTrainer::builder().build().unwrap(), + |o| o.into(), + ); + + let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?; + let guard = cx.lock(); + js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into())); + + Ok(js_trainer) +} + +// Unigram + struct UnigramTrainerOptions(UnigramTrainer); impl From for UnigramTrainer { fn from(v: UnigramTrainerOptions) -> Self { @@ -337,6 +433,7 @@ fn unigram_trainer(mut cx: FunctionContext) -> JsResult { pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_BPETrainer", prefix), bpe_trainer)?; m.export_function(&format!("{}_WordPieceTrainer", prefix), wordpiece_trainer)?; + m.export_function(&format!("{}_WordLevelTrainer", prefix), wordlevel_trainer)?; m.export_function(&format!("{}_UnigramTrainer", prefix), unigram_trainer)?; Ok(()) } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 404adc656..e72f1bd81 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -301,34 +301,41 @@ impl PyWordLevelTrainer { #[new] #[args(kwargs = "**")] pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> { - let mut trainer = tk::models::wordlevel::WordLevelTrainer::default(); + let mut builder = tk::models::wordlevel::WordLevelTrainer::builder(); if let Some(kwargs) = kwargs { for (key, val) in kwargs { let key: &str = key.extract()?; match key { - "vocab_size" => trainer.vocab_size = val.extract()?, - "min_frequency" => trainer.min_frequency = val.extract()?, - "show_progress" => trainer.show_progress = val.extract()?, + "vocab_size" => { + builder.vocab_size(val.extract()?); + } + "min_frequency" => { + builder.min_frequency(val.extract()?); + } + "show_progress" => { + builder.show_progress(val.extract()?); + } "special_tokens" => { - trainer.special_tokens = val - .cast_as::()? - .into_iter() - .map(|token| { - if let Ok(content) = token.extract::() { - Ok(PyAddedToken::from(content, Some(true)).get_token()) - } else if let Ok(mut token) = - token.extract::>() - { - token.is_special_token = true; - Ok(token.get_token()) - } else { - Err(exceptions::PyTypeError::new_err( - "special_tokens must be a List[Union[str, AddedToken]]", - )) - } - }) - .collect::>>()? + builder.special_tokens( + val.cast_as::()? + .into_iter() + .map(|token| { + if let Ok(content) = token.extract::() { + Ok(PyAddedToken::from(content, Some(true)).get_token()) + } else if let Ok(mut token) = + token.extract::>() + { + token.is_special_token = true; + Ok(token.get_token()) + } else { + Err(exceptions::PyTypeError::new_err( + "special_tokens must be a List[Union[str, AddedToken]]", + )) + } + }) + .collect::>>()?, + ); } _ => println!("Ignored unknown kwargs option {}", key), } @@ -337,7 +344,12 @@ impl PyWordLevelTrainer { Ok(( PyWordLevelTrainer {}, - PyTrainer::new(Arc::new(trainer.into())), + PyTrainer::new(Arc::new( + builder + .build() + .expect("WordLevelTrainerBuilder cannot fail") + .into(), + )), )) } } diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs index 2485b57ca..c864529fe 100644 --- a/tokenizers/src/models/wordlevel/trainer.rs +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -2,15 +2,20 @@ use super::WordLevel; use crate::{AddedToken, Result, Trainer}; use std::collections::HashMap; +#[derive(Debug, Clone, Builder)] pub struct WordLevelTrainer { /// The minimum frequency a word must have to be part of the vocabulary - pub min_frequency: u32, + #[builder(default)] + min_frequency: u32, /// The target vocabulary size - pub vocab_size: usize, + #[builder(default)] + vocab_size: usize, /// Whether to show progress while training - pub show_progress: bool, + #[builder(default)] + show_progress: bool, /// A list of special tokens that the model should know of - pub special_tokens: Vec, + #[builder(default)] + special_tokens: Vec, } impl Default for WordLevelTrainer { @@ -25,6 +30,10 @@ impl Default for WordLevelTrainer { } impl WordLevelTrainer { + pub fn builder() -> WordLevelTrainerBuilder { + WordLevelTrainerBuilder::default() + } + fn train( &self, word_counts: HashMap, From 2410452c9747408a77677f15f1b7c707d7cc3dac Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 20 Nov 2020 09:02:28 -0500 Subject: [PATCH 11/11] Make sure TrainerWrapper can only train the right Model --- .../python/tests/bindings/test_trainers.py | 7 ++++++ tokenizers/src/models/mod.rs | 22 +++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py index ee249557e..ad671b057 100644 --- a/bindings/python/tests/bindings/test_trainers.py +++ b/bindings/python/tests/bindings/test_trainers.py @@ -92,3 +92,10 @@ def test_train_with_special_tokens(self): "t ", "[SEP]", ] + + def test_cannot_train_different_model(self): + tokenizer = Tokenizer(models.BPE()) + trainer = trainers.UnigramTrainer(show_progress=False) + + with pytest.raises(Exception, match="UnigramTrainer can only train a Unigram"): + tokenizer.train([], trainer) diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 44a0d17ca..c4e05c798 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -153,19 +153,19 @@ impl Trainer for TrainerWrapper { match self { TrainerWrapper::BpeTrainer(t) => match model { ModelWrapper::BPE(bpe) => t.train(words, bpe), - _ => unreachable!(), + _ => Err("BpeTrainer can only train a BPE".into()), }, TrainerWrapper::WordPieceTrainer(t) => match model { ModelWrapper::WordPiece(wp) => t.train(words, wp), - _ => unreachable!(), + _ => Err("WordPieceTrainer can only train a WordPiece".into()), }, TrainerWrapper::WordLevelTrainer(t) => match model { ModelWrapper::WordLevel(wl) => t.train(words, wl), - _ => unreachable!(), + _ => Err("WordLevelTrainer can only train a WordLevel".into()), }, TrainerWrapper::UnigramTrainer(t) => match model { ModelWrapper::Unigram(u) => t.train(words, u), - _ => unreachable!(), + _ => Err("UnigramTrainer can only train a Unigram".into()), }, } } @@ -184,3 +184,17 @@ impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer); impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer); impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer); impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn trainer_wrapper_train_model_wrapper() { + let trainer = TrainerWrapper::BpeTrainer(BpeTrainer::default()); + let mut model = ModelWrapper::Unigram(Unigram::default()); + + let result = trainer.train(HashMap::new(), &mut model); + assert!(result.is_err()); + } +}