Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix fsdp_auto_wrap_policy #2167

Merged
merged 3 commits into from
Oct 22, 2024

Conversation

eljandoubi
Copy link
Contributor

@eljandoubi eljandoubi commented Oct 19, 2024

fix the issue that fsdp_auto_wrap_policy is not working when FSDP_TRANSFORMER_CLS_TO_WRAP and the model's _no_split_modules are None
@BenjaminBossan @sayakpaul

Fixes #2166

@BenjaminBossan
Copy link
Member

Thanks for working on this fix. Could you give an example of a model that would currently fail but work with this fix? Ideally, we can build a unit test based on this.

@eljandoubi
Copy link
Contributor Author

  1. Donut
  2. Pix2Struct

@BenjaminBossan
Copy link
Member

I could not replicate for Donut:

>>> from peft.utils.other import fsdp_auto_wrap_policy
>>> model = DonutSwinPreTrainedModel.from_pretrained('naver-clova-ix/donut-base')
>>> model._no_split_modules
['DonutSwinStage']
>>> fsdp_auto_wrap_policy(model)  # works

For Pix2struct, I do get:

>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
>>> model._no_split_modules
None
>>> fsdp_auto_wrap_policy(model)
Exception: Could not find the transformer layer class to wrap in the model.

but of course I don't need to use PEFT's fsdp_auto_wrap_policy for FSDP training, it's just there to help users. Do you have a use case where you can't switch to another auto wrap policy?

@eljandoubi
Copy link
Contributor Author

In the Transformers Trainer class, it uses fsdp_auto_wrap_policy in _fsdp_qlora_plugin_updates, which is automatically applied when training a PeftModel in FSDP mode.

@BenjaminBossan
Copy link
Member

Thanks a lot for the pointer, it makes sense that Trainer should still work in that case.

Let's add a small test to ensure that the function does not fail in such cases. We already have a test class here:

class TestFSDPWrap:

We can just add the new test in there. As to the model, I think it's sufficient to just create a custom model and check that calling fsdp_auto_wrap_policy does not raise an error. So something like:

    def test_fsdp_auto_wrap_policy_does_not_raise_on_custom_model(self):
        # See #2167
        # Avoid raising on custom models since Trainer uses fsdp_auto_wrap_policy automatically for PEFT + FSDP
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = nn.Linear(2, 3)

        fsdp_auto_wrap_policy(MyModule())  # does not raise

@eljandoubi
Copy link
Contributor Author

You already have a toy model called SimpleModel. I am using it for testing.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating the issue and providing a fix. LGTM.

You already have a toy model called SimpleModel. I am using it for testing.

Good catch.

@BenjaminBossan BenjaminBossan merged commit 7717550 into huggingface:main Oct 22, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants