You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Running the dpo.py example with gradient_checkpointing=True produces different errors for different Falcon models (1B and 7B). Note that this happens both with and/or without DeepSpeed ZeRO-{2,3}, which suggests a deeper problem with the modelling code 🙀.
To trigger the error, first activate gradient checkpointing in the training arguments of dpo.py:
training_args = TrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
max_steps=script_args.max_steps,
remove_unused_columns=False,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
logging_first_step=True,
logging_steps=10, # match results in blog post
eval_steps=500,
output_dir="./test",
optim="rmsprop",
warmup_steps=150,
report_to=script_args.report_to,
bf16=True,
+ gradient_checkpointing=True,
)
Then run one of the commands provided below:
Stack trace for https://huggingface.co/tiiuae/falcon-rw-1b
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2776, in training_step
loss = self.compute_loss(model, inputs)
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 500, in compute_loss
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 451, in get_batch_metrics
) = self.concatenated_forward(model, batch)
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 417, in concatenated_forward
all_logits = model(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
loss = self.module(*inputs, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1303, in forward
transformer_outputs = self.transformer(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1177, in forward
outputs = torch.utils.checkpoint.checkpoint(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
outputs = run_function(*args)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1173, in custom_forward
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 787, in forward
attn_outputs = self.self_attention(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 527, in forward
attention_probs = attention_probs * head_mask
RuntimeError: The size of tensor a (203) must match the size of tensor b (8) at non-singleton dimension 2
Stack trace for https://huggingface.co/tiiuae/falcon-7b
Traceback (most recent call last):
File "/fsx/lewis/git/trl/examples/scripts/dpo.py", line 173, in <module>
dpo_trainer.train()
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
return inner_training_loop(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2776, in training_step
loss = self.compute_loss(model, inputs)
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 500, in compute_loss
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 451, in get_batch_metrics
) = self.concatenated_forward(model, batch)
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 417, in concatenated_forward
all_logits = model(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
loss = self.module(*inputs, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1303, in forward
transformer_outputs = self.transformer(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1213, in forward
presents = self._convert_cache_to_standard_format(presents, batch_size)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 947, in _convert_cache_to_standard_format
batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
IndexError: tuple index out of range
The text was updated successfully, but these errors were encountered:
Did you find a solution for the issue? I get the same error when running training with gradient_checkpointing=True.
I tracked it down to here where presents is attempted to be converted. However, it is an empty tuple if use_chache is false, which it has to be if gradient checkpointing is used.
Can this be avoided by changing the if statement to if presents: or will this entail other things?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Description
Running the
dpo.py
example withgradient_checkpointing=True
produces different errors for different Falcon models (1B and 7B). Note that this happens both with and/or without DeepSpeed ZeRO-{2,3}, which suggests a deeper problem with the modelling code 🙀.To trigger the error, first activate gradient checkpointing in the training arguments of
dpo.py
:training_args = TrainingArguments( per_device_train_batch_size=script_args.per_device_train_batch_size, max_steps=script_args.max_steps, remove_unused_columns=False, gradient_accumulation_steps=script_args.gradient_accumulation_steps, learning_rate=script_args.learning_rate, evaluation_strategy="steps", logging_first_step=True, logging_steps=10, # match results in blog post eval_steps=500, output_dir="./test", optim="rmsprop", warmup_steps=150, report_to=script_args.report_to, bf16=True, + gradient_checkpointing=True, )
Then run one of the commands provided below:
Stack trace for https://huggingface.co/tiiuae/falcon-rw-1b
Command to reproduce
Stack trace for https://huggingface.co/tiiuae/falcon-7b
Command to reproduce
Error
The text was updated successfully, but these errors were encountered: