Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode stream python #1678

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/python/py_src/tokenizers/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC
Sequence = decoders.Sequence
DecodeStream = decoders.DecodeStream
8 changes: 8 additions & 0 deletions bindings/python/py_src/tokenizers/decoders/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Generated content DO NOT EDIT
class DecodeStream:
"""
Class needed for streaming decode

"""
def __init__(self, skip_special_tokens):
pass

class Decoder:
"""
Base class for all decoders
Expand Down
63 changes: 63 additions & 0 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::{Arc, RwLock};

use crate::pre_tokenizers::from_string;
use crate::tokenizer::PyTokenizer;
use crate::utils::PyPattern;
use pyo3::exceptions;
use pyo3::prelude::*;
Expand Down Expand Up @@ -590,9 +591,71 @@ pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?;
m.add_class::<PySequenceDecoder>()?;
m.add_class::<PyDecodeStream>()?;
Ok(())
}

/// Class needed for streaming decode
///
#[pyclass(module = "tokenizers.decoders", name = "DecodeStream")]
#[derive(Clone)]
pub struct PyDecodeStream {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intersting, so we don't use the rust DecodeStream object. I am guessing it's for ownership reasons? Otherwise we needs to wrap the PyDecodeStream {stream: DecodeStream} with arc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arc doesn't save us. The issue is with the borrow of the tokenizer's lifetime.

You can technically reborrow in every call in rust too, it's just not very "rusty".
For Python, we need to get access to the tokenizer on every call, and cloning into an Arc feels super wasteful (and breaks every update you might do on the tokenizer afterwards).

This seems innocuous enough since currently users have to hold already a reference to the tokenizer anyway.

/// Regular decode option that is kept throughout.
skip_special_tokens: bool,
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks.
/// This typically contains 3 parts:
/// - read
/// - prefix
/// - rest
///
/// Read is the bit necessary to surround the prefix
/// so decoding the whole ids produces a valid prefix.
/// Prefix is the previously produced string, kept around to trim off of
/// the next valid chunk
ids: Vec<u32>,
/// The previously returned chunk that needs to be discarded from the
/// decoding of the current ids to produce the next chunk
prefix: String,
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}

#[pymethods]
impl PyDecodeStream {
#[new]
#[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil @ArthurZucker

Thanks for this API! Im working on integrating it into VLLM.

QQ - this API is clear for step(). However, would it be possible to pass the prefill tokens to new()?

IIUC, the current API requires me to call step N times for N prefill tokens before I get into the decode phase. Is that right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep we'll add if not already possible!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1725 for the record!

fn new(skip_special_tokens: bool) -> Self {
PyDecodeStream {
skip_special_tokens,
ids: vec![],
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}

#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
ToPyResult(tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
&mut self.read_index,
))
.into()
}
}

#[cfg(test)]
mod test {
use std::sync::{Arc, RwLock};
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
#[derive(Clone, Serialize)]
#[serde(transparent)]
pub struct PyTokenizer {
tokenizer: Tokenizer,
pub(crate) tokenizer: Tokenizer,
}

impl PyTokenizer {
Expand Down
32 changes: 32 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tokenizers.pre_tokenizers import ByteLevel, Metaspace
from tokenizers.processors import RobertaProcessing, TemplateProcessing
from tokenizers.normalizers import Strip, Lowercase, Sequence
from tokenizers.decoders import ByteFallback, DecodeStream, Metaspace as DecoderMetaspace


from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files
Expand Down Expand Up @@ -365,6 +366,37 @@ def test_decode(self):
output = tokenizer.decode_batch([[0, 1, 2, 3], [4]])
assert output == ["my name is john", "pair"]

# Can decode stream
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 0) == "my"
assert stream.step(tokenizer, 1) == " name"
assert stream.step(tokenizer, 2) == " is"
assert stream.step(tokenizer, 3) == " john"

def test_decode_stream(self):
vocab = [
("<unk>", 0.0),
("<0x20>", -0.1),
("<0xC3>", -0.2),
("<0xA9>", -0.3),
]
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))
tokenizer.decoder = ByteFallback()
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 1) == " "
assert stream.step(tokenizer, 2) == None
assert stream.step(tokenizer, 3) == "é"

vocab = [
("<unk>", 0.0),
("▁This", -0.1),
]
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
tokenizer.decoder = DecoderMetaspace()
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 1) == "This"
assert stream.step(tokenizer, 1) == " This"

def test_get_vocab(self):
tokenizer = Tokenizer(BPE())
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
Expand Down
60 changes: 43 additions & 17 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,24 +1069,50 @@ where

/// See [`DecodeStream`]
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id);
let string = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
if string.len() > self.prefix.len() && !string.ends_with('�') {
if !(string.starts_with(&self.prefix)) {
return Err(Box::new(DecodeStreamError::InvalidPrefix));
}
let new_text = &string[self.prefix.len()..].to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.ids = self.ids.drain(self.read_index..).collect();
self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
self.read_index = self.prefix_index;
self.prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
step_decode_stream(
self.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
&mut self.read_index,
)
}
}

/// Internal function exposed only to bypass python limitations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what were the limitations?

pub fn step_decode_stream<M, N, PT, PP, D>(
tokenizer: &TokenizerImpl<M, N, PT, PP, D>,
id: u32,
skip_special_tokens: bool,
ids: &mut Vec<u32>,
prefix: &mut String,
prefix_index: &mut usize,
read_index: &mut usize,
) -> Result<Option<String>>
where
M: Model,
N: Normalizer,
PT: PreTokenizer,
PP: PostProcessor,
D: Decoder,
{
ids.push(id);
let string = tokenizer.decode(ids.as_slice(), skip_special_tokens)?;
if string.len() > prefix.len() && !string.ends_with('�') {
if !(string.starts_with(&*prefix)) {
return Err(Box::new(DecodeStreamError::InvalidPrefix));
}
let new_text = &string[prefix.len()..].to_string();
let new_prefix_index = ids.len() - *prefix_index;
*ids = ids.drain(*read_index..).collect();
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
*read_index = *prefix_index;
*prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
}
}

Expand Down
Loading