From 172838ca43744107a509ed93ba8d7eac8985abd9 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Fri, 12 Jun 2020 21:29:24 +0200 Subject: [PATCH] Add pickling support for Python tokenizers (#73) --- Cargo.toml | 3 +- python/Cargo.toml | 2 + python/src/lib.rs | 1 + python/src/stem.rs | 70 +++++++++++++++++--------- python/src/tokenize.rs | 84 ++++++++++++++++++------------- python/src/tokenize_sentence.rs | 25 ++++++++- python/src/utils.rs | 30 +++++++++++ python/vtext/tests/test_common.py | 45 +++++++++++++++++ python/vtext/tests/test_stem.py | 6 +-- src/tokenize/mod.rs | 9 ++-- src/tokenize_sentence/mod.rs | 5 +- 11 files changed, 210 insertions(+), 70 deletions(-) create mode 100644 python/src/utils.rs create mode 100644 python/vtext/tests/test_common.py diff --git a/Cargo.toml b/Cargo.toml index 49eaa3f..56c7dee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ keywords = [ "tokenization", "tfidf", "levenshtein", - "matching" + "text-processing" ] edition = "2018" exclude = [ @@ -41,6 +41,7 @@ lazy_static = "1.4.0" seahash = "4.0.0" itertools = "0.8" ndarray = "0.13.0" +serde = { version = "1.0", features = ["derive"] } sprs = {version = "0.7.1", default-features = false} unicode-segmentation = "1.6.0" hashbrown = { version = "0.7", features = ["rayon"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index 75b33be..d1a50ef 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -9,10 +9,12 @@ crate-type = ["cdylib"] [dependencies] ndarray = "0.13" +serde = { version = "1.0", features = ["derive"] } sprs = {version = "0.7.1", default-features = false} vtext = {"path" = "../", features=["python", "rayon"]} rust-stemmers = "1.1" rayon = "1.3" +bincode = "1.2.1" [dependencies.numpy] version = "0.9.0" diff --git a/python/src/lib.rs b/python/src/lib.rs index 01f4ebd..d795c19 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -18,6 +18,7 @@ use pyo3::wrap_pyfunction; mod stem; mod tokenize; mod tokenize_sentence; +mod utils; mod vectorize; use vtext::metrics; diff --git a/python/src/stem.rs b/python/src/stem.rs index 9ba9176..5a8b573 100644 --- a/python/src/stem.rs +++ b/python/src/stem.rs @@ -4,8 +4,10 @@ // . This file may not be copied, // modified, or distributed except according to those terms. +use crate::utils::{deserialize_params, serialize_params}; use pyo3::exceptions; use pyo3::prelude::*; +use pyo3::types::PyDict; /// __init__(self, lang='english') /// @@ -14,40 +16,43 @@ use pyo3::prelude::*; /// Wraps the rust-stemmers crate that uses an implementation generated /// by the `Snowball compiler `_ /// for Rust. -#[pyclass] +#[pyclass(module = "vtext.stem")] pub struct SnowballStemmer { pub lang: String, inner: rust_stemmers::Stemmer, } +fn get_algorithm(lang: &str) -> PyResult { + match lang { + "arabic" => Ok(rust_stemmers::Algorithm::Arabic), + "danish" => Ok(rust_stemmers::Algorithm::Danish), + "dutch" => Ok(rust_stemmers::Algorithm::Dutch), + "english" => Ok(rust_stemmers::Algorithm::English), + "french" => Ok(rust_stemmers::Algorithm::French), + "german" => Ok(rust_stemmers::Algorithm::German), + "greek" => Ok(rust_stemmers::Algorithm::Greek), + "hungarian" => Ok(rust_stemmers::Algorithm::Hungarian), + "italian" => Ok(rust_stemmers::Algorithm::Italian), + "portuguese" => Ok(rust_stemmers::Algorithm::Portuguese), + "romanian" => Ok(rust_stemmers::Algorithm::Romanian), + "russian" => Ok(rust_stemmers::Algorithm::Russian), + "spanish" => Ok(rust_stemmers::Algorithm::Spanish), + "swedish" => Ok(rust_stemmers::Algorithm::Swedish), + "tamil" => Ok(rust_stemmers::Algorithm::Tamil), + "turkish" => Ok(rust_stemmers::Algorithm::Turkish), + _ => Err(exceptions::ValueError::py_err(format!( + "lang={} is unsupported!", + lang + ))), + } +} + #[pymethods] impl SnowballStemmer { #[new] #[args(lang = "\"english\"")] fn new(lang: &str) -> PyResult { - let algorithm = match lang { - "arabic" => Ok(rust_stemmers::Algorithm::Arabic), - "danish" => Ok(rust_stemmers::Algorithm::Danish), - "dutch" => Ok(rust_stemmers::Algorithm::Dutch), - "english" => Ok(rust_stemmers::Algorithm::English), - "french" => Ok(rust_stemmers::Algorithm::French), - "german" => Ok(rust_stemmers::Algorithm::German), - "greek" => Ok(rust_stemmers::Algorithm::Greek), - "hungarian" => Ok(rust_stemmers::Algorithm::Hungarian), - "italian" => Ok(rust_stemmers::Algorithm::Italian), - "portuguese" => Ok(rust_stemmers::Algorithm::Portuguese), - "romanian" => Ok(rust_stemmers::Algorithm::Romanian), - "russian" => Ok(rust_stemmers::Algorithm::Russian), - "spanish" => Ok(rust_stemmers::Algorithm::Spanish), - "swedish" => Ok(rust_stemmers::Algorithm::Swedish), - "tamil" => Ok(rust_stemmers::Algorithm::Tamil), - "turkish" => Ok(rust_stemmers::Algorithm::Turkish), - _ => Err(exceptions::ValueError::py_err(format!( - "lang={} is unsupported!", - lang - ))), - }?; - + let algorithm = get_algorithm(lang)?; let stemmer = rust_stemmers::Stemmer::create(algorithm); Ok(SnowballStemmer { @@ -73,4 +78,21 @@ impl SnowballStemmer { let res = self.inner.stem(word).to_string(); Ok(res) } + + fn get_params<'p>(&self, py: Python<'p>) -> PyResult<&'p PyDict> { + let params = PyDict::new(py); + params.set_item("lang", self.lang.clone())?; + Ok(params) + } + + pub fn __getstate__(&self, py: Python) -> PyResult { + serialize_params(&self.lang, py) + } + + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + self.lang = deserialize_params(py, state)?; + let algorithm = get_algorithm(&self.lang)?; + self.inner = rust_stemmers::Stemmer::create(algorithm); + Ok(()) + } } diff --git a/python/src/tokenize.rs b/python/src/tokenize.rs index 304b779..6d22e45 100644 --- a/python/src/tokenize.rs +++ b/python/src/tokenize.rs @@ -7,9 +7,10 @@ use pyo3::prelude::*; use pyo3::types::PyList; +use crate::utils::{deserialize_params, serialize_params}; use vtext::tokenize::*; -#[pyclass] +#[pyclass(module = "vtext.tokenize")] pub struct BaseTokenizer {} #[pymethods] @@ -31,9 +32,8 @@ impl BaseTokenizer { /// References /// ---------- /// - `Unicode® Standard Annex #29 `_ -#[pyclass(extends=BaseTokenizer)] +#[pyclass(extends=BaseTokenizer, module="vtext.tokenize")] pub struct UnicodeSegmentTokenizer { - pub word_bounds: bool, inner: vtext::tokenize::UnicodeSegmentTokenizer, } @@ -48,10 +48,7 @@ impl UnicodeSegmentTokenizer { .unwrap(); ( - UnicodeSegmentTokenizer { - word_bounds: word_bounds, - inner: tokenizer, - }, + UnicodeSegmentTokenizer { inner: tokenizer }, BaseTokenizer::new(), ) } @@ -86,6 +83,16 @@ impl UnicodeSegmentTokenizer { fn get_params(&self) -> PyResult { Ok(self.inner.params.clone()) } + + pub fn __getstate__(&self, py: Python) -> PyResult { + serialize_params(&self.inner.params, py) + } + + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + let mut params: UnicodeSegmentTokenizerParams = deserialize_params(py, state)?; + self.inner = params.build().unwrap(); + Ok(()) + } } /// __init__(self, lang="en") @@ -104,9 +111,8 @@ impl UnicodeSegmentTokenizer { /// ---------- /// /// - `Unicode® Standard Annex #29 `_ -#[pyclass(extends=BaseTokenizer)] +#[pyclass(extends=BaseTokenizer, module="vtext.tokenize")] pub struct VTextTokenizer { - pub lang: String, inner: vtext::tokenize::VTextTokenizer, } @@ -120,13 +126,7 @@ impl VTextTokenizer { .build() .unwrap(); - ( - VTextTokenizer { - lang: lang.to_string(), - inner: tokenizer, - }, - BaseTokenizer::new(), - ) + (VTextTokenizer { inner: tokenizer }, BaseTokenizer::new()) } /// tokenize(self, x) @@ -159,14 +159,23 @@ impl VTextTokenizer { fn get_params(&self) -> PyResult { Ok(self.inner.params.clone()) } + + pub fn __getstate__(&self, py: Python) -> PyResult { + serialize_params(&self.inner.params, py) + } + + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + let mut params: VTextTokenizerParams = deserialize_params(py, state)?; + self.inner = params.build().unwrap(); + Ok(()) + } } /// __init__(self, pattern=r'\\b\\w\\w+\\b') /// /// Tokenize a document using regular expressions -#[pyclass(extends=BaseTokenizer)] +#[pyclass(extends=BaseTokenizer, module="vtext.tokenize")] pub struct RegexpTokenizer { - pub pattern: String, inner: vtext::tokenize::RegexpTokenizer, } @@ -180,13 +189,7 @@ impl RegexpTokenizer { .build() .unwrap(); - ( - RegexpTokenizer { - pattern: pattern.to_string(), - inner: inner, - }, - BaseTokenizer::new(), - ) + (RegexpTokenizer { inner: inner }, BaseTokenizer::new()) } /// tokenize(self, x) @@ -219,6 +222,16 @@ impl RegexpTokenizer { fn get_params(&self) -> PyResult { Ok(self.inner.params.clone()) } + + pub fn __getstate__(&self, py: Python) -> PyResult { + serialize_params(&self.inner.params, py) + } + + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + let mut params: RegexpTokenizerParams = deserialize_params(py, state)?; + self.inner = params.build().unwrap(); + Ok(()) + } } /// __init__(self, window_size=4) @@ -237,9 +250,8 @@ impl RegexpTokenizer { /// >>> tokenizer.tokenize('fox can\'t') /// ['fox ', 'ox c', 'x ca', ' can', 'can\'', 'an\'t'] /// -#[pyclass(extends=BaseTokenizer)] +#[pyclass(extends=BaseTokenizer, module="vtext.tokenize")] pub struct CharacterTokenizer { - pub window_size: usize, inner: vtext::tokenize::CharacterTokenizer, } @@ -253,13 +265,7 @@ impl CharacterTokenizer { .build() .unwrap(); - ( - CharacterTokenizer { - window_size: window_size, - inner: inner, - }, - BaseTokenizer::new(), - ) + (CharacterTokenizer { inner: inner }, BaseTokenizer::new()) } /// tokenize(self, x) @@ -292,4 +298,14 @@ impl CharacterTokenizer { fn get_params(&self) -> PyResult { Ok(self.inner.params.clone()) } + + pub fn __getstate__(&self, py: Python) -> PyResult { + serialize_params(&self.inner.params, py) + } + + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + let mut params: CharacterTokenizerParams = deserialize_params(py, state)?; + self.inner = params.build().unwrap(); + Ok(()) + } } diff --git a/python/src/tokenize_sentence.rs b/python/src/tokenize_sentence.rs index ad16652..c86cd67 100644 --- a/python/src/tokenize_sentence.rs +++ b/python/src/tokenize_sentence.rs @@ -10,6 +10,7 @@ use pyo3::types::PyList; use vtext::tokenize::Tokenizer; use vtext::tokenize_sentence::*; +use crate::utils::{deserialize_params, serialize_params}; // macro located `vtext::tokenize_sentence::vecString` use vtext::vecString; @@ -24,7 +25,7 @@ use vtext::vecString; /// References /// ---------- /// - `Unicode® Standard Annex #29 `_ -#[pyclass(extends=BaseTokenizer)] +#[pyclass(extends=BaseTokenizer, module="vtext.tokenize_sentence")] pub struct UnicodeSentenceTokenizer { inner: vtext::tokenize_sentence::UnicodeSentenceTokenizer, } @@ -73,6 +74,16 @@ impl UnicodeSentenceTokenizer { fn get_params(&self) -> PyResult { Ok(self.inner.params.clone()) } + + pub fn __getstate__(&self, py: Python) -> PyResult { + serialize_params(&self.inner.params, py) + } + + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + let mut params: UnicodeSentenceTokenizerParams = deserialize_params(py, state)?; + self.inner = params.build().unwrap(); + Ok(()) + } } /// __init__(self, punctuation=[".", "?", "!"]) @@ -88,7 +99,7 @@ impl UnicodeSentenceTokenizer { /// Punctuation tokens used to determine boundaries. Only the first unicode "character" is used. /// /// -#[pyclass(extends=BaseTokenizer)] +#[pyclass(extends=BaseTokenizer, module="vtext.tokenize_sentence")] pub struct PunctuationTokenizer { inner: vtext::tokenize_sentence::PunctuationTokenizer, } @@ -139,4 +150,14 @@ impl PunctuationTokenizer { fn get_params(&self) -> PyResult { Ok(self.inner.params.clone()) } + + pub fn __getstate__(&self, py: Python) -> PyResult { + serialize_params(&self.inner.params, py) + } + + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + let mut params: PunctuationTokenizerParams = deserialize_params(py, state)?; + self.inner = params.build().unwrap(); + Ok(()) + } } diff --git a/python/src/utils.rs b/python/src/utils.rs new file mode 100644 index 0000000..44067d1 --- /dev/null +++ b/python/src/utils.rs @@ -0,0 +1,30 @@ +// Copyright 2019 vtext developers +// +// Licensed under the Apache License, Version 2.0, +// . This file may not be copied, +// modified, or distributed except according to those terms. +use bincode::{deserialize, serialize}; +use pyo3::prelude::*; +use pyo3::types::PyBytes; +use serde::de::DeserializeOwned; +use serde::Serialize; + +pub fn serialize_params(params: &T, py: Python) -> PyResult +where + T: Serialize, +{ + Ok(PyBytes::new(py, &serialize(¶ms).unwrap()).to_object(py)) +} + +pub fn deserialize_params<'p, T>(py: Python<'p>, state: PyObject) -> PyResult +where + T: DeserializeOwned + Clone, +{ + match state.extract::<&PyBytes>(py) { + Ok(s) => { + let params: T = deserialize(s.as_bytes()).unwrap(); + Ok(params) + } + Err(e) => Err(e), + } +} diff --git a/python/vtext/tests/test_common.py b/python/vtext/tests/test_common.py new file mode 100644 index 0000000..47fa0f0 --- /dev/null +++ b/python/vtext/tests/test_common.py @@ -0,0 +1,45 @@ +# Copyright 2019 vtext developers +# +# Licensed under the Apache License, Version 2.0, +# . This file may not be copied, +# modified, or distributed except according to those terms. + +import pickle +import pytest +from vtext.tokenize import ( + CharacterTokenizer, + RegexpTokenizer, + UnicodeSegmentTokenizer, + VTextTokenizer, +) +from vtext.tokenize_sentence import UnicodeSentenceTokenizer, PunctuationTokenizer +from vtext.stem import SnowballStemmer + + +TOKENIZERS = [ + CharacterTokenizer, + RegexpTokenizer, + UnicodeSegmentTokenizer, + VTextTokenizer, +] + +SENTENCE_TOKENIZERS = [UnicodeSentenceTokenizer, PunctuationTokenizer] +STEMMERS = [SnowballStemmer] + + +@pytest.mark.parametrize("Estimator", TOKENIZERS + SENTENCE_TOKENIZERS + STEMMERS) +def test_pickle(Estimator): + est = Estimator() + params_ref = est.get_params() + + out = pickle.dumps(est) + + est2 = pickle.loads(out) + assert est2.get_params() == params_ref + + +def test_pickle_non_default_params(): + # check that pickling correctly stores estimator parameters + est = CharacterTokenizer(window_size=10) + est2 = pickle.loads(pickle.dumps(est)) + assert est2.get_params()["window_size"] == 10 diff --git a/python/vtext/tests/test_stem.py b/python/vtext/tests/test_stem.py index a7ba091..92796a3 100644 --- a/python/vtext/tests/test_stem.py +++ b/python/vtext/tests/test_stem.py @@ -12,9 +12,9 @@ def test_snowball_stemmer(): assert stemmer.stem("continuité") == "continu" -def test_snowball_stemmer_api(): - # check that not providing init parameters works - SnowballStemmer() +def test_snowball_stemmer_get_params(): + est = SnowballStemmer() + assert est.get_params() == {"lang": "english"} def test_snowball_stemmer_input_validation(): diff --git a/src/tokenize/mod.rs b/src/tokenize/mod.rs index baf61b9..02ed46d 100644 --- a/src/tokenize/mod.rs +++ b/src/tokenize/mod.rs @@ -54,6 +54,7 @@ use crate::errors::VTextError; #[cfg(feature = "python")] use dict_derive::{FromPyObject, IntoPyObject}; use regex::Regex; +use serde::{Deserialize, Serialize}; use std::fmt; use unicode_segmentation::UnicodeSegmentation; @@ -73,7 +74,7 @@ pub struct RegexpTokenizer { } /// Builder for the regexp tokenizer -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "python", derive(FromPyObject, IntoPyObject))] pub struct RegexpTokenizerParams { pattern: String, @@ -137,7 +138,7 @@ pub struct UnicodeSegmentTokenizer { } /// Builder for the unicode segmentation tokenizer -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "python", derive(FromPyObject, IntoPyObject))] pub struct UnicodeSegmentTokenizerParams { word_bounds: bool, @@ -199,7 +200,7 @@ pub struct VTextTokenizer { } /// Builder for the VTextTokenizer -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "python", derive(FromPyObject, IntoPyObject))] pub struct VTextTokenizerParams { lang: String, @@ -360,7 +361,7 @@ pub struct CharacterTokenizer { pub params: CharacterTokenizerParams, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "python", derive(FromPyObject, IntoPyObject))] pub struct CharacterTokenizerParams { window_size: usize, diff --git a/src/tokenize_sentence/mod.rs b/src/tokenize_sentence/mod.rs index 2d1e8bc..0a67630 100644 --- a/src/tokenize_sentence/mod.rs +++ b/src/tokenize_sentence/mod.rs @@ -65,6 +65,7 @@ extern crate unicode_segmentation; #[cfg(feature = "python")] use dict_derive::{FromPyObject, IntoPyObject}; +use serde::{Deserialize, Serialize}; use unicode_segmentation::UnicodeSegmentation; use crate::errors::VTextError; @@ -89,7 +90,7 @@ pub struct UnicodeSentenceTokenizer { } /// Builder for the unicode segmentation tokenizer -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "python", derive(FromPyObject, IntoPyObject))] pub struct UnicodeSentenceTokenizerParams {} @@ -138,7 +139,7 @@ pub struct PunctuationTokenizer { } /// Builder for the punctuation sentence tokenizer -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "python", derive(FromPyObject, IntoPyObject))] pub struct PunctuationTokenizerParams { punctuation: Vec,