diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index ca2c6ca9b2..99ae5db739 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -140,7 +140,7 @@ def setup( " --quantize flag." ) strategy = FSDPStrategy( - auto_wrap_policy={Block}, + auto_wrap_policy={torch.nn.Linear}, activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True,