Skip to content

Commit

Permalink
Test BPE keeping its options after training
Browse files Browse the repository at this point in the history
  • Loading branch information
n1t0 committed Nov 10, 2020
1 parent 4ca7a1f commit 73415a2
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 18 deletions.
4 changes: 2 additions & 2 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,12 @@ mod test {
match *py_model.model.as_ref().read().unwrap() {
ModelWrapper::BPE(_) => (),
_ => panic!("Expected Bert postprocessor."),
}
};

let py_model: PyModel = serde_json::from_str(&rs_wrapper_ser).unwrap();
match *py_model.model.as_ref().read().unwrap() {
ModelWrapper::BPE(_) => (),
_ => panic!("Expected Bert postprocessor."),
}
};
}
}
10 changes: 6 additions & 4 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,9 @@ impl PyTokenizer {
let trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
Python::with_gil(|py| {
py.allow_threads(|| ToPyResult(self.tokenizer.train(&trainer, files)).into())
py.allow_threads(|| {
ToPyResult(self.tokenizer.train(&trainer, files).map(|_| {})).into()
})
})
}

Expand Down Expand Up @@ -1169,15 +1171,15 @@ mod test {
use super::*;
use crate::models::PyModel;
use crate::normalizers::{PyNormalizer, PyNormalizerTypeWrapper};
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use tempfile::NamedTempFile;
use tk::normalizers::{Lowercase, NFKC};

#[test]
fn serialize() {
let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new(
let mut tokenizer = Tokenizer::new(PyModel::new(Arc::new(RwLock::new(
tk::models::bpe::BPE::default().into(),
)));
))));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
Arc::new(NFKC.into()),
Arc::new(Lowercase.into()),
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/tests/bindings/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def pre_tokenize(self, pretok):
del os.environ["TOKENIZERS_PARALLELISM"]

trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
bpe_tokenizer.train(trainer, [train_files["small"]])
bpe_tokenizer.train([train_files["small"]], trainer=trainer)

def test_train_with_special_tokens(self):
filename = "tests/data/dummy-unigram-special_tokens-train.txt"
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_train_with_special_tokens(self):
show_progress=False, special_tokens=["[PAD]", "[SEP]", "[CLS]"], unk_token="[UNK]"
)

tokenizer.train(trainer, [filename])
tokenizer.train([filename], trainer=trainer)

assert tokenizer.encode("[CLS] This is a test [SEP]").tokens == [
"[CLS]",
Expand Down
10 changes: 5 additions & 5 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,15 @@ pub struct BPE {
cache: Option<Cache<String, Word>>,
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
/// perform no merges, so the result will just be characters.
pub(super) dropout: Option<f32>,
pub dropout: Option<f32>,
/// The unknown token to be used when we encounter an unknown char
pub(super) unk_token: Option<String>,
pub unk_token: Option<String>,
/// An optional prefix to use on any subword that exist only behind another one
pub(super) continuing_subword_prefix: Option<String>,
pub continuing_subword_prefix: Option<String>,
/// An optional suffix to caracterize and end-of-word subword
pub(super) end_of_word_suffix: Option<String>,
pub end_of_word_suffix: Option<String>,
/// Do multiple unk tokens get fused
pub(super) fuse_unk: bool,
pub fuse_unk: bool,
}

impl std::fmt::Debug for BPE {
Expand Down
4 changes: 2 additions & 2 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,7 @@ where
}

/// Train a model and replace our current Model, using the given Trainer
pub fn train<T>(&mut self, trainer: &T, files: Vec<String>) -> Result<()>
pub fn train<T>(&mut self, trainer: &T, files: Vec<String>) -> Result<&mut Self>
where
T: Trainer<Model = M> + Sync,
{
Expand All @@ -1065,7 +1065,7 @@ where
let special_tokens = trainer.train(words, &mut self.model)?;
self.add_special_tokens(&special_tokens);

Ok(())
Ok(self)
}
}

Expand Down
6 changes: 3 additions & 3 deletions tokenizers/tests/documentation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tokenizers::{Tokenizer, TokenizerImpl};
#[test]
fn train_tokenizer() {
let vocab_size: usize = 100;
let tokenizer = TokenizerBuilder::new()
let mut tokenizer = TokenizerBuilder::new()
.with_model(BPE::default())
.with_normalizer(Some(Sequence::new(vec![
Strip::new(true, true).into(),
Expand Down Expand Up @@ -97,7 +97,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
"data/wikitext-103-raw/wiki.test.raw".into(),
"data/wikitext-103-raw/wiki.valid.raw".into(),
];
tokenizer.train_and_replace(&trainer, files)?;
tokenizer.train(&trainer, files)?;
// END quicktour_train
// START quicktour_reload_model
use std::path::Path;
Expand Down Expand Up @@ -427,7 +427,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
"data/wikitext-103-raw/wiki.test.raw".into(),
"data/wikitext-103-raw/wiki.valid.raw".into(),
];
bert_tokenizer.train_and_replace(&trainer, files)?;
bert_tokenizer.train(&trainer, files)?;

let model_files = bert_tokenizer
.get_model()
Expand Down
29 changes: 29 additions & 0 deletions tokenizers/tests/training.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use tokenizers::models::bpe::BPE;
use tokenizers::{DecoderWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper};
use tokenizers::{Model, TokenizerBuilder};

#[test]
fn bpe_values_after_training() {
let mut tokenizer = TokenizerBuilder::<
BPE,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
>::default()
.with_model(
BPE::builder()
.unk_token("[UNK]".to_string())
.dropout(0.1)
.build()
.unwrap(),
)
.build()
.unwrap();
let trainer = tokenizer.get_model().get_trainer();
tokenizer
.train(&trainer, vec!["./data/small.txt".to_string()])
.unwrap();
assert_eq!(tokenizer.get_model().dropout, Some(0.1));
assert_eq!(tokenizer.get_model().unk_token, Some("[UNK]".to_string()));
}

0 comments on commit 73415a2

Please sign in to comment.