Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade hydra-node fork to latest transformers #2

Merged
merged 9 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from queue import Queue
from typing import TYPE_CHECKING, Optional

# from transformers.generation.utils import (GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput)

if TYPE_CHECKING:
from ..models.auto import AutoTokenizer
Expand Down Expand Up @@ -225,3 +226,108 @@ def __next__(self):
raise StopIteration()
else:
return value


class OutputStreamer(BaseStreamer):
"""
Streams Output objects
"""
def __init__(self,
filter_func=None,
cache = None,
):
if filter_func is None:
filter_func = self._filter_func
self.filter_func = filter_func
if cache is None:
cache = []
self.cache = cache # incoming unprocessed outputs

def _filter_func(self, value):
"""
Class-default behavior for self.filter_func.
self.filter_func will be called on each incoming value. Can be used to filter the stream to a particular
attribute on the value object, or to limit the stream to values meeting certain criteria.
"""
return value

def process_incoming_value(self, value):
"""
Called on each incoming value
"""
return self.filter_func(value)

def is_ready(self):
"""
Test whether the buffer is ready
"""
return len(self.cache) > 1

def on_ready(self):
"""
When the buffer is ready, flush it and do something with the values it was holding
"""
if len(self.cache) > 1:
values = self.cache[:]
elif len(self.cache) == 1:
values = self.cache[0]
values = [values] # put it in a list to be consistent
else:
raise ValueError("on_ready() called on an empty buffer. This should not happen. Report this error.")
self.cache = []
return self.process_outgoing_values(values)

def process_outgoing_values(self, values):
"""
What to do with the values that were previously in the buffer
"""
return values

def put(self, value):
value = self.process_incoming_value(value)
if value is not None:
if isinstance(value, list):
self.cache.extend(value)
else:
self.cache.append(value)

if self.is_ready():
return self.on_ready()


class OutputIteratorStreamer(OutputStreamer):
def __init__(self,
filter_func=None,
cache = None,
queue=None,
timeout: Optional[float] = None,
):
super().__init__(filter_func=filter_func, cache=cache)
if queue is None:
queue = Queue()
self.queue = queue # outgoing finalized outputs
self.timeout = timeout
self.stop_signal = None

def process_outgoing_values(self, values):
"""
What to do with the values that were previously in the buffer
"""
self.queue.put(values)


def __iter__(self):
return self

def __next__(self):
value = self.queue.get(timeout=self.timeout)
if value == self.stop_signal:
raise StopIteration()
else:
return value

def end(self):
# flush the cache if there's anything in it
if self.cache:
self.on_ready()
self.queue.put(self.stop_signal)
146 changes: 94 additions & 52 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,33 @@ class GenerationMixin:
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""

def _prepare_output(
self, *,
return_dict_in_generate,
**output_kargs):
if return_dict_in_generate:
if self.config.is_encoder_decoder:
cls = GenerateEncoderDecoderOutput
else:
cls =GenerateDecoderOnlyOutput
if 'decoder_attentions' in output_kargs:
output_kargs['attentions'] = output_kargs.pop('decoder_attentions')
if 'decoder_hidden_states' in output_kargs:
output_kargs['hidden_states'] = output_kargs.pop('decoder_hidden_states')

if 'encoder_attentions' in output_kargs:
output_kargs.pop('encoder_attentions')
if 'encoder_hidden_states' in output_kargs:
output_kargs.pop('encoder_hidden_states')
if 'cross_attentions' in output_kargs:
output_kargs.pop('cross_attentions')

outv = cls(**output_kargs)
else:
outv = output_kargs['sequences']
return outv


def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
Expand Down Expand Up @@ -1858,7 +1885,12 @@ def generate(
input_ids = self.heal_tokens(input_ids, tokenizer)

if streamer is not None:
streamer.put(input_ids.cpu())
output_stub = self._prepare_output(
return_dict_in_generate=generation_config.return_dict_in_generate,
sequences=input_ids,
# no scores/logits/attention/hidden here because they haven't been computed yet.
)
streamer.put(output_stub)

# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
Expand Down Expand Up @@ -2546,6 +2578,11 @@ def _contrastive_search(
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# initialize variables for self._prepare_output
encoder_attentions = encoder_hidden_states = None
next_step_cross_attentions = ()
next_step_decoder_attentions = ()

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
Expand Down Expand Up @@ -2781,8 +2818,6 @@ def _contrastive_search(

# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
if self.config.is_encoder_decoder:
next_step_cross_attentions = ()
next_step_decoder_attentions = ()
if output_attentions:
for layer in outputs.cross_attentions:
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
Expand Down Expand Up @@ -2819,7 +2854,21 @@ def _contrastive_search(
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
output_stub = self._prepare_output(
return_dict_in_generate=return_dict_in_generate,
sequences=next_tokens,
scores=(processed_logit_for_next_step,), # (scores,),
logits=(processed_logit_for_next_step,),
# I think there's an issue with the contrastive sampling implementation that is currently returning the same values for logits as scores #(logits[selected_idx,:],), #(logit_for_next_step,), # `logit_for_next_step`: values don't match, `logits`: shapes don't match
encoder_attentions=None, # probably doesn't make sense to stream this
encoder_hidden_states=None, # probably doesn't make sense to stream this
decoder_attentions=(next_step_decoder_attentions,),
# ([0],),# very concerning that if I set this to `([0],)` my tests don't fail
cross_attentions=(next_step_cross_attentions,),
decoder_hidden_states=(next_decoder_hidden_states,),
past_key_values=None, # probably doesn't make sense to stream this
)
streamer.put(output_stub)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
Expand Down Expand Up @@ -2851,29 +2900,18 @@ def _contrastive_search(
past_key_values.append(tuple(layer_past_key_values))
model_kwargs["past_key_values"] = tuple(past_key_values)

if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
return self._prepare_output(
return_dict_in_generate=return_dict_in_generate,
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values")
)

def _sample(
self,
Expand Down Expand Up @@ -2934,6 +2972,10 @@ def _sample(
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# initialize variables for self._prepare_output(...)
encoder_attentions = encoder_hidden_states = None
next_decoder_attentions = next_cross_attentions = next_decoder_hidden_states = None

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
Expand Down Expand Up @@ -3006,7 +3048,19 @@ def _sample(
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
output_stub = self._prepare_output(
return_dict_in_generate=return_dict_in_generate,
sequences=next_tokens,
scores=(next_token_scores,),
logits=(next_token_logits,),
encoder_attentions=None,
encoder_hidden_states=None,
decoder_attentions=(next_decoder_attentions,),
cross_attentions=(next_cross_attentions,),
decoder_hidden_states=(next_decoder_hidden_states,),
past_key_values=None,
)
streamer.put(output_stub)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
Expand All @@ -3024,30 +3078,18 @@ def _sample(
if streamer is not None:
streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
return self._prepare_output(
return_dict_in_generate=return_dict_in_generate,
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values")
)

def _temporary_reorder_cache(self, past_key_values, beam_idx):
"""
Expand Down
Loading
Loading