Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPOTrainer incompatible with Falcon models and gradient checkpointing #828

Closed
lewtun opened this issue Oct 3, 2023 · 4 comments
Closed
Assignees

Comments

@lewtun
Copy link
Member

lewtun commented Oct 3, 2023

Description

Running the dpo.py example with gradient_checkpointing=True produces different errors for different Falcon models (1B and 7B). Note that this happens both with and/or without DeepSpeed ZeRO-{2,3}, which suggests a deeper problem with the modelling code 🙀.

To trigger the error, first activate gradient checkpointing in the training arguments of dpo.py:

    training_args = TrainingArguments(
        per_device_train_batch_size=script_args.per_device_train_batch_size,
        max_steps=script_args.max_steps,
        remove_unused_columns=False,
        gradient_accumulation_steps=script_args.gradient_accumulation_steps,
        learning_rate=script_args.learning_rate,
        evaluation_strategy="steps",
        logging_first_step=True,
        logging_steps=10,  # match results in blog post
        eval_steps=500,
        output_dir="./test",
        optim="rmsprop",
        warmup_steps=150,
        report_to=script_args.report_to,
        bf16=True,
+        gradient_checkpointing=True,
    )

Then run one of the commands provided below:

Stack trace for https://huggingface.co/tiiuae/falcon-rw-1b

Command to reproduce

# DDP
ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/dpo.py --model_name_or_path tiiuae/falcon-rw-1b

# ZeRO-2
ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/dpo.py --model_name_or_path tiiuae/falcon-rw-1b  
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
  tr_loss_step = self.training_step(model, inputs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2776, in training_step
  loss = self.compute_loss(model, inputs)
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 500, in compute_loss
  loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 451, in get_batch_metrics
  ) = self.concatenated_forward(model, batch)
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 417, in concatenated_forward
  all_logits = model(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
  loss = self.module(*inputs, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1303, in forward
  transformer_outputs = self.transformer(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1177, in forward
  outputs = torch.utils.checkpoint.checkpoint(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
  return CheckpointFunction.apply(function, preserve, *args)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
  return super().apply(*args, **kwargs)  # type: ignore[misc]
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
  outputs = run_function(*args)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1173, in custom_forward
  return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 787, in forward
  attn_outputs = self.self_attention(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 527, in forward
  attention_probs = attention_probs * head_mask
RuntimeError: The size of tensor a (203) must match the size of tensor b (8) at non-singleton dimension 2
Stack trace for https://huggingface.co/tiiuae/falcon-7b

Command to reproduce

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/dpo.py --model_name_or_path tiiuae/falcon-7b

Error

Traceback (most recent call last):
File "/fsx/lewis/git/trl/examples/scripts/dpo.py", line 173, in <module>
  dpo_trainer.train()
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
  return inner_training_loop(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
  tr_loss_step = self.training_step(model, inputs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2776, in training_step
  loss = self.compute_loss(model, inputs)
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 500, in compute_loss
  loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 451, in get_batch_metrics
  ) = self.concatenated_forward(model, batch)
File "/fsx/lewis/git/trl/trl/trainer/dpo_trainer.py", line 417, in concatenated_forward
  all_logits = model(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
  loss = self.module(*inputs, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
  result = forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1303, in forward
  transformer_outputs = self.transformer(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
  result = forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1213, in forward
  presents = self._convert_cache_to_standard_format(presents, batch_size)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 947, in _convert_cache_to_standard_format
  batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
IndexError: tuple index out of range
@lvwerra
Copy link
Member

lvwerra commented Oct 5, 2023

Does this also happen in the transformers.Trainer? Could be a deeper issue maybe? cc @younesbelkada

@younesbelkada
Copy link
Contributor

This is indeed a problem in the modeling code of Falcon, will have a look in the next days

@younesbelkada younesbelkada self-assigned this Oct 10, 2023
@ArneNx
Copy link

ArneNx commented Oct 19, 2023

Did you find a solution for the issue? I get the same error when running training with gradient_checkpointing=True.
I tracked it down to here where presents is attempted to be converted. However, it is an empty tuple if use_chache is false, which it has to be if gradient checkpointing is used.
Can this be avoided by changing the if statement to if presents: or will this entail other things?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants