Skip to content
Merged
2 changes: 1 addition & 1 deletion bindings/python/py_src/tokenizers/decoders/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ class DecodeStream:
Class needed for streaming decode

"""
def __init__(self, skip_special_tokens):
def __init__(self, ids=None, skip_special_tokens=False):
pass

class Decoder:
Expand Down
37 changes: 30 additions & 7 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,21 +646,44 @@ pub struct PyDecodeStream {
prefix_index: usize,
}

#[derive(Clone)]
enum StreamInput {
Id(u32),
Ids(Vec<u32>),
}

impl FromPyObject<'_> for StreamInput {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(id) = obj.extract::<u32>() {
Ok(StreamInput::Id(id))
} else if let Ok(ids) = obj.extract::<Vec<u32>>() {
Ok(StreamInput::Ids(ids))
} else {
Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
"StreamInput must be either an integer or a list of integers",
))
}
}
}

#[pymethods]
impl PyDecodeStream {
#[new]
#[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")]
fn new(skip_special_tokens: bool) -> Self {
#[pyo3(signature = (ids=None, skip_special_tokens=false), text_signature = "(self, ids=None, skip_special_tokens=False)")]
fn new(ids: Option<Vec<u32>>, skip_special_tokens: Option<bool>) -> Self {
PyDecodeStream {
skip_special_tokens,
ids: vec![],
prefix: "".to_string(),
skip_special_tokens: skip_special_tokens.unwrap_or(false),
ids: ids.unwrap_or_default(),
prefix: String::new(),
prefix_index: 0,
}
}

#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
fn step(&mut self, tokenizer: &PyTokenizer, id: StreamInput) -> PyResult<Option<String>> {
let id: Vec<u32> = match id {
StreamInput::Id(id) => vec![id],
StreamInput::Ids(ids) => ids,
};
ToPyResult(tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
Expand Down
8 changes: 4 additions & 4 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ impl PyBertProcessing {
}

#[getter]
fn get_sep(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
fn get_sep(self_: PyRef<'_, Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
let py = self_.py();
let (tok, id) = getter!(self_, Bert, get_sep_copy());
PyTuple::new(
Expand All @@ -358,7 +358,7 @@ impl PyBertProcessing {
}

#[getter]
fn get_cls(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
fn get_cls(self_: PyRef<'_, Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
let py = self_.py();
let (tok, id) = getter!(self_, Bert, get_cls_copy());
PyTuple::new(
Expand Down Expand Up @@ -422,7 +422,7 @@ impl PyRobertaProcessing {
}

#[getter]
fn get_sep(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
fn get_sep(self_: PyRef<'_, Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
let py = self_.py();
let (tok, id) = getter!(self_, Roberta, get_sep_copy());
PyTuple::new(
Expand All @@ -439,7 +439,7 @@ impl PyRobertaProcessing {
}

#[getter]
fn get_cls(self_: PyRef<Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
fn get_cls(self_: PyRef<'_, Self>) -> Result<Bound<'_, PyTuple>, PyErr> {
let py = self_.py();
let (tok, id) = getter!(self_, Roberta, get_cls_copy());
PyTuple::new(
Expand Down
104 changes: 104 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,110 @@ def test_decode(self):
assert stream.step(tokenizer, 2) == " is"
assert stream.step(tokenizer, 3) == " john"

stream = DecodeStream(ids=[0, 1, 2])
assert stream.step(tokenizer, 3) == " john"

def test_decode_stream_fallback(self):
tokenizer = Tokenizer.from_pretrained("gpt2")
# tokenizer.decode([255]) fails because its a fallback
# tokenizer.encode("อั").ids = [19567, 255, 19567, 109]
stream = DecodeStream()
stream.step(tokenizer, [19567])
stream.step(tokenizer, [255])
stream.step(tokenizer, [19567])
out = stream.step(tokenizer, [109])
assert out == "ั"

stream = DecodeStream()
out = stream.step(tokenizer, [19567, 255, 19567, 109])
assert out == "อั"
stream = DecodeStream()
stream.step(tokenizer, [19567])
out = stream.step(tokenizer, [255, 19567, 109])
assert out == "อั"

stream = DecodeStream()
stream.step(tokenizer, [19567])
first_out = stream.step(tokenizer, [255])
assert first_out == "อ"
# since we emitted the 'อ', we can't produce 'อั'
out = stream.step(tokenizer, [19567, 109])
assert out == "ั"

stream = DecodeStream([19567, 255, 19567])
# the stream's prefix is 'อ�' which is invalid, thus all ids are kept for the next step
out = stream.step(tokenizer, [109])
assert out == "อั"

def test_decode_skip_special_tokens(self):
tokenizer = Tokenizer.from_pretrained("hf-internal-testing/Llama-3.1-8B-Instruct")

stream = DecodeStream([40])
out = stream.step(tokenizer, [2846, 40, 40, 40])
assert out == "'mIII"

stream = DecodeStream(
[
128000,
128006,
9125,
128007,
271,
38766,
1303,
33025,
2696,
25,
6790,
220,
2366,
18,
198,
15724,
2696,
25,
220,
1627,
10263,
220,
2366,
19,
271,
9514,
527,
264,
11190,
18328,
13,
128009,
128006,
882,
128007,
271,
15339,
11,
1268,
527,
499,
30,
128009,
128006,
78191,
128007,
271,
]
)
out = stream.step(tokenizer, 40)
assert out == "I"

stream = DecodeStream([40])
out = stream.step(tokenizer, 2846)
assert out == "'m"

stream = DecodeStream([40])
out = stream.step(tokenizer, [2846, 40, 40, 40])
assert out == "'mIII"

def test_decode_stream(self):
vocab = [
("<unk>", 0.0),
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/models/unigram/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ impl Unigram {
}

/// Iterate of vocabulary of the model as a pair of `(token, score)`.
pub fn iter(&self) -> UnigramIterator {
pub fn iter(&self) -> UnigramIterator<'_> {
UnigramIterator { model: self, i: 0 }
}

Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/models/unigram/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl<Label: Eq + Hash + Copy> Trie<Label> {
node.is_leaf = true;
}

pub fn common_prefix_search<T>(&self, iterator: T) -> TrieIterator<Label, T>
pub fn common_prefix_search<T>(&self, iterator: T) -> TrieIterator<'_, Label, T>
where
T: Iterator<Item = Label>,
{
Expand Down
30 changes: 23 additions & 7 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1041,8 +1041,12 @@ pub struct DecodeStream<'tok, M, N, PT, PP, D> {

#[derive(thiserror::Error, Debug)]
pub enum DecodeStreamError {
#[error("Invalid prefix encountered")]
InvalidPrefix,
#[error("Invalid prefix encountered while decoding stream. Token ID: {token_id}, Expected prefix: '{expected_prefix}', Actual string: '{actual_string}'")]
InvalidPrefix {
token_id: u32,
expected_prefix: String,
actual_string: String,
},
}

impl<'tok, M, N, PT, PP, D> DecodeStream<'tok, M, N, PT, PP, D>
Expand All @@ -1067,7 +1071,7 @@ where
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
step_decode_stream(
self.tokenizer,
id,
vec![id],
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
Expand All @@ -1079,7 +1083,7 @@ where
/// Internal function exposed only to bypass python limitations
pub fn step_decode_stream<M, N, PT, PP, D>(
tokenizer: &TokenizerImpl<M, N, PT, PP, D>,
id: u32,
token_ids: Vec<u32>,
skip_special_tokens: bool,
ids: &mut Vec<u32>,
prefix: &mut String,
Expand All @@ -1092,12 +1096,25 @@ where
PP: PostProcessor,
D: Decoder,
{
ids.push(id);
if prefix.is_empty() && !ids.is_empty() {
let new_prefix = tokenizer.decode(ids, skip_special_tokens)?;
if !new_prefix.ends_with('�') {
*prefix = new_prefix;
*prefix_index = ids.len();
}
}

ids.extend(token_ids);
let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?;
if string.len() > prefix.len() && !string.ends_with('�') {
if !(string.starts_with(&*prefix)) {
return Err(Box::new(DecodeStreamError::InvalidPrefix));
return Err(Box::new(DecodeStreamError::InvalidPrefix {
token_id: *ids.last().unwrap(),
expected_prefix: prefix.clone(),
actual_string: string,
}));
}

let new_text = &string[prefix.len()..].to_string();
let new_prefix_index = ids.len() - *prefix_index;
*ids = ids.drain(*prefix_index..).collect();
Expand All @@ -1108,7 +1125,6 @@ where
Ok(None)
}
}

impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
where
M: Model,
Expand Down
Loading