Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer improvements #519

Merged
merged 11 commits into from
Nov 20, 2020
11 changes: 2 additions & 9 deletions bindings/node/examples/documentation/pipeline.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ describe("pipelineExample", () => {
let { Tokenizer } = require("tokenizers/bindings/tokenizer");
let { WordPiece } = require("tokenizers/bindings/models");

let bertTokenizer = new Tokenizer(WordPiece.empty());
let bertTokenizer = new Tokenizer(WordPiece.init({}, { unkToken: "[UNK]" }));
// END bert_setup_tokenizer
// START bert_setup_normalizer
let { sequenceNormalizer, lowercaseNormalizer, nfdNormalizer, stripAccentsNormalizer }
Expand All @@ -120,20 +120,13 @@ describe("pipelineExample", () => {
// END bert_setup_processor
// START bert_train_tokenizer
let { wordPieceTrainer } = require("tokenizers/bindings/trainers");
let { promisify } = require("util");

let trainer = wordPieceTrainer({
vocabSize: 30522,
specialTokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
});
let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`);
bertTokenizer.train(trainer, files);

let modelFiles = bertTokenizer.getModel().save("data", "bert-wiki");
let fromFile = promisify(WordPiece.fromFile);
bertTokenizer.setModel(await fromFile(modelFiles[0], {
unkToken: "[UNK]"
}));
bertTokenizer.train(files, trainer);

bertTokenizer.save("data/bert-wiki.json")
// END bert_train_tokenizer
Expand Down
13 changes: 2 additions & 11 deletions bindings/node/examples/documentation/quicktour.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ describe("quicktourExample", () => {
let { Tokenizer } = require("tokenizers/bindings/tokenizer");
let { BPE } = require("tokenizers/bindings/models");

let tokenizer = new Tokenizer(BPE.empty());
let tokenizer = new Tokenizer(BPE.init({}, [], { unkToken: "[UNK]" }));
// END init_tokenizer
// START init_trainer
let { bpeTrainer } = require("tokenizers/bindings/trainers");
Expand All @@ -32,17 +32,8 @@ describe("quicktourExample", () => {
// END init_pretok
// START train
let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`);
tokenizer.train(trainer, files);
tokenizer.train(files, trainer);
// END train
// START reload_model
let { promisify } = require("util");

let modelFiles = tokenizer.getModel().save("data", "wiki");
let fromFile = promisify(BPE.fromFile);
tokenizer.setModel(await fromFile(modelFiles[0], modelFiles[1], {
unkToken: "[UNK]"
}));
// END reload_model
// START save
tokenizer.save("data/tokenizer-wiki.json");
// END save
Expand Down
29 changes: 29 additions & 0 deletions bindings/node/lib/bindings/trainers.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,35 @@ export function bpeTrainer(options?: TrainerOptions): Trainer;
*/
export function wordPieceTrainer(options?: TrainerOptions): Trainer;

export interface WordLevelTrainerOptions {
/**
* The minimum frequency a pair should have in order to be merged.
* @default 2
*/
minFrequency?: number;
/**
* Whether to show progress bars while training.
* @default true
*/
showProgress?: boolean;
/**
* A list of special tokens the model should know of.
* @default []
*/
specialTokens?: (string | AddedToken)[];
/**
* The size of the final vocabulary, including all tokens and alphabet.
* @default 30000
*/
vocabSize?: number;
}

/**
* Instantiate a new WordLevel Trainer
* @param [options] WordLevel Trainer options
*/
export function wordLevelTrainer(options?: WordLevelTrainerOptions): Trainer;

export interface UnigramTrainerOptions {
vocabSize?: number;
nSubIterations?: number;
Expand Down
1 change: 1 addition & 0 deletions bindings/node/lib/bindings/trainers.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ const native = require("./native");
module.exports = {
bpeTrainer: native.trainers_BPETrainer,
wordPieceTrainer: native.trainers_WordPieceTrainer,
wordLevelTrainer: native.trainers_WordLevelTrainer,
unigramTrainer: native.trainers_UnigramTrainer,
};
61 changes: 43 additions & 18 deletions bindings/node/native/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ extern crate tokenizers as tk;

use crate::extraction::*;
use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask};
use crate::trainers::Trainer;
use neon::prelude::*;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::{Arc, RwLock};

use tk::models::{
bpe::{BpeBuilder, Merges, Vocab},
Expand All @@ -21,53 +22,76 @@ use tk::Token;
#[derive(Clone, Serialize, Deserialize)]
pub struct Model {
#[serde(flatten)]
pub model: Option<Arc<ModelWrapper>>,
pub model: Option<Arc<RwLock<ModelWrapper>>>,
}

impl From<ModelWrapper> for Model {
fn from(wrapper: ModelWrapper) -> Self {
impl<M> From<M> for Model
where
M: Into<ModelWrapper>,
{
fn from(wrapper: M) -> Self {
Self {
model: Some(Arc::new(wrapper)),
model: Some(Arc::new(RwLock::new(wrapper.into()))),
}
}
}

impl tk::Model for Model {
type Trainer = Trainer;

fn tokenize(&self, sequence: &str) -> tk::Result<Vec<Token>> {
self.model
.as_ref()
.ok_or("Uninitialized Model")?
.read()
.unwrap()
.tokenize(sequence)
}

fn token_to_id(&self, token: &str) -> Option<u32> {
self.model.as_ref()?.token_to_id(token)
self.model.as_ref()?.read().unwrap().token_to_id(token)
}

fn id_to_token(&self, id: u32) -> Option<&str> {
self.model.as_ref()?.id_to_token(id)
fn id_to_token(&self, id: u32) -> Option<String> {
self.model.as_ref()?.read().unwrap().id_to_token(id)
}

fn get_vocab(&self) -> &HashMap<String, u32> {
fn get_vocab(&self) -> HashMap<String, u32> {
self.model
.as_ref()
.expect("Uninitialized Model")
.read()
.unwrap()
.get_vocab()
}

fn get_vocab_size(&self) -> usize {
self.model
.as_ref()
.expect("Uninitialized Model")
.read()
.unwrap()
.get_vocab_size()
}

fn save(&self, folder: &Path, name: Option<&str>) -> tk::Result<Vec<PathBuf>> {
self.model
.as_ref()
.ok_or("Uninitialized Model")?
.read()
.unwrap()
.save(folder, name)
}

fn get_trainer(&self) -> Self::Trainer {
self.model
.as_ref()
.expect("Uninitialized Model")
.read()
.unwrap()
.get_trainer()
.into()
}
}

declare_types! {
Expand All @@ -86,7 +110,8 @@ declare_types! {
let guard = cx.lock();

let files = this.borrow(&guard)
.model.as_ref().unwrap()
.model.as_ref().expect("Uninitialized Model")
.read().unwrap()
.save(
Path::new(&folder),
name.as_deref()
Expand Down Expand Up @@ -153,7 +178,7 @@ fn bpe_init(mut cx: FunctionContext) -> JsResult<JsModel> {

let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(model.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into())));

Ok(js_model)
}
Expand Down Expand Up @@ -191,7 +216,7 @@ fn bpe_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
let bpe = tk::models::bpe::BPE::default();

let guard = cx.lock();
model.borrow_mut(&guard).model = Some(Arc::new(bpe.into()));
model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(bpe.into())));

Ok(model)
}
Expand Down Expand Up @@ -236,7 +261,7 @@ fn wordpiece_init(mut cx: FunctionContext) -> JsResult<JsModel> {

let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(model.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into())));

Ok(js_model)
}
Expand Down Expand Up @@ -270,7 +295,7 @@ fn wordpiece_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
let wordpiece = tk::models::wordpiece::WordPiece::default();

let guard = cx.lock();
model.borrow_mut(&guard).model = Some(Arc::new(wordpiece.into()));
model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordpiece.into())));

Ok(model)
}
Expand Down Expand Up @@ -305,7 +330,7 @@ fn wordlevel_init(mut cx: FunctionContext) -> JsResult<JsModel> {

let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(model.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into())));

