Skip to content

Commit

Permalink
Generate: can_generate() recursive check (huggingface#33718)
Browse files Browse the repository at this point in the history
* add recursive check and test warnings

* missing space

* models without can_generate
  • Loading branch information
gante authored and BenjaminBossan committed Sep 30, 2024
1 parent 502e785 commit 9a8abb5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
6 changes: 6 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,12 @@ def can_generate(cls) -> bool:
# Model class overwrites `generate` (e.g. time series models) -> can generate
if str(cls.__name__) in str(cls.generate):
return True
# The class inherits from a class that can generate (recursive check) -> can generate
for base in cls.__bases__:
if not hasattr(base, "can_generate"):
continue
if "PreTrainedModel" not in str(base) and base.can_generate():
return True
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate.
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
Expand Down
32 changes: 27 additions & 5 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,29 +1718,51 @@ def test_isin_mps_friendly(self):

def test_can_generate(self):
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
logger = logging.get_logger("transformers.modeling_utils")
logger.warning_once.cache_clear()

# 1 - By default, a model CAN'T generate
self.assertFalse(BertModel.can_generate())
can_generate = BertModel.can_generate()
self.assertFalse(can_generate)

# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
class DummyBertWithMixin(BertModel, GenerationMixin):
pass

self.assertTrue(DummyBertWithMixin.can_generate())
with CaptureLogger(logger) as cl:
can_generate = DummyBertWithMixin.can_generate()
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)

# 3 - Alternatively, a model can implement a `generate` method
class DummyBertWithGenerate(BertModel):
def generate(self):
pass

self.assertTrue(DummyBertWithGenerate.can_generate())
with CaptureLogger(logger) as cl:
can_generate = DummyBertWithGenerate.can_generate()
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)

# 4 - Finally, it can inherit from a model that can generate
class DummyBertWithParent(DummyBertWithMixin):
pass

with CaptureLogger(logger) as cl:
can_generate = DummyBertWithParent.can_generate()
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)

# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# `GenerationMixin`)
class DummyBertWithPrepareInputs(BertModel):
def prepare_inputs_for_generation(self):
pass

self.assertTrue(DummyBertWithPrepareInputs.can_generate())
with CaptureLogger(logger) as cl:
can_generate = DummyBertWithPrepareInputs.can_generate()
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
self.assertTrue(can_generate)

def test_save_and_load_config_with_custom_generation(self):
"""
Expand Down

0 comments on commit 9a8abb5

Please sign in to comment.