Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi GPU DDP QLora Doesn't Work; Lora Does #921

Closed
mallorbc opened this issue Oct 27, 2023 · 13 comments
Closed

Multi GPU DDP QLora Doesn't Work; Lora Does #921

mallorbc opened this issue Oct 27, 2023 · 13 comments

Comments

@mallorbc
Copy link

I am using TRL and PEFT to finetune models. You can see the code I am using here

When finetuning Llama 7B, I use Lora because with A100s I can fit the model in this format and its a bit faster than QLora.

When finetuning 70B, I must use Qlora.

When finetuning with Lora, the code will automatically utilize all available GPUs with DDP. I do not need to use accelerate, though I can and when I do, it works. I imagine there is some stuff going on behind the scenes that is using accelerate

When using QLora, the model does not work, but the behavior changes based on whether or not I am using accelerate explicitly.

When not explicitly using Qlora, the model is loaded onto both GPUs, but one GPU is idle and not ever used. Due to this, the model training never progresses.
image

We can see in this photo the model is loaded on both but only training is occurring on GPU 0.

When I use accelerate explicitly, I get the following error

Traceback (most recent call last):
  File "trl_finetune.py", line 349, in <module>
    trainer.train()
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1511, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1640, in _inner_training_loop
    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1288, in prepare
    result = tuple(
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1289, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1094, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1376, in prepare_model
    raise ValueError(
ValueError: You can't train a model that has been loaded in 8-bit precision on a different device than the one you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}
[2023-10-27 20:30:41,936] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 3870 closing signal SIGTERM
[2023-10-27 20:30:42,301] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 1 (pid: 3871) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/usr/local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 985, in launch_command
    multi_gpu_launcher(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 654, in multi_gpu_launcher
    distrib_run.run(args)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
trl_finetune.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-10-27_20:30:41
  host      : dc01289cc6fb
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3871)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
@mallorbc
Copy link
Author

mallorbc commented Oct 27, 2023

I think ultimately the core of the issue is that regular Lora can be on CPU. Accelerate then puts the model where it needs to go. QLora has to be on a GPU and then:

ValueError: You can't train a model that has been loaded in 8-bit precision on a different device than the one you're training on. Make sure you loaded the model on the correct device using for example device_map={'':torch.cuda.current_device()}you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}

Maybe a way to pass a QLora config would be ideal?

@younesbelkada
Copy link
Contributor

younesbelkada commented Oct 27, 2023

Hi @mallorbc !
Thanks a lot for your issue, quickly looking at your script I think you need to adapt it a bit to support QLoRA + DDP - please refer to this comment: huggingface/accelerate#1840 (comment) to fix your issue, you need to excplicitly pass a device_map that corresponds to {"": Accelerator().process_index} refer to that comment for more details and let me know if the fix works on your end!

@mallorbc
Copy link
Author

Hi @mallorbc ! Thanks a lot for your issue, quickly looking at your script I think you need to adapt it a bit to support QLoRA + DDP - please refer to this comment: huggingface/accelerate#1840 (comment) to fix your issue, you need to excplicitly pass a device_map that corresponds to {"": Accelerator.process_index} refer to that comment for more details and let me know if the fix works on your end!

I had a similar idea. :)

I ultimately did this and it worked:

        local_rank = os.getenv("LOCAL_RANK")
        device_string = "cuda:" + str(local_rank)
        kwargs["device_map"] = device_string

Thanks so much!

BTW, is there a reason we should use float16 over bfloat16 for QLora?

@younesbelkada
Copy link
Contributor

hi @mallorbc
Thanks! I think that you are right, one should use bfloat16 for QLoRA training but for inference you should use float16 as it is faster.

@mallorbc
Copy link
Author

mallorbc commented Oct 27, 2023

@younesbelkada
My solution only actually seemed to work and failed after a few minutes. Activity was definitely occurring on both GPUs though.

