Skip to content

Commit

Permalink
Donut: fix generate call from local path (huggingface#31470)
Browse files Browse the repository at this point in the history
* local donut path fix

* engrish

* Update src/transformers/generation/utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
gante and amyeroberts authored Jun 18, 2024
1 parent 76289fb commit cd71f93
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,12 @@ def _prepare_decoder_input_ids_for_generation(
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
decoder_input_ids = decoder_start_token_id
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the
# original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic.
# See: https://github.com/huggingface/transformers/pull/31470
elif "donut" in self.__class__.__name__.lower() or (
self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
):
pass
elif self.config.model_type in ["whisper"]:
pass
Expand Down

0 comments on commit cd71f93

Please sign in to comment.