diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fd6cb96c1d24..dbec410a4b63 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -4917,7 +4917,7 @@ from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin # Generation - from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer + from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, OutputIteratorStreamer from .hf_argparser import HfArgumentParser # Integrations diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 8f2a6ad9600d..c79a6d949c22 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -18,8 +18,8 @@ _import_structure = { - "configuration_utils": ["GenerationConfig", "GenerationMode"], - "streamers": ["TextIteratorStreamer", "TextStreamer"], + "configuration_utils": ["GenerationConfig"], + "streamers": ["TextIteratorStreamer", "TextStreamer", "OutputIteratorStreamer"], } try: diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index c75b43466af7..56741b7d1945 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -16,6 +16,9 @@ from queue import Queue from typing import TYPE_CHECKING, Optional +import torch + +from transformers.generation.utils import (GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput) if TYPE_CHECKING: from ..models.auto import AutoTokenizer @@ -35,7 +38,126 @@ def end(self): raise NotImplementedError() -class TextStreamer(BaseStreamer): +class OutputStreamer(BaseStreamer): + """ + Streams Output objects + """ + def __init__(self, + filter_func=None, + cache = None, + ): + if filter_func is None: + filter_func = self._filter_func + self.filter_func = filter_func + if cache is None: + cache = [] + self.cache = cache # incoming unprocessed outputs + + def _filter_func(self, value): + """ + Class-default behavior for self.filter_func. + + self.filter_func will be called on each incoming value. Can be used to filter the stream to a particular + attribute on the value object, or to limit the stream to values meeting certain criteria. + """ + return value + + def process_incoming_value(self, value): + """ + Called on each incoming value + """ + return self.filter_func(value) + + def is_ready(self): + """ + Test whether the buffer is ready + """ + return len(self.cache) > 1 + + def on_ready(self): + """ + When the buffer is ready, flush it and do something with the values it was holding + """ + if len(self.cache) > 1: + values = self.cache[:] + elif len(self.cache) == 1: + values = self.cache[0] + values = [values] # put it in a list to be consistent + else: + raise ValueError("on_ready() called on an empty buffer. This should not happen. Report this error.") + self.cache = [] + return self.process_outgoing_values(values) + + def process_outgoing_values(self, values): + """ + What to do with the values that were previously in the buffer + """ + return values + + def put(self, value): + value = self.process_incoming_value(value) + if value is not None: + if isinstance(value, list): + self.cache.extend(value) + else: + self.cache.append(value) + + if self.is_ready(): + return self.on_ready() + + +class OutputIteratorStreamer(OutputStreamer): + def __init__(self, + filter_func=None, + cache = None, + queue=None, + timeout: Optional[float] = None, + ): + super().__init__(filter_func=filter_func, cache=cache) + if queue is None: + queue = Queue() + self.queue = queue # outgoing finalized outputs + self.timeout = timeout + self.stop_signal = None + + def process_outgoing_values(self, values): + """ + What to do with the values that were previously in the buffer + """ + self.queue.put(values) + + + def __iter__(self): + return self + + def __next__(self): + value = self.queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value + + def end(self): + # flush the cache if there's anything in it + if self.cache: + self.on_ready() + self.queue.put(self.stop_signal) + + +############################### + +class TokenStreamer(OutputStreamer): + """ + Filters the output stream on tokens to replicate legacy behavior + """ + def _filter_func(self, value): + if isinstance(value, (GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput)): + return value.sequences.cpu() + else: + return value.cpu() + + +class TextStreamer(TokenStreamer): """ Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. @@ -70,6 +192,7 @@ class TextStreamer(BaseStreamer): """ def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): + super().__init__() self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.decode_kwargs = decode_kwargs @@ -83,6 +206,15 @@ def put(self, value): """ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. """ + # uses the parent classes built-in cache to restrict the "value" object to token_ids + value = self.filter_func(value) + if value is None: + return + + #TODO: probably don't need this anymore? + if isinstance(value, list): + value = torch.tensor(value) + if len(value.shape) > 1 and value.shape[0] > 1: raise ValueError("TextStreamer only supports batch size 1") elif len(value.shape) > 1: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1d7eef755bf9..414e3db6c35b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -351,6 +351,32 @@ class GenerationMixin: learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). """ + def _prepare_output( + self, *, + return_dict_in_generate, + **output_kargs): + if return_dict_in_generate: + if self.config.is_encoder_decoder: + cls = GenerateEncoderDecoderOutput + else: + cls =GenerateDecoderOnlyOutput + if 'decoder_attentions' in output_kargs: + output_kargs['attentions'] = output_kargs.pop('decoder_attentions') + if 'decoder_hidden_states' in output_kargs: + output_kargs['hidden_states'] = output_kargs.pop('decoder_hidden_states') + + if 'encoder_attentions' in output_kargs: + output_kargs.pop('encoder_attentions') + if 'encoder_hidden_states' in output_kargs: + output_kargs.pop('encoder_hidden_states') + if 'cross_attentions' in output_kargs: + output_kargs.pop('cross_attentions') + + outv = cls(**output_kargs) + else: + outv = output_kargs['sequences'] + return outv + def prepare_inputs_for_generation(self, *args, **kwargs): raise NotImplementedError( "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." @@ -1402,8 +1428,15 @@ def generate( else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + # echo back the prompt + # NB: if user wants prompt logits, this will prob need to be moved down if streamer is not None: - streamer.put(input_ids.cpu()) + output_stub = self._prepare_output( + return_dict_in_generate=generation_config.return_dict_in_generate, + sequences=input_ids, + # no scores/logits/attention/hidden here because they haven't been computed yet. + ) + streamer.put(output_stub) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] @@ -1919,6 +1952,7 @@ def _contrastive_search( decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + encoder_attentions = encoder_hidden_states = None # initialize variables for self._prepare_output if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( @@ -2073,6 +2107,12 @@ def _contrastive_search( context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) + ### dmarx + # NB: I'm a bit confused why the logic here is disguised in this _ranking_fast function, which is only used here + # but is defined 2000 lines later down the file. Moreover, I think that means the returned scores will never + # take into account this "degeneration penalty" that's applied here for the re-ranking + ### /dmarx + # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't # introduce (noticeable) slowdowns on single-device runs. @@ -2122,9 +2162,11 @@ def _contrastive_search( logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration + next_step_cross_attentions = () + next_step_decoder_attentions = () if self.config.is_encoder_decoder: - next_step_cross_attentions = () - next_step_decoder_attentions = () + #next_step_cross_attentions = () + #next_step_decoder_attentions = () if output_attentions: for layer in outputs.cross_attentions: layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] @@ -2163,7 +2205,20 @@ def _contrastive_search( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: - streamer.put(next_tokens.cpu()) + output_stub = self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=next_tokens, + scores=(processed_logit_for_next_step,), #(scores,), + logits=(processed_logit_for_next_step,), # I think there's an issue with the contrastive sampling implementation that is currently returning the same values for logits as scores #(logits[selected_idx,:],), #(logit_for_next_step,), # `logit_for_next_step`: values don't match, `logits`: shapes don't match + encoder_attentions=None, # probably doesn't make sense to stream this + encoder_hidden_states=None, # probably doesn't make sense to stream this + decoder_attentions=(next_step_decoder_attentions,), # ([0],),# very concerning that if I set this to `([0],)` my tests don't fail + cross_attentions=(next_step_cross_attentions,), + decoder_hidden_states=(next_decoder_hidden_states,), + past_key_values=None, # probably doesn't make sense to stream this + ) + streamer.put(output_stub) + model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) @@ -2198,29 +2253,18 @@ def _contrastive_search( past_key_values.append(tuple(layer_past_key_values)) model_kwargs["past_key_values"] = tuple(past_key_values) - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids + return self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values") + ) def greedy_search(self, *args, **kwargs): logger.warning_once( @@ -2377,12 +2421,16 @@ def _greedy_search( decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + encoder_attentions = encoder_hidden_states = None # initialize variables for self._prepare_output(...) if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) + # initialize variables for streamer.put(self._prepare_output(...)) + next_decoder_attentions = next_cross_attentions = next_decoder_hidden_states = None + # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) @@ -2424,20 +2472,23 @@ def _greedy_search( if output_logits: raw_logits += (next_token_logits,) if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + next_decoder_attentions = ( + outputs.decoder_attentions if self.config.is_encoder_decoder else outputs.attentions ) + decoder_attentions += (next_decoder_attentions,) if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) + next_cross_attentions = outputs.cross_attentions + cross_attentions += (next_cross_attentions,) if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) + next_decoder_hidden_states = ( + outputs.decoder_hidden_states if self.config.is_encoder_decoder - else (outputs.hidden_states,) + else outputs.hidden_states ) + decoder_hidden_states += (next_decoder_hidden_states,) - # argmax + # argmax next_tokens = torch.argmax(next_tokens_scores, dim=-1) # finished sentences should have their next token be a padding token @@ -2449,7 +2500,20 @@ def _greedy_search( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: - streamer.put(next_tokens.cpu()) + output_stub = self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=next_tokens, + scores=(next_tokens_scores,), + logits=(next_token_logits,), + encoder_attentions=None, # (encoder_attentions,), # this will always be the same values for each streamed token. not sure it makes sense to stream it + encoder_hidden_states=None, # (encoder_hidden_states,), # this will always be the same values for each streamed token. not sure it makes sense to stream it + decoder_attentions=(next_decoder_attentions,), # ok this time changing it to `([0],)` causes a test failure. so that's good. + cross_attentions=(next_cross_attentions,), # not sure this is right ### changing to `([0],)` does not cause test failure :( + decoder_hidden_states=(next_decoder_hidden_states,), # not sure this is right + past_key_values=None, # probably don't want to stream this just in general + ) + streamer.put(output_stub) + model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, @@ -2475,30 +2539,18 @@ def _greedy_search( if streamer is not None: streamer.end() - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids + return self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values") + ) def sample(self, *args, **kwargs): logger.warning_once( @@ -2676,6 +2728,8 @@ def _sample( decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + encoder_attentions = encoder_hidden_states = None # initialize variables for self._prepare_output(...) + next_decoder_attentions = next_cross_attentions = next_decoder_hidden_states = None # initialize variables for streamer.put(self._prepare_output(...)) if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( @@ -2725,20 +2779,23 @@ def _sample( if output_logits: raw_logits += (next_token_logits,) if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + next_decoder_attentions = ( + outputs.decoder_attentions if self.config.is_encoder_decoder else outputs.attentions ) + decoder_attentions += (next_decoder_attentions,) if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) + next_cross_attentions = outputs.cross_attentions + cross_attentions += (next_cross_attentions,) if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) + next_decoder_hidden_states = ( + outputs.decoder_hidden_states if self.config.is_encoder_decoder - else (outputs.hidden_states,) + else outputs.hidden_states ) + decoder_hidden_states += (next_decoder_hidden_states,) - # sample + # sample probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) @@ -2750,8 +2807,23 @@ def _sample( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: - streamer.put(next_tokens.cpu()) + output_stub = self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + #sequences=next_tokens[:, None] # doesn't seem to make a difference. Handled downstream + sequences=next_tokens, # this seems to be getting coerced to float somewhere? + scores=(next_token_scores,), + logits=(next_token_logits,), + encoder_attentions=None, # probably don't want to stream this + encoder_hidden_states=None, # probably don't want to stream this + decoder_attentions=(next_decoder_attentions,), + cross_attentions=(next_cross_attentions,), + decoder_hidden_states=(next_decoder_hidden_states,), + past_key_values=None, # probably don't want to stream this + ) + streamer.put(output_stub) + model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) @@ -2774,30 +2846,18 @@ def _sample( if streamer is not None: streamer.end() - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids + return self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values") + ) def _temporary_reorder_cache(self, past_key_values, beam_idx): """ @@ -4499,6 +4559,7 @@ def _assisted_decoding( decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + encoder_hidden_states = encoder_attentions = None # initialize variables for self._prepare_output() if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( @@ -4532,9 +4593,11 @@ def _assisted_decoding( candidate_logits = candidate_logits.to(self.device) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + #last_assistant_token_is_eos = False + #if eos_token_id_tensor is not None: last_assistant_token_is_eos = ( ~candidate_input_ids[:, -1] - .tile(eos_token_id_tensor.shape[0], 1) + .tile(eos_token_id_tensor.shape[0], 1) # <<< throwing error in streamer tests. looks like valid behavior for eos_token_id_tensor to be None. .ne(eos_token_id_tensor.unsqueeze(1)) .prod(dim=0) .bool() @@ -4574,6 +4637,7 @@ def _assisted_decoding( # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). max_matches = max_len - cur_len - 1 + #n_matches = None # initialize variable for streamer.put(...) if do_sample and candidate_logits is not None: valid_tokens, n_matches = _speculative_sampling( candidate_input_ids, @@ -4610,8 +4674,22 @@ def _assisted_decoding( # 4.1. Get the valid continuation, after the matching tokens input_ids = torch.cat((input_ids, valid_tokens), dim=-1) - if streamer is not None: - streamer.put(valid_tokens.cpu()) + # move this down + # if streamer is not None: + # output_stub = self._prepare_output( + # return_dict_in_generate=return_dict_in_generate, + # sequences=valid_tokens, + # scores=tuple(new_logits[:, i, :] for i in range(n_matches + 1)), + # # todo: just slice a view into the tensor... new_logits[:, :(n_matches+1), :], right? + # logits=(next_token_logits,), + # encoder_attentions=None, + # encoder_hidden_states=None, + # decoder_attentions=None, + # cross_attentions=None, + # decoder_hidden_states=None, + # past_key_values=None, + # ) + # streamer.put(output_stub) new_cur_len = input_ids.shape[-1] # 4.2. Discard past key values relative to unused assistant tokens @@ -4620,6 +4698,10 @@ def _assisted_decoding( # 5. Update the candidate generation strategy if needed candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + ### dmarx + # NTS: make sure .update_candidate_stragety() isn't mutating its inputs. + # otw we need to move the streamer.put() further down + ### /dmarx if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -4671,6 +4753,32 @@ def _assisted_decoding( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) + if streamer is not None: + # if n_matches is None: + # n_matches = len(valid_tokens) + # if decoder_attentions is not None: + # decoder_attentions = decoder_attentions[: n_matches + 1] + # if cross_attentions is not None: + # cross_attentions = cross_attentions[: n_matches + 1] + # if decoder_hidden_states is not None: + # decoder_hidden_states = decoder_hidden_states[: n_matches + 1] + + output_stub = self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=valid_tokens, + scores=tuple(new_logits[:, i, :] for i in range(n_matches + 1)), + # todo: just slice a view into the tensor... new_logits[:, :(n_matches+1), :], right? + logits=(next_token_logits,), + encoder_attentions=None, + encoder_hidden_states=None, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=None, + ) + streamer.put(output_stub) + + # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( @@ -4699,30 +4807,19 @@ def _assisted_decoding( candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( candidate_generator.num_assistant_tokens ) - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids + + return self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values") + ) def _speculative_sampling( diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py index c82a5e99e0de..3cfd5e5db99e 100644 --- a/tests/generation/test_streamers.py +++ b/tests/generation/test_streamers.py @@ -13,11 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from collections import Counter +import copy from queue import Empty +import random from threading import Thread +import unittest +import pytest -from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available +from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available #, OutputIteratorStreamer +from transformers.generation.streamers import OutputIteratorStreamer # TODO: fix import from transformers.testing_utils import CaptureStdout, require_torch, torch_device from ..test_modeling_common import ids_tensor @@ -28,6 +33,9 @@ from transformers import AutoModelForCausalLM +## for debugging only +# import lovely_tensors as lt +# lt.monkey_patch() @require_torch class StreamerTester(unittest.TestCase): @@ -120,3 +128,282 @@ def test_iterator_streamer_timeout(self): streamer_text = "" for new_text in streamer: streamer_text += new_text + + +def nested_tensor_equality(left, right): + """ + Recursively check equality of tensors nested in tuple of tuples + """ + assert type(left) == type(right) + assert len(left) == len(right) + if isinstance(left, torch.Tensor): + assert torch.equal(left, right) + else: + for left2, right2 in zip(left, right): + assert nested_tensor_equality(left2, right2) + return True + +@require_torch +#class OutputIteratorStreamerTester(unittest.TestCase): # incompatible with pytest.mark.parameterize +class TestOutputIteratorStreamer: + + def _setup(self, + model="hf-internal-testing/tiny-random-gpt2", + # assistant_model, + do_sample=False, + top_k=None, + penalty_alpha=None, + output_scores=False, + output_logits=False, + output_attentions=False, + max_new_tokens=10, + return_dict_in_generate=True, + output_hidden_states=False, + ): + model = AutoModelForCausalLM.from_pretrained(model).to(torch_device) + model.config.eos_token_id = -1 + print(model.config) + + generation_kwargs = dict( + # input_ids=input_ids, + max_new_tokens=max_new_tokens, + return_dict_in_generate=return_dict_in_generate, + do_sample=do_sample, + top_k=top_k, + penalty_alpha=penalty_alpha, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + # if assistant_model: + # # attentions acting funny. suppress for now + # if not output_attentions: + # generation_kwargs['assistant_model'] = copy.deepcopy(model) + # generation_kwargs['assistant_model'].config.eos_token_id = 999 # assistant model needs to have a valid eos_token_id I think + + ### dmarx Force behaviors here for development ########################################### + # lol maybe these should just be separate tests.... + + # output attentions for... + # ...greedy decoding + # generation_kwargs['output_attentions'] = False + # if (not generation_kwargs['do_sample']) and (generation_kwargs['penalty_alpha'] is None): + # generation_kwargs['output_attentions'] = True + # + # # ...multinomial sampling + # if (generation_kwargs['do_sample']) and (generation_kwargs['penalty_alpha'] is None): + # generation_kwargs['output_attentions'] = True + # + # # output attentions for contrastive decoding + # if (generation_kwargs['do_sample']) and (generation_kwargs['penalty_alpha'] is not None) and (generation_kwargs['top_k'] is not None) : + # generation_kwargs['output_attentions'] = True + #### /dmarx ############################################################################## + + print(generation_kwargs) # easier than decoding pytest parameterization shorthand on error + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + generation_kwargs['input_ids'] = input_ids + + baseline_kwargs = copy.deepcopy(generation_kwargs) + test_kwargs = copy.deepcopy(generation_kwargs) + + seed = random.randint(0, int(1e9)) + torch.manual_seed(seed) + baseline_outputs = model.generate(**baseline_kwargs) + print("baseline_outputs") + print(baseline_outputs) + + streamer = OutputIteratorStreamer() + test_kwargs['streamer'] = streamer + torch.manual_seed(seed) + thread = Thread(target=model.generate, kwargs=test_kwargs) + thread.start() + + outputs = {'sequences':torch.Tensor()} + for attr_name in ( + 'scores', 'logits', + 'attentions', 'encoder_attentions', 'decoder_attentions', 'cross_attentions', + 'hidden_states', 'encoder_hidden_states', 'decoder_hidden_states', + #'past_key_values' # uh... let's just say we're not going to support streaming the cache. + ): + if hasattr(baseline_outputs, attr_name): + if getattr(baseline_outputs, attr_name) is not None: + #print(attr_name) + #print(getattr(baseline_outputs, attr_name)) + outputs[attr_name] = () + + return baseline_outputs, outputs, streamer + + + # @pytest.mark.parametrize("do_sample,top_k", [(False,None), (True,4)]) + # @pytest.mark.parametrize("penalty_alpha", [None, 0.6]) + # @pytest.mark.parametrize("output_scores", [False, True]) + # @pytest.mark.parametrize("output_logits", [False, True]) + # @pytest.mark.parametrize("output_attentions", [False, True]) + # @pytest.mark.parametrize("model", ["hf-internal-testing/tiny-random-gpt2", "hf-internal-testing/tiny-random-bert", "hf-internal-testing/tiny-random-bart"]) # decoder, encoder, encoder-decoder + #@pytest.mark.parametrize("assistant_model", [False, True]) # having issues + def check_outputs_match(self, + *, + model="hf-internal-testing/tiny-random-gpt2", + #assistant_model, + do_sample=False, + top_k=None, + max_new_tokens=10, + penalty_alpha=None, + output_scores=False, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + ): + + baseline_outputs, outputs, streamer = self._setup( + model=model, + # assistant_model=assistant_model, + do_sample=do_sample, + top_k=top_k, + penalty_alpha=penalty_alpha, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions, + max_new_tokens=max_new_tokens, + return_dict_in_generate=True, + output_hidden_states=output_hidden_states, + ) + + n_times_field_extended = Counter() + for answer in streamer: + #if isinstance(answer, list): + assert isinstance(answer, list) + for output_object in answer: + for output_name in outputs.keys(): + #print(output_name) + new_values = getattr(output_object, output_name) + if (new_values is not None) and (len(new_values) > 0): + + #print(type(outputs[output_name]), type(new_values)) + if output_name == 'sequences': + new_values = new_values.cpu() # fml.... + if new_values.ndim == 1: + new_values = new_values.unsqueeze(0) + outputs[output_name] = torch.cat([outputs[output_name], new_values], axis=-1) + else: + outputs[output_name] += new_values # tuples gonna tuple... + + print(outputs) + for output_name in outputs.keys(): + print(output_name) + baseline_values = getattr(baseline_outputs, output_name) + if isinstance(baseline_values, torch.Tensor): + baseline_values = baseline_values.cpu() + #assert (baseline_values is not None) and (baseline_values != tuple()) + assert (baseline_values is not None) + #assert type(baseline_values) == type(getattr(output_object, output_name)) + #assert n_times_field_extended[output_name] > 1 # make sure we're not just comparing to the final output tensor + # TODO: pick a better "are these literally the same object" test + + #if not isinstance(baseline_values, torch.Tensor): + # baseline_values = torch.cat(baseline_values, axis=-1) + target_values = outputs[output_name] + #assert baseline_values.shape == target_values.shape + print("baseline", baseline_values) + print("target", target_values) + assert type(baseline_values) == type(target_values) + assert len(baseline_values) == len(target_values) + + + # attention/hidden = tuples of tuples + assert nested_tensor_equality(baseline_values, target_values) + + # @pytest.mark.parametrize("do_sample,top_k", [(False,None), (True,4)]) + # @pytest.mark.parametrize("penalty_alpha", [None, 0.6]) + def check_ids_only_match(self, + do_sample=False, + top_k=None, + penalty_alpha=None, + max_new_tokens=10, + model="hf-internal-testing/tiny-random-gpt2", + ): + baseline_values, outputs, streamer = self._setup( + model=model, + # assistant_model=assistant_model, + do_sample=do_sample, + top_k=top_k, + penalty_alpha=penalty_alpha, + max_new_tokens=max_new_tokens, + return_dict_in_generate=False, + output_scores=False, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + ) + + target_values = torch.Tensor() + for answer in streamer: + assert isinstance(answer, list) + for output_object in answer: + new_ids = output_object.cpu() + if new_ids.ndim == 1: + new_ids = new_ids.unsqueeze(0) + target_values = torch.cat([target_values, new_ids], axis=-1) + + assert baseline_values.shape == target_values.shape + assert baseline_values.tolist() == target_values.tolist() + + def test_greedy_ids_only(self): + self.check_ids_only_match(do_sample=False) + + def test_multinomial_ids_only(self): + self.check_ids_only_match(do_sample=True) + + def test_contrastive_ids_only(self): + self.check_ids_only_match(do_sample=False, penalty_alpha=0.6, top_k=4) + + #def test_assisted_ids_only(self): + # + + @pytest.mark.parametrize("output_scores", [False, True]) + @pytest.mark.parametrize("output_logits", [False, True]) + @pytest.mark.parametrize("output_attentions", [False, True]) + def test_greedy_outputs(self, + output_scores, + output_logits, + output_attentions, + ): + self.check_outputs_match( + do_sample=False, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions) + + @pytest.mark.parametrize("output_scores", [False, True]) + @pytest.mark.parametrize("output_logits", [False, True]) + @pytest.mark.parametrize("output_attentions", [False, True]) + def test_multinomial_outputs(self, + output_scores, + output_logits, + output_attentions, + ): + self.check_outputs_match( + do_sample=True, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions) + + @pytest.mark.parametrize("output_scores", [False, True]) + @pytest.mark.parametrize("output_logits", [False, True]) + # TODO: reactivate fixtures for logits and attentions + @pytest.mark.parametrize("output_attentions", [False]) + #@pytest.mark.parametrize("output_attentions", [False, True]) + def test_contrastive_outputs(self, + output_scores, + output_logits, + output_attentions, + ): + self.check_outputs_match( + do_sample=False, + penalty_alpha=0.6, + top_k=4, + output_scores=output_scores, + output_logits=output_logits, + output_attentions=output_attentions)