From 73415a22e1759a59d9f9c4b36620fbe74694b711 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 10 Nov 2020 11:01:56 -0500 Subject: [PATCH] Test BPE keeping its options after training --- bindings/python/src/models.rs | 4 +-- bindings/python/src/tokenizer.rs | 10 ++++--- .../python/tests/bindings/test_trainers.py | 4 +-- tokenizers/src/models/bpe/model.rs | 10 +++---- tokenizers/src/tokenizer/mod.rs | 4 +-- tokenizers/tests/documentation.rs | 6 ++-- tokenizers/tests/training.rs | 29 +++++++++++++++++++ 7 files changed, 49 insertions(+), 18 deletions(-) create mode 100644 tokenizers/tests/training.rs diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 7c604d150..cd1c07cd2 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -492,12 +492,12 @@ mod test { match *py_model.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => (), _ => panic!("Expected Bert postprocessor."), - } + }; let py_model: PyModel = serde_json::from_str(&rs_wrapper_ser).unwrap(); match *py_model.model.as_ref().read().unwrap() { ModelWrapper::BPE(_) => (), _ => panic!("Expected Bert postprocessor."), - } + }; } } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index eaa42b19c..967ec5d6d 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1041,7 +1041,9 @@ impl PyTokenizer { let trainer = trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone()); Python::with_gil(|py| { - py.allow_threads(|| ToPyResult(self.tokenizer.train(&trainer, files)).into()) + py.allow_threads(|| { + ToPyResult(self.tokenizer.train(&trainer, files).map(|_| {})).into() + }) }) } @@ -1169,15 +1171,15 @@ mod test { use super::*; use crate::models::PyModel; use crate::normalizers::{PyNormalizer, PyNormalizerTypeWrapper}; - use std::sync::Arc; + use std::sync::{Arc, RwLock}; use tempfile::NamedTempFile; use tk::normalizers::{Lowercase, NFKC}; #[test] fn serialize() { - let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new( + let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new(RwLock::new( tk::models::bpe::BPE::default().into(), - ))); + )))); tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ Arc::new(NFKC.into()), Arc::new(Lowercase.into()), diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py index c8f092d79..ee249557e 100644 --- a/bindings/python/tests/bindings/test_trainers.py +++ b/bindings/python/tests/bindings/test_trainers.py @@ -41,7 +41,7 @@ def pre_tokenize(self, pretok): del os.environ["TOKENIZERS_PARALLELISM"] trainer = trainers.BpeTrainer(special_tokens=[""], show_progress=False) - bpe_tokenizer.train(trainer, [train_files["small"]]) + bpe_tokenizer.train([train_files["small"]], trainer=trainer) def test_train_with_special_tokens(self): filename = "tests/data/dummy-unigram-special_tokens-train.txt" @@ -76,7 +76,7 @@ def test_train_with_special_tokens(self): show_progress=False, special_tokens=["[PAD]", "[SEP]", "[CLS]"], unk_token="[UNK]" ) - tokenizer.train(trainer, [filename]) + tokenizer.train([filename], trainer=trainer) assert tokenizer.encode("[CLS] This is a test [SEP]").tokens == [ "[CLS]", diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 64eb413bd..2ffba7118 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -189,15 +189,15 @@ pub struct BPE { cache: Option>, /// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will /// perform no merges, so the result will just be characters. - pub(super) dropout: Option, + pub dropout: Option, /// The unknown token to be used when we encounter an unknown char - pub(super) unk_token: Option, + pub unk_token: Option, /// An optional prefix to use on any subword that exist only behind another one - pub(super) continuing_subword_prefix: Option, + pub continuing_subword_prefix: Option, /// An optional suffix to caracterize and end-of-word subword - pub(super) end_of_word_suffix: Option, + pub end_of_word_suffix: Option, /// Do multiple unk tokens get fused - pub(super) fuse_unk: bool, + pub fuse_unk: bool, } impl std::fmt::Debug for BPE { diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 7050b6f55..1e178075b 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1056,7 +1056,7 @@ where } /// Train a model and replace our current Model, using the given Trainer - pub fn train(&mut self, trainer: &T, files: Vec) -> Result<()> + pub fn train(&mut self, trainer: &T, files: Vec) -> Result<&mut Self> where T: Trainer + Sync, { @@ -1065,7 +1065,7 @@ where let special_tokens = trainer.train(words, &mut self.model)?; self.add_special_tokens(&special_tokens); - Ok(()) + Ok(self) } } diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 9353d76cf..2520ffa63 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -8,7 +8,7 @@ use tokenizers::{Tokenizer, TokenizerImpl}; #[test] fn train_tokenizer() { let vocab_size: usize = 100; - let tokenizer = TokenizerBuilder::new() + let mut tokenizer = TokenizerBuilder::new() .with_model(BPE::default()) .with_normalizer(Some(Sequence::new(vec![ Strip::new(true, true).into(), @@ -97,7 +97,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> { "data/wikitext-103-raw/wiki.test.raw".into(), "data/wikitext-103-raw/wiki.valid.raw".into(), ]; - tokenizer.train_and_replace(&trainer, files)?; + tokenizer.train(&trainer, files)?; // END quicktour_train // START quicktour_reload_model use std::path::Path; @@ -427,7 +427,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> { "data/wikitext-103-raw/wiki.test.raw".into(), "data/wikitext-103-raw/wiki.valid.raw".into(), ]; - bert_tokenizer.train_and_replace(&trainer, files)?; + bert_tokenizer.train(&trainer, files)?; let model_files = bert_tokenizer .get_model() diff --git a/tokenizers/tests/training.rs b/tokenizers/tests/training.rs new file mode 100644 index 000000000..1454a48b5 --- /dev/null +++ b/tokenizers/tests/training.rs @@ -0,0 +1,29 @@ +use tokenizers::models::bpe::BPE; +use tokenizers::{DecoderWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper}; +use tokenizers::{Model, TokenizerBuilder}; + +#[test] +fn bpe_values_after_training() { + let mut tokenizer = TokenizerBuilder::< + BPE, + NormalizerWrapper, + PreTokenizerWrapper, + PostProcessorWrapper, + DecoderWrapper, + >::default() + .with_model( + BPE::builder() + .unk_token("[UNK]".to_string()) + .dropout(0.1) + .build() + .unwrap(), + ) + .build() + .unwrap(); + let trainer = tokenizer.get_model().get_trainer(); + tokenizer + .train(&trainer, vec!["./data/small.txt".to_string()]) + .unwrap(); + assert_eq!(tokenizer.get_model().dropout, Some(0.1)); + assert_eq!(tokenizer.get_model().unk_token, Some("[UNK]".to_string())); +}