Ok(js_model)
}
Expand Down Expand Up @@ -337,7 +362,7 @@ fn wordlevel_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
let wordlevel = tk::models::wordlevel::WordLevel::default();

let guard = cx.lock();
model.borrow_mut(&guard).model = Some(Arc::new(wordlevel.into()));
model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordlevel.into())));

Ok(model)
}
Expand All @@ -362,7 +387,7 @@ fn unigram_init(mut cx: FunctionContext) -> JsResult<JsModel> {

let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(unigram.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(unigram.into())));

Ok(js_model)
}
Expand All @@ -373,7 +398,7 @@ fn unigram_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
let unigram = tk::models::unigram::Unigram::default();

let guard = cx.lock();
model.borrow_mut(&guard).model = Some(Arc::new(unigram.into()));
model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(unigram.into())));

Ok(model)
}
Expand Down
8 changes: 4 additions & 4 deletions bindings/node/native/src/tasks/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate tokenizers as tk;

use crate::models::*;
use neon::prelude::*;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use tk::models::bpe::{BpeBuilder, BPE};
use tk::models::wordlevel::{WordLevel, WordLevelBuilder};
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
Expand Down Expand Up @@ -34,7 +34,7 @@ impl Task for WordPieceFromFilesTask {

let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(wordpiece.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordpiece.into())));

Ok(js_model.upcast())
}
Expand Down Expand Up @@ -67,7 +67,7 @@ impl Task for WordLevelFromFilesTask {

let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(wordlevel.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordlevel.into())));

Ok(js_model.upcast())
}
Expand Down Expand Up @@ -100,7 +100,7 @@ impl Task for BPEFromFilesTask {

let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(bpe.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(bpe.into())));

Ok(js_model.upcast())
}
Expand Down
24 changes: 18 additions & 6 deletions bindings/node/native/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::trainers::JsTrainer;
use neon::prelude::*;
use std::sync::{Arc, RwLock};

use tk::Model as ModelTrait;
use tk::TokenizerImpl;

// AddedToken
Expand Down Expand Up @@ -634,7 +635,7 @@ declare_types! {
let guard = cx.lock();
let token = this.borrow(&guard)
.tokenizer.read().unwrap()
.id_to_token(id).map(|t| t.to_owned());
.id_to_token(id);

if let Some(token) = token {
Ok(cx.string(token).upcast())
Expand Down Expand Up @@ -745,18 +746,29 @@ declare_types! {
}

method train(mut cx) {
// train(trainer: JsTrainer, files: string[])
// train(files: string[], trainer?: Trainer)

let trainer = cx.argument::<JsTrainer>(0)?;
let files = cx.extract::<Vec<String>>(1)?;
let files = cx.extract::<Vec<String>>(0)?;
let trainer = if let Some(val) = cx.argument_opt(1) {
let js_trainer = val.downcast::<JsTrainer>().or_throw(&mut cx)?;
let guard = cx.lock();

let trainer = js_trainer.borrow(&guard).clone();
trainer
} else {
let this = cx.this();
let guard = cx.lock();

let trainer = this.borrow(&guard).tokenizer.read().unwrap().get_model().get_trainer();
trainer
};

let mut this = cx.this();
let guard = cx.lock();

let trainer = trainer.borrow(&guard).clone();
this.borrow_mut(&guard)
.tokenizer.write().unwrap()
.train_and_replace(&trainer, files)
.train(&trainer, files)
.map_err(|e| Error(format!("{}", e)))?;

Ok(cx.undefined().upcast())
Expand Down
Loading