Skip to content

Commit

Permalink
Revert changes to vectorizers
Browse files Browse the repository at this point in the history
  • Loading branch information
rth committed May 3, 2019
1 parent ba87223 commit 3d47c8c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 57 deletions.
19 changes: 8 additions & 11 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use pyo3::wrap_pyfunction;

use vtext::metrics;
use vtext::tokenize;
use vtext::tokenize::Tokenizer;
use vtext::vectorize;

type PyCsrArray = (Py<PyArray1<i32>>, Py<PyArray1<i32>>, Py<PyArray1<i32>>);
Expand Down Expand Up @@ -60,16 +59,15 @@ fn result_to_csr(py: Python, x: CsMat<i32>) -> PyResult<PyCsrArray> {
}

#[pyclass]
pub struct _HashingVectorizerWrapper<'b> {
inner: vtext::vectorize::HashingVectorizer<'b>,
pub struct _HashingVectorizerWrapper {
inner: vtext::vectorize::HashingVectorizer,
}

#[pymethods]
impl<'b> _HashingVectorizerWrapper<'b> {
impl _HashingVectorizerWrapper {
#[new]
fn new(obj: &PyRawObject) {
let tokenizer = vtext::tokenize::RegexpTokenizer::new("\\b\\w\\w+\\b".to_string());
let estimator = vtext::vectorize::HashingVectorizer::new(tokenizer);
let estimator = vtext::vectorize::HashingVectorizer::new();
obj.init(_HashingVectorizerWrapper { inner: estimator });
}

Expand All @@ -85,16 +83,15 @@ impl<'b> _HashingVectorizerWrapper<'b> {
}

#[pyclass]
pub struct _CountVectorizerWrapper<'b> {
inner: vtext::vectorize::CountVectorizer<'b>,
pub struct _CountVectorizerWrapper {
inner: vtext::vectorize::CountVectorizer,
}

#[pymethods]
impl<'b> _CountVectorizerWrapper<'b> {
impl _CountVectorizerWrapper {
#[new]
fn new(obj: &PyRawObject) {
let tokenizer = vtext::tokenize::RegexpTokenizer::new("\\b\\w\\w+\\b".to_string());
let estimator = vtext::vectorize::CountVectorizer::new(tokenizer);
let estimator = vtext::vectorize::CountVectorizer::new();
obj.init(_CountVectorizerWrapper { inner: estimator });
}

Expand Down
31 changes: 14 additions & 17 deletions src/vectorize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,19 @@ This module allows computing a sparse document term matrix from a text corpus.
```rust
extern crate vtext;
use vtext::tokenize::{VTextTokenizer,Tokenizer};
use vtext::vectorize::CountVectorizer;
let documents = vec![
String::from("Some text input"),
String::from("Another line"),
];
let tokenizer = VTextTokenizer::new("en");
let mut vectorizer = CountVectorizer::new(&tokenizer);
let mut vectorizer = CountVectorizer::new();
let X = vectorizer.fit_transform(&documents);
// returns a sparse CSR matrix with document-terms counts
*/

use crate::math::CSRArray;
use crate::tokenize;
use crate::tokenize::Tokenizer;
use hashbrown::HashMap;
use ndarray::Array;
use sprs::CsMat;
Expand Down Expand Up @@ -82,29 +77,29 @@ fn _sum_duplicates(tf: &mut CSRArray, indices_local: &[i32], nnz: &mut usize) {
}

#[derive(Debug)]
pub struct HashingVectorizer<'b> {
pub struct HashingVectorizer {
lowercase: bool,
tokenizer: &'b Tokenizer,
token_pattern: String,
n_features: u64,
}

#[derive(Debug)]
pub struct CountVectorizer<'b> {
pub struct CountVectorizer {
lowercase: bool,
tokenizer: &'b Tokenizer,
token_pattern: String,
// vocabulary uses i32 indices, to avoid memory copies when converting
// to sparse CSR arrays in Python with scipy.sparse
pub vocabulary: HashMap<String, i32>,
}

pub enum Vectorizer {}

impl<'b> CountVectorizer<'b> {
impl CountVectorizer {
/// Initialize a CountVectorizer estimator
pub fn new(tokenizer: &'b Tokenizer) -> Self {
pub fn new() -> Self {
CountVectorizer {
lowercase: true,
tokenizer: tokenizer,
token_pattern: String::from(TOKEN_PATTERN_DEFAULT),
vocabulary: HashMap::with_capacity_and_hasher(1000, Default::default()),
}
}
Expand Down Expand Up @@ -179,12 +174,12 @@ impl<'b> CountVectorizer<'b> {
}
}

impl<'b> HashingVectorizer<'b> {
impl HashingVectorizer {
/// Create a new HashingVectorizer estimator
pub fn new(tokenizer: &'b Tokenizer) -> Self {
pub fn new() -> Self {
HashingVectorizer {
lowercase: true,
tokenizer: tokenizer,
token_pattern: String::from(TOKEN_PATTERN_DEFAULT),
n_features: 1048576,
}
}
Expand All @@ -209,6 +204,8 @@ impl<'b> HashingVectorizer<'b> {
let mut indices_local = Vec::new();
let mut nnz: usize = 0;

let tokenizer = tokenize::RegexpTokenizer::new(TOKEN_PATTERN_DEFAULT.to_string());

// String.to_lowercase() is very slow
// https://www.reddit.com/r/rust/comments/6wbru2/performance_issue_can_i_avoid_of_using_the_slow/
// https://github.com/rust-lang/rust/issues/26244
Expand All @@ -217,7 +214,7 @@ impl<'b> HashingVectorizer<'b> {
let pipe = X.iter().map(|doc| doc.to_ascii_lowercase());

for (_document_id, document) in pipe.enumerate() {
let tokens = self.tokenizer.tokenize(&document);
let tokens = tokenizer.tokenize(&document);
indices_local.clear();
for token in tokens {
// set the RNG seeds to get reproducible hashing
Expand Down
34 changes: 5 additions & 29 deletions src/vectorize/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
// <http://apache.org/licenses/LICENSE-2.0>. This file may not be copied,
// modified, or distributed except according to those terms.

use crate::tokenize::*;
use crate::vectorize::*;

#[test]
fn test_count_vectorizer_simple() {
// Example 1
let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string());

let documents = vec![String::from("cat dog cat")];
let mut vect = CountVectorizer::new(&tokenizer);
let mut vect = CountVectorizer::new();
let X = vect.fit_transform(&documents);
assert_eq!(X.to_dense(), array![[2, 1]]);

Expand All @@ -23,7 +21,7 @@ fn test_count_vectorizer_simple() {
String::from("The sky sky sky is blue"),
];

let mut vect = CountVectorizer::new(&tokenizer);
let mut vect = CountVectorizer::new();
vect.fit(&documents);
let X = vect.transform(&documents);

Expand All @@ -48,9 +46,7 @@ fn test_hashing_vectorizer_simple() {
String::from("The sky is blue"),
];

let tokenizer = VTextTokenizer::new("en");

let vect = HashingVectorizer::new(&tokenizer);
let vect = HashingVectorizer::new();
let vect = vect.fit(&documents);
let X = vect.transform(&documents);
assert_eq!(X.indptr(), &[0, 4, 8]);
Expand Down Expand Up @@ -79,37 +75,17 @@ fn test_hashing_vectorizer_simple() {
fn test_empty_dataset() {
let documents: Vec<String> = vec![];

let tokenizer = VTextTokenizer::new("en");
let mut vectorizer = CountVectorizer::new(&tokenizer);
let mut vectorizer = CountVectorizer::new();

let X = vectorizer.fit_transform(&documents);
assert_eq!(X.data(), &[]);
assert_eq!(X.indices(), &[]);
assert_eq!(X.indptr(), &[0]);

let vectorizer = HashingVectorizer::new(&tokenizer);
let vectorizer = HashingVectorizer::new();

let X = vectorizer.fit_transform(&documents);
assert_eq!(X.data(), &[]);
assert_eq!(X.indices(), &[]);
assert_eq!(X.indptr(), &[0]);
}

#[test]
fn test_dynamic_dispatch_tokenizer() {
let tokenizer = VTextTokenizer::new("en");
CountVectorizer::new(&tokenizer);
HashingVectorizer::new(&tokenizer);

let tokenizer = UnicodeSegmentTokenizer::new(false);
CountVectorizer::new(&tokenizer);
HashingVectorizer::new(&tokenizer);

let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string());
CountVectorizer::new(&tokenizer);
HashingVectorizer::new(&tokenizer);

let tokenizer = CharacterTokenizer::new(4);
CountVectorizer::new(&tokenizer);
HashingVectorizer::new(&tokenizer);
}

0 comments on commit 3d47c8c

Please sign in to comment.