Skip to content

Add FxHash and ShortStringOptimization. #1733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
4 changes: 2 additions & 2 deletions bindings/node/src/models.rs
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ use crate::trainers::Trainer;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use tokenizers as tk;
@@ -95,7 +95,7 @@ impl tk::Model for Model {
self.model.as_ref()?.read().unwrap().id_to_token(id)
}

fn get_vocab(&self) -> HashMap<String, u32> {
fn get_vocab(&self) -> FxHashMap<String, u32> {
self
.model
.as_ref()
4 changes: 2 additions & 2 deletions bindings/node/src/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ use crate::pre_tokenizers::PreTokenizer;
use crate::processors::Processor;
use crate::tasks::tokenizer::{DecodeBatchTask, DecodeTask, EncodeBatchTask, EncodeTask};
use crate::trainers::Trainer;
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use tokenizers::Model as ModelTrait;

use napi::bindgen_prelude::*;
@@ -433,7 +433,7 @@ impl Tokenizer {
}

#[napi]
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> HashMap<String, u32> {
pub fn get_vocab(&self, with_added_tokens: Option<bool>) -> FxHashMap<String, u32> {
let with_added_tokens = with_added_tokens.unwrap_or(true);
self.tokenizer.read().unwrap().get_vocab(with_added_tokens)
}
2 changes: 2 additions & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@ pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] }
numpy = "0.23"
ndarray = "0.16"
itertools = "0.12"
rustc-hash = "2.1.1"
compact_str = { version = "0.8.1", features = ["serde"] }

[dependencies.tokenizers]
path = "../../tokenizers"
94 changes: 68 additions & 26 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ use std::sync::{Arc, RwLock};
use crate::pre_tokenizers::from_string;
use crate::tokenizer::PyTokenizer;
use crate::utils::PyPattern;
use compact_str::ToCompactString;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
@@ -91,7 +92,10 @@ impl PyDecoder {
}

impl Decoder for PyDecoder {
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode_chain<T: ToCompactString>(
&self,
tokens: Vec<T>,
) -> tk::Result<Vec<impl ToCompactString>> {
self.decoder.decode_chain(tokens)
}
}
@@ -139,7 +143,12 @@ impl PyDecoder {
/// :obj:`str`: The decoded string
#[pyo3(text_signature = "(self, tokens)")]
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
ToPyResult(self.decoder.decode(tokens)).into()
ToPyResult(
self.decoder
.decode(tokens)
.map(|t| t.to_compact_string().to_string()),
)
.into()
}

fn __repr__(&self) -> PyResult<String> {
@@ -235,12 +244,12 @@ pub struct PyWordPieceDec {}
impl PyWordPieceDec {
#[getter]
fn get_prefix(self_: PyRef<Self>) -> String {
getter!(self_, WordPiece, prefix.clone())
getter!(self_, WordPiece, prefix.clone().to_string())
}

#[setter]
fn set_prefix(self_: PyRef<Self>, prefix: String) {
setter!(self_, WordPiece, prefix, prefix);
setter!(self_, WordPiece, prefix, prefix.to_compact_string());
}

#[getter]
@@ -256,7 +265,10 @@ impl PyWordPieceDec {
#[new]
#[pyo3(signature = (prefix = String::from("##"), cleanup = true), text_signature = "(self, prefix=\"##\", cleanup=True)")]
fn new(prefix: String, cleanup: bool) -> (Self, PyDecoder) {
(PyWordPieceDec {}, WordPiece::new(prefix, cleanup).into())
(
PyWordPieceDec {},
WordPiece::new(prefix.to_compact_string(), cleanup).into(),
)
}
}

@@ -412,12 +424,12 @@ pub struct PyBPEDecoder {}
impl PyBPEDecoder {
#[getter]
fn get_suffix(self_: PyRef<Self>) -> String {
getter!(self_, BPE, suffix.clone())
getter!(self_, BPE, suffix.to_string())
}

#[setter]
fn set_suffix(self_: PyRef<Self>, suffix: String) {
setter!(self_, BPE, suffix, suffix);
setter!(self_, BPE, suffix, suffix.into());
}

#[new]
@@ -443,22 +455,27 @@ pub struct PyCTCDecoder {}
impl PyCTCDecoder {
#[getter]
fn get_pad_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, pad_token.clone())
getter!(self_, CTC, pad_token.to_string())
}

#[setter]
fn set_pad_token(self_: PyRef<Self>, pad_token: String) {
setter!(self_, CTC, pad_token, pad_token);
setter!(self_, CTC, pad_token, pad_token.into());
}

#[getter]
fn get_word_delimiter_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, word_delimiter_token.clone())
getter!(self_, CTC, word_delimiter_token.clone()).to_string()
}

#[setter]
fn set_word_delimiter_token(self_: PyRef<Self>, word_delimiter_token: String) {
setter!(self_, CTC, word_delimiter_token, word_delimiter_token);
setter!(
self_,
CTC,
word_delimiter_token,
word_delimiter_token.into()
);
}

#[getter]
@@ -526,22 +543,33 @@ impl CustomDecoder {
}

impl Decoder for CustomDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
fn decode<T: ToCompactString>(&self, tokens: Vec<T>) -> tk::Result<impl ToCompactString> {
let tokens: Vec<String> = tokens
.into_iter()
.map(|t| t.to_compact_string().to_string())
.collect();
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode", (tokens,), None)?
.extract(py)?;
.extract::<String>(py)?;
Ok(decoded)
})
}

fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode_chain<T: ToCompactString>(
&self,
tokens: Vec<T>,
) -> tk::Result<Vec<impl ToCompactString>> {
let tokens: Vec<String> = tokens
.into_iter()
.map(|t| t.to_compact_string().to_string())
.collect();
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode_chain", (tokens,), None)?
.extract(py)?;
.extract::<Vec<String>>(py)?;
Ok(decoded)
})
}
@@ -595,10 +623,21 @@ where
}

impl Decoder for PyDecoderWrapper {
fn decode_chain(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
fn decode_chain<T: ToCompactString>(
&self,
tokens: Vec<T>,
) -> tk::Result<Vec<impl ToCompactString>> {
match self {
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode_chain(tokens),
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode_chain(tokens),
PyDecoderWrapper::Wrapped(inner) => inner
.read()
.unwrap()
.decode_chain(tokens)
.map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()),
PyDecoderWrapper::Custom(inner) => inner
.read()
.unwrap()
.decode_chain(tokens)
.map(|v| v.into_iter().map(|t| t.to_compact_string()).collect()),
}
}
}
@@ -663,14 +702,17 @@ impl PyDecodeStream {

#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
ToPyResult(tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
))
ToPyResult(
tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix.to_compact_string(),
&mut self.prefix_index,
)
.map(|o| o.map(|s| s.to_string())),
)
.into()
}
}
6 changes: 5 additions & 1 deletion bindings/python/src/encoding.rs
Original file line number Diff line number Diff line change
@@ -127,7 +127,11 @@ impl PyEncoding {
/// :obj:`List[str]`: The list of tokens
#[getter]
fn get_tokens(&self) -> Vec<String> {
self.encoding.get_tokens().to_vec()
self.encoding
.get_tokens()
.iter()
.map(|x| x.to_string())
.collect()
}

/// The generated word indices.
Loading