Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream ModelOutputs #29545

Closed
wants to merge 49 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f901ee0
skeleton OutputStreamer
dmarx Feb 29, 2024
57a77d1
hacky but passes tests
dmarx Feb 29, 2024
b34c0b2
skeleton OutputIteratorStreamer, notes
dmarx Mar 1, 2024
5eb0e69
fleshed out OutputIteratorStreamer w test case
dmarx Mar 1, 2024
ff35e2b
token_id streaming passes
dmarx Mar 1, 2024
76ed5c2
contrastive passes token_id test
dmarx Mar 1, 2024
3885ea1
randomize test seed
dmarx Mar 1, 2024
2958099
greedy scores pass test
dmarx Mar 1, 2024
976eb23
ensure we're building the output incrementally
dmarx Mar 2, 2024
cabb5fe
multinom sampling outputs match
dmarx Mar 4, 2024
9d075c9
throw error if on_ready called on empty buffer
dmarx Mar 5, 2024
ba48d71
enforce list(values) in on_ready()
dmarx Mar 6, 2024
108d3c2
POC output_constructor
dmarx Mar 6, 2024
02ec905
DRY tests following on_ready() fix
dmarx Mar 6, 2024
ebe4798
fix function spelling
dmarx Mar 6, 2024
528d79a
moved output_constructor to class method
dmarx Mar 6, 2024
97812bf
rename output_constructor -> _prepare_output
dmarx Mar 6, 2024
2c03ae6
integrating _prepare_output
dmarx Mar 6, 2024
64732bb
finished integrating _prepare_output
dmarx Mar 6, 2024
e903f04
placeholder args for attention/hidden streaming
dmarx Mar 6, 2024
ca4dc8d
cleanup
dmarx Mar 6, 2024
6e17323
explicit GenerateEncoderDecoderOutput support
dmarx Mar 6, 2024
3e2ff84
parameterized OutputStreamer tests
dmarx Mar 7, 2024
256884f
fixed tests
dmarx Mar 7, 2024
75bf3d7
assert same types on field, cleanup
dmarx Mar 7, 2024
673837b
tuple-of-tensors output type parity
dmarx Mar 7, 2024
8dd7973
cleanup, test type consistency yielded of stream
dmarx Mar 7, 2024
cd644fb
test emits helpful info on failure
dmarx Mar 7, 2024
2ad0ead
output_attentions streaming for greedy decoding
dmarx Mar 7, 2024
e523c53
fix skipped arguments
dmarx Mar 7, 2024
b6c2dc1
'fixed' tests, but now very messy
dmarx Mar 7, 2024
7ce491f
cleaned up
dmarx Mar 7, 2024
290b9cb
test checks all output attrs but past_key_values
dmarx Mar 8, 2024
41cb5e2
test over model varieties
dmarx Mar 8, 2024
a7a8a93
moved instantiation closer to use
dmarx Mar 8, 2024
76ad30d
attention for multinom decoding
dmarx Mar 8, 2024
21adb61
contrastive working after multinom supported
dmarx Mar 8, 2024
916bf43
attention streaming passes all test cases
dmarx Mar 8, 2024
0b28bab
add assistive decoding to test parameterization
dmarx Mar 8, 2024
b89dc83
draft streaming assisted, exceeds max tokens
dmarx Mar 8, 2024
1e02df6
back out assisted decoding changes for the moment
dmarx Mar 8, 2024
965c55c
tuple consistency
dmarx Mar 8, 2024
6c7422e
refactor tests
dmarx Mar 8, 2024
314c5e1
refactored tests, new issues w contrastive
dmarx Mar 8, 2024
33fded0
added note
dmarx Mar 8, 2024
2a4f657
suppress contrastive tests ftm
dmarx Mar 8, 2024
eb0d567
setting scores=logits, contrastive passes
dmarx Mar 8, 2024
014c8f5
block debugging import (lovely_tensors)
dmarx Mar 11, 2024
9358cd0
Merge branch 'main' into dmarx.output_streamer
dmarx Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4917,7 +4917,7 @@
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin

# Generation
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, OutputIteratorStreamer
from .hf_argparser import HfArgumentParser

# Integrations
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@


_import_structure = {
"configuration_utils": ["GenerationConfig", "GenerationMode"],
"streamers": ["TextIteratorStreamer", "TextStreamer"],
"configuration_utils": ["GenerationConfig"],
"streamers": ["TextIteratorStreamer", "TextStreamer", "OutputIteratorStreamer"],
}

try:
Expand Down
134 changes: 133 additions & 1 deletion src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from queue import Queue
from typing import TYPE_CHECKING, Optional

import torch

from transformers.generation.utils import (GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput)

if TYPE_CHECKING:
from ..models.auto import AutoTokenizer
Expand All @@ -35,7 +38,126 @@ def end(self):
raise NotImplementedError()


class TextStreamer(BaseStreamer):
class OutputStreamer(BaseStreamer):
"""
Streams Output objects
"""
def __init__(self,
filter_func=None,
cache = None,
):
if filter_func is None:
filter_func = self._filter_func
self.filter_func = filter_func
if cache is None:
cache = []
self.cache = cache # incoming unprocessed outputs

def _filter_func(self, value):
"""
Class-default behavior for self.filter_func.

self.filter_func will be called on each incoming value. Can be used to filter the stream to a particular
attribute on the value object, or to limit the stream to values meeting certain criteria.
"""
return value

def process_incoming_value(self, value):
"""
Called on each incoming value
"""
return self.filter_func(value)

def is_ready(self):
"""
Test whether the buffer is ready
"""
return len(self.cache) > 1

def on_ready(self):
"""
When the buffer is ready, flush it and do something with the values it was holding
"""
if len(self.cache) > 1:
values = self.cache[:]
elif len(self.cache) == 1:
values = self.cache[0]
values = [values] # put it in a list to be consistent
else:
raise ValueError("on_ready() called on an empty buffer. This should not happen. Report this error.")
self.cache = []
return self.process_outgoing_values(values)

def process_outgoing_values(self, values):
"""
What to do with the values that were previously in the buffer
"""
return values

def put(self, value):
value = self.process_incoming_value(value)
if value is not None:
if isinstance(value, list):
self.cache.extend(value)
else:
self.cache.append(value)

if self.is_ready():
return self.on_ready()


class OutputIteratorStreamer(OutputStreamer):
def __init__(self,
filter_func=None,
cache = None,
queue=None,
timeout: Optional[float] = None,
):
super().__init__(filter_func=filter_func, cache=cache)
if queue is None:
queue = Queue()
self.queue = queue # outgoing finalized outputs
self.timeout = timeout
self.stop_signal = None

def process_outgoing_values(self, values):
"""
What to do with the values that were previously in the buffer
"""
self.queue.put(values)


def __iter__(self):
return self

def __next__(self):
value = self.queue.get(timeout=self.timeout)
if value == self.stop_signal:
raise StopIteration()
else:
return value

def end(self):
# flush the cache if there's anything in it
if self.cache:
self.on_ready()
self.queue.put(self.stop_signal)


###############################

class TokenStreamer(OutputStreamer):
"""
Filters the output stream on tokens to replicate legacy behavior
"""
def _filter_func(self, value):
if isinstance(value, (GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput)):
return value.sequences.cpu()
else:
return value.cpu()


class TextStreamer(TokenStreamer):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.

Expand Down Expand Up @@ -70,6 +192,7 @@ class TextStreamer(BaseStreamer):
"""

def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
super().__init__()
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
Expand All @@ -83,6 +206,15 @@ def put(self, value):
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
# uses the parent classes built-in cache to restrict the "value" object to token_ids
value = self.filter_func(value)
if value is None:
return

#TODO: probably don't need this anymore?
if isinstance(value, list):
value = torch.tensor(value)

if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
Expand Down
Loading