I am now getting this:

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.
Traceback (most recent call last):
  File "trl_finetune.py", line 356, in <module>
    trainer.train()
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1511, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1811, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2669, in training_step
    self.accelerator.backward(loss)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1989, in backward
    loss.backward(**kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 1119 with name base_model.model.model.layers.79.mlp.down_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.
wandb: Waiting for W&B process to finish... (failed 1).
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /workspace/finetuning_repo/wandb/offline-run-20231027_210208-ogyygr6u
wandb: Find logs at: ./wandb/offline-run-20231027_210208-ogyygr6u/logs
Traceback (most recent call last):
  File "trl_finetune.py", line 356, in <module>
    trainer.train()
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1511, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1811, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2669, in training_step
    self.accelerator.backward(loss)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1989, in backward
    loss.backward(**kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 1119 with name base_model.model.model.layers.79.mlp.down_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.
[2023-10-27 21:08:28,560] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 5366 closing signal SIGTERM
[2023-10-27 21:08:29,076] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 5365) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/usr/local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 985, in launch_command
    multi_gpu_launcher(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 654, in multi_gpu_launcher
    distrib_run.run(args)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================

Gonna take a closer look at what you shared.

@mallorbc
Copy link
Author

This error happens with your solution or with my solution. I see that is affecting a certain target lora layer. I have removed it and it moves to another, so far I have removed down_proj up_proj gate_proj now it says o_proj is causing issues.

Trying it again with that removed now.

The Qlora paper suggests that tuning all linear layers is important, so this is not ideal.

@mallorbc
Copy link
Author

mallorbc commented Oct 27, 2023

The issues I am having deal with gradient checkpointing. With QLora, flash attention, rope scaling, I can get 8k tokens on 1 A100. I can't even get 2k when I disable gradient checkpointing(which I guess is enabled by default?)

And I guess for the 7B model it is not enabled my default?

@mallorbc mallorbc reopened this Oct 27, 2023
@mallorbc
Copy link
Author

Ok so normal DDP does not support gradient checkpointing. Thankfully DeepSpeed does and thankfully, all stages but Zero 3 work with QLora(or at least it seems, I need to train a model still but forward and backward work).

Thus the answer is to use gradient checkpointing with DeepSpeed and Lora/QLora.

@younesbelkada
Copy link
Contributor

Hi @mallorbc
regarding the issue

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) 

Can you try to install TRL, transformers and PEFT from source and try to pass gradient_checkpointing_kwargs={"use_reentrant": False} to the training arguments? I believe this should fix that issue

@hengjiUSTC
Copy link

hengjiUSTC commented Dec 29, 2023

Hi @mallorbc regarding the issue

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) 

Can you try to install TRL, transformers and PEFT from source and try to pass gradient_checkpointing_kwargs={"use_reentrant": False} to the training arguments? I believe this should fix that issue

I tried with

    training_args = TrainingArguments(
        do_train=True,
        do_eval=True,
        output_dir=args.output_dir,
        dataloader_drop_last=True,
        evaluation_strategy="steps",
        save_strategy="steps",
        logging_strategy="steps",
        num_train_epochs=args.epochs,
        eval_steps=args.eval_steps,
        save_steps=args.save_steps,
        logging_steps=args.log_steps,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size * 2,
        optim=optimizer,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        warmup_steps=args.warmup_steps,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        gradient_checkpointing=args.gradient_checkpointing,
        gradient_checkpointing_kwargs={'use_reentrant': True} if args.gradient_checkpointing else None,
        weight_decay=args.weight_decay,
        report_to="wandb",
        load_best_model_at_end=True,
        save_total_limit=args.save_limit,
        bf16=True if torch.cuda.is_bf16_supported() else False,
        fp16=False if torch.cuda.is_bf16_supported() else True,
    )

it didn't work. Still getting error.

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Traceback (most recent call last):
  File "/home/ubuntu/learn-llm/trl_finetune.py", line 445, in <module>
    trainer.train()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 315, in train
    output = super().train(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1543, in train
    return inner_training_loop(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2746, in training_step
    self.accelerator.backward(loss)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1938, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 447 with name base_model.model.model.layers.31.mlp.down_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.
Traceback (most recent call last):
  File "/home/ubuntu/learn-llm/trl_finetune.py", line 445, in <module>
    trainer.train()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 315, in train
    output = super().train(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1543, in train
    return inner_training_loop(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2746, in training_step
    self.accelerator.backward(loss)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1938, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 447 with name base_model.model.model.layers.31.mlp.down_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.
Traceback (most recent call last):
  File "/home/ubuntu/learn-llm/trl_finetune.py", line 445, in <module>
    trainer.train()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 315, in train
    output = super().train(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1543, in train
    return inner_training_loop(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2746, in training_step
    self.accelerator.backward(loss)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1938, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 447 with name base_model.model.model.layers.31.mlp.down_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.
Traceback (most recent call last):
  File "/home/ubuntu/learn-llm/trl_finetune.py", line 445, in <module>
    trainer.train()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 315, in train
    output = super().train(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1543, in train
    return inner_training_loop(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2746, in training_step
    self.accelerator.backward(loss)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1938, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 447 with name base_model.model.model.layers.31.mlp.down_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.
Traceback (most recent call last):aded (0.004 MB deduped)
  File "/home/ubuntu/learn-llm/trl_finetune.py", line 445, in <module>
    trainer.train()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 315, in train
    output = super().train(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1543, in train
    return inner_training_loop(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2746, in training_step
    self.accelerator.backward(loss)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 1938, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 447 with name base_model.model.model.layers.31.mlp.down_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

@younesbelkada
Copy link
Contributor

which version of TRL / PEFT / transformers do you have?

@hengjiUSTC
Copy link

which version of TRL / PEFT / transformers do you have?

as below

ubuntu@ip-172-31-14-191:~/learn-llm$ pip3 list | grep peft
peft                      0.7.2.dev0
ubuntu@ip-172-31-14-191:~/learn-llm$ pip3 list | grep transformer
transformers              4.37.0.dev0
ubuntu@ip-172-31-14-191:~/learn-llm$ pip3 list | grep trl
trl                       0.7.6

@hengjiUSTC
Copy link

I change to {'use_reentrant': False}, and it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants