-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformers 4.36 use_cache issue #28056
Comments
Thanks, pinging @gante as well as he worked on the cache refactoring, let’s keep this in mind |
Hey @younesbelkada unfortunately I don't think that fix will work for me, as I use a different training framework to handle activation checkpointing. It'd be great to understand and fix the root cause so that transformers models are fully usable with raw pytorch. Thanks as always for the quick responses! |
Thanks @dakinggg ok sounds great, I'll spend some time to understand to rootcause of it and why it used to fail on transformers main and provide an update here! |
Hi @dakinggg I had a deeper look, consider the snippet below: import torch
from torch.optim import Adam
from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
MODEL_ID = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
inputs = tokenizer("hello world what's up", return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
print(inputs)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", attn_implementation="eager", torch_dtype=torch.float16)
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=['q_proj', 'v_proj'], inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
optimizer = Adam(model.parameters(), lr=1e-5)
model.train()
for i in range(10):
outputs = model(labels=inputs['input_ids'], **inputs)
loss = outputs.loss
print(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad() in case the fix in #28031 is not applied what will happen (step by step): 1-
None past_key_values .1.2- use_cache will be force-set to False here:
past_key_values have been already created above.1.3- Since past_key_values is set to a non-None value, it will pass this line as well,
past_key_value for each layer. Note at that point the past_key_values will have a shape of batch_size, 1, seq_len, seq_len
Once that all past key values are populated, the script will call
kv_seq_len being set to 2*seq_len 2.2- ValueError raised here: https://github.com/huggingface/transformers/blob/2788f8d8d5f9cee2fe33a9292b0f3570bd566a6d/src/transformers/models/llama/modeling_llama.py#L714C13-L714C69 since the shapes do not match anymore I don't 100% master what is going on under the hood when one uses torch's GC but the fix that I proposed in #28031 circumvents this issue, by making sure there are no dummy past_key_values are created in case we are under gradient checkpointing and training regime. Hence, force-setting
The fix proposed worked for peft but should be universal to all training frameworks, except if you patch LLama/Mistral modeling classes with other classes, which in that case you should apply the same patch there as well. Let me know if anything is unclear! |
That mostly makes sense...I'm didn't quite understand why it wasn't an issue in previous versions though. Shouldn't we just never compute past kv during training? regardless of gradient checkpointing or not. Even if it worked, its not good to be creating past kv when we're not generating, as it uses significant extra memory. As to why the model's forward gets called again, that is because when you activation checkpointing, you don't save all of the activations for the backward pass, only some of them, and then you recompute the rest. |
Thanks @dakinggg for your reply!
past_key_values was always None during training leading to that block never being called, whereas now, past_key_values are always created during training since the model will fallback to config's use_cache to create past_key_values Thanks also for explaining about GC, it makes sense. |
Ahh I see. So it seems to me that the proper fix is to go back to the old behavior where |
Related, IMO the proper place to default |
Yep, we'll add a fix |
Thanks! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
not sure if this has been fixed |
Not yet, I'll be doing some more refactoring to past key values, in hope to fix these issues as well |
Do we know why it produces higher loss? Should we use 4.35.2 before the refactoring is done? |
I meet the same peoblem, I install the transformers with the main branch, but it doesn't work. Has this problem been solved? thanks! @ArthurZucker |
This is still an issue on the latest version of Transformers it seems? |
I don't think this is still an issue? I can't find the exact reproducer cc @dakinggg ? |
@ArthurZucker confirmed this is still an issue. I end up with something like this
|
Same fix has continued to work of explicitly specifying |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Running on main the share snippet https://github.com//issues/28056#issuecomment-1858319673 which is the only one I could find seems to work |
@dakinggg if you have a better snippet down to test and fix |
I suspect you all resolved the issue when using huggingface trainer (so your code snippet works), but not when using other libraries, which may enable activation checkpointing differently, and so any checks that you've put it will not know that activation checkpointing is being done. i don't have a small snippet, my test is just running a training job using llm foundry. |
I did not use huggingface trainer, I used FSDP with a custom train loop, and just loading with use_cache=False works for me. |
Yes, it works fine with |
I agree. |
Or if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False Which should automatically disable the The other might be: if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) and that is something we are stuck with because of BC.... |
facing the same issue too with fsdp+activation checkpointing. Currently disabled cache using |
@ArthurZucker I think explicitly setting Can someone on this thread share what is the exception message you're getting on |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Still an issue, will try to get the trace again this week. Iirc it's the standard pytorch error for act checkpoint metadata inconsistency. Also, why can't transformers just set use cache to false if self.training? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
transformers
version: 4.36.0Who can help?
@ArthurZucker @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Sorry that I don't really have a minimal reproducer here as I'm in another training framework, but I still think this might be useful for you.
Running training on llama2 7b, with activation checkpointing, has some issues in 4.36. Comparing to training with 4.35.2
ValueError: Attention mask should be of size (2, 1, 4096, 8192), but is torch.Size([2, 1, 4096, 4096])
If I explicitly set
use_cache=False
(shouldn't have any impact during training because there is no cache), results with 4.36 are similar to 4.35.2.Expected behavior
No regression from 4.35.2 -> 4.36.
The text was updated successfully, but these errors were encountered: