Skip to content

Commit

Permalink
rls2.5: fix return_dict_in_generate issue (#3333)
Browse files Browse the repository at this point in the history
  • Loading branch information
blzheng authored Oct 24, 2024
1 parent 62e4a7f commit 584a4e2
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 90 deletions.
33 changes: 5 additions & 28 deletions intel_extension_for_pytorch/transformers/generation/beam_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,18 @@
from torch import nn
import torch.distributed as dist
import warnings
from typing import Optional, Tuple, Union, List
from typing import Optional, Union, List
from transformers.generation.stopping_criteria import (
StoppingCriteriaList,
validate_stopping_criteria,
)
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.beam_search import BeamScorer
from transformers.utils import ModelOutput
import time


class GenerateBeamDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
logits: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None


class GenerateBeamEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
logits: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None

from transformers.generation.utils import (
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
)

GenerateBeamOutput = Union[
GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput
Expand Down
29 changes: 6 additions & 23 deletions intel_extension_for_pytorch/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,18 @@
from torch import nn
import torch.distributed as dist
from ...utils._logger import logger, WarningType
from typing import Optional, Tuple, Union, List
from typing import Optional, Union, List
from transformers.generation.stopping_criteria import (
StoppingCriteriaList,
validate_stopping_criteria,
)
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.beam_search import BeamScorer
from transformers.utils import ModelOutput
import time


class BeamSearchEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


class BeamSearchDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
from transformers.generation.utils import (
BeamSearchEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
)


BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
Expand All @@ -55,6 +37,7 @@ def _beam_search(
) -> Union[BeamSearchOutput, torch.LongTensor]:
new_generation_config = model_kwargs.pop("generation_config", None)
if new_generation_config is not None:
return_dict_in_generate = new_generation_config.return_dict_in_generate
if new_generation_config.do_sample:
return self._beam_sample(
input_ids,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,19 @@
import torch
import torch.distributed as dist
from ...utils._logger import logger, WarningType
from typing import Optional, Tuple, Union, List
from typing import Optional, Union, List
from transformers.generation.stopping_criteria import (
StoppingCriteriaList,
validate_stopping_criteria,
)
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.streamers import BaseStreamer
from transformers.utils import ModelOutput
import time


class GreedySearchDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


class GreedySearchEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

from transformers.generation.utils import (
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
)

GreedySearchOutput = Union[
GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput
Expand Down
26 changes: 6 additions & 20 deletions intel_extension_for_pytorch/transformers/generation/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,18 @@
from torch import nn
import torch.distributed as dist
import warnings
from typing import Optional, Tuple, Union, List
from typing import Optional, Union, List
from transformers.generation.stopping_criteria import (
StoppingCriteriaList,
validate_stopping_criteria,
)
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.streamers import BaseStreamer
from transformers.utils import ModelOutput
import time


class SampleEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


class SampleDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

from transformers.generation.utils import (
SampleEncoderDecoderOutput,
SampleDecoderOnlyOutput,
)

SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]

Expand All @@ -52,6 +37,7 @@ def _sample(
) -> Union[SampleOutput, torch.LongTensor]:
new_generation_config = model_kwargs.pop("generation_config", None)
if new_generation_config is not None:
return_dict_in_generate = new_generation_config.return_dict_in_generate
if not new_generation_config.do_sample:
pad_token_id = new_generation_config._pad_token_tensor
eos_token_id = new_generation_config._eos_token_tensor
Expand Down
7 changes: 7 additions & 0 deletions tests/cpu/test_ipex_optimize_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,13 @@ def test_generate_functions(self):
ipex_res = ipex_m.generate(input_ids, **generate_kwargs)
ref_res = ref_m.generate(input_ids, **generate_kwargs)
self.assertEqual(ipex_res, ref_res)
ipex_res_dict = ipex_m.generate(
input_ids, return_dict_in_generate=True, **generate_kwargs
)
ref_res_dict = ref_m.generate(
input_ids, return_dict_in_generate=True, **generate_kwargs
)
self.assertEqual(ipex_res_dict.sequences, ref_res_dict.sequences)

def test_cache_weight_for_large_batch(self):
config = AutoConfig.from_pretrained(
Expand Down

0 comments on commit 584a4e2

Please sign in to comment.