Skip to content

Commit 797859c

Browse files
npuichigovasqu
andauthored
Update no split modules in T5Gemma model (#40810)
* Update no split modules in T5Gemma model * Update no_split_modules also for T5Gemma modular * Remove model_split_percents from test cases --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
1 parent 6e69b60 commit 797859c

File tree

3 files changed

+2
-4
lines changed

3 files changed

+2
-4
lines changed

src/transformers/models/t5gemma/modeling_t5gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel):
585585
config: T5GemmaConfig
586586
base_model_prefix = "model"
587587
supports_gradient_checkpointing = True
588-
_no_split_modules = ["T5GemmaBlock"]
588+
_no_split_modules = ["T5GemmaEncoderLayer", "T5GemmaDecoderLayer"]
589589
_skip_keys_device_placement = ["past_key_values"]
590590
_supports_flash_attn = True
591591
_supports_sdpa = True

src/transformers/models/t5gemma/modular_t5gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ class T5GemmaPreTrainedModel(Gemma2PreTrainedModel):
476476
config: T5GemmaConfig
477477
base_model_prefix = "model"
478478
supports_gradient_checkpointing = True
479-
_no_split_modules = ["T5GemmaBlock"]
479+
_no_split_modules = ["T5GemmaEncoderLayer", "T5GemmaDecoderLayer"]
480480

481481
def _init_weights(self, module):
482482
# TODO: support initialization for encoders and decoders separately(?)

tests/models/t5gemma/test_modeling_t5gemma.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,6 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
597597
test_pruning = False
598598
_is_stateful = True
599599
is_encoder_decoder = True
600-
model_split_percents = [0.5, 0.6]
601600

602601
# used in `test_torch_compile_for_training`
603602
_torch_compile_train_cls = T5GemmaForConditionalGeneration if is_torch_available() else None
@@ -1460,7 +1459,6 @@ class T5GemmaEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
14601459
test_headmasking = False
14611460
_is_stateful = True
14621461
is_encoder_decoder = False
1463-
model_split_percents = [0.4, 0.5]
14641462

14651463
# won't fix
14661464
test_torchscript = False

0 commit comments

Comments
 (0)