Skip to content

Commit

Permalink
pass kwargs in stopping criteria list (#28927)
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp authored Feb 8, 2024
1 parent 0b693e9 commit cc309fd
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)
return any(criteria(input_ids, scores, **kwargs) for criteria in self)

@property
def max_length(self) -> Optional[int]:
Expand Down

0 comments on commit cc309fd

Please sign in to comment.