Skip to content

Commit

Permalink
Generate: models with custom generate() return True in `can_gener…
Browse files Browse the repository at this point in the history
…ate()` (huggingface#25838)
  • Loading branch information
gante authored and parambharat committed Sep 26, 2023
1 parent c98b825 commit d01e813
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 27 deletions.
5 changes: 3 additions & 2 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,9 @@ def can_generate(cls) -> bool:
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True

Expand Down
5 changes: 3 additions & 2 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,8 +1307,9 @@ def can_generate(cls) -> bool:
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True

Expand Down
5 changes: 3 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,8 +1216,9 @@ def can_generate(cls) -> bool:
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(cls.prepare_inputs_for_generation):
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True

Expand Down
14 changes: 0 additions & 14 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,13 +1231,6 @@ def forward(
attentions=all_self_attentions,
)

def can_generate(self) -> bool:
"""
Returns True. Despite being an autoencoder, BarkFineModel shares some characteristics with generative models
due to the way audio are generated.
"""
return True

def generate(
self,
coarse_output: torch.Tensor,
Expand Down Expand Up @@ -1594,10 +1587,3 @@ def generate(
self.codec_model_hook.offload()

return audio

def can_generate(self) -> bool:
"""
Returns True. Despite not having a `self.generate` method, this model can `generate` and thus needs a
BarkGenerationConfig.
"""
return True
7 changes: 0 additions & 7 deletions src/transformers/models/speecht5/modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2779,13 +2779,6 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

def can_generate(self) -> bool:
"""
Returns True. This model can `generate` and must therefore have this property set to True in order to be used
in the TTS pipeline.
"""
return True

@torch.no_grad()
def generate(
self,
Expand Down

0 comments on commit d01e813

Please sign in to comment.