You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We tried to reproduce your work in our env and found one weird issue: by using your code, gradient_checkpointing=True runs much faster than gradient_checkpointing=False which betrayed our intuition(2 hr vs 6 hr in our CPU env). So we did some analysis, as below:
In this case, while setting gradient_checkpointing=True (and with PyTorch use_reentrant=True implicitly), LoRA weights are wrapped by transformer block whose input and output's requires_grad are both False, so all the transformers blocks will not execute BP in this setting, so in this case, actually only classifier head is trained, LoRA weights will not be trained and keep as identity per initialization.
We upgraded the transformers to 4.37.2 and add below 2 lines in get_lora_model to set use_reentrant to False, things will back to normal and LoRA weights will be trained.
@mehdiir,
We tried to reproduce your work in our env and found one weird issue: by using your code,
gradient_checkpointing=True
runs much faster thangradient_checkpointing=False
which betrayed our intuition(2 hr vs 6 hr in our CPU env). So we did some analysis, as below:In this case, while setting
gradient_checkpointing=True
(and with PyTorchuse_reentrant=True
implicitly), LoRA weights are wrapped by transformer block whose input and output'srequires_grad
are both False, so all the transformers blocks will not execute BP in this setting, so in this case, actually only classifier head is trained, LoRA weights will not be trained and keep as identity per initialization.We upgraded the transformers to 4.37.2 and add below 2 lines in
get_lora_model
to setuse_reentrant
to False, things will back to normal and LoRA weights will be trained.FYI in case other people meet the similar issue too.
The text was updated successfully, but these errors were encountered: