Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Unexpected GPU memory consumption when using transformers PEFT in DeepSpeed Zero3 #29047

Closed
2 of 4 tasks
alekseymalakhov11 opened this issue Feb 15, 2024 · 5 comments
Closed
2 of 4 tasks

Comments

@alekseymalakhov11
Copy link

System Info

transformers = "4.35.0"
peft = "0.7.1"
torch = ">=2.0.0"
accelerate = "^0.24.1"
deepspeed = "^0.9.5"

Who can help?

@muellerzr @pacman100 @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Description

Llama30B with Lora adapters cannot fit into 8 x A100 (80GB).

Demonstration of Problem and Experiment Setups

I will illustrate this issue using various experiment setups on smaller models:

  1. 7b+lora+stage 3
    image
  2. 7b+stage 3
    image
  3. 7b+lora+stage 2
    image
  4. 7b + stage 2
    image

All other parameters remain consistent in the experiments below.

Expected behavior

Suspected Cause

The possible reason for this issue might be that Zero3 does not partition non-trainable weights across GPUs. The basis for this assumption is:

  • The memory consumption is consistent with predicted values when Lora is not used.
  • When training the model with both Zero2 and Zero3 using Lora, I observe nearly the same memory consumption.
  • A code examination of the Zero Runtime sources also suggests this could be the case.

Expected behavior

Training the model with Zero3 while using Lora should consume significantly less memory than Zero2 with Lora.

We also opened an issue in Deepspeed, but no one has assisted us. Additionally, you might have more experience with PEFT and Deepspeed integration in the Transformers trainer.

@amyeroberts
Copy link
Collaborator

cc @younesbelkada too :)

@younesbelkada
Copy link
Contributor

Hi @alekseymalakhov11 ! Thanks very much for the issue, just for us to understand better the issue, can you share the full command you are using for training?
It might be unrelated but just to be on the safe zone, could you try out on PEFT==0.8.2 & PEFT main to include some fixes such as huggingface/peft#1450 ?

@alekseymalakhov11
Copy link
Author

Thank you for your quick response!

We have attempted to update PEFT to the versions you suggested; however, this didn't resolve the issue. Additionally, we updated DeepSpeed and Accelerate to their latest versions, but the problem still exists.

I have attached the code snippets that we use for training

training.py

from peft import get_peft_model, LoraConfig

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

tokenizer.add_special_tokens({'additional_special_tokens': ['<UNK>']})

model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
    )

model.resize_token_embeddings(len(tokenizer))


peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=[
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj"
    ],
    task_type="CAUSAL_LM",
    modules_to_save=["embed_tokens", "lm_head"],
)

model = get_peft_model(model, peft_config)

# https://github.com/huggingface/peft/issues/341
for name, module in model.named_modules():
    if name.endswith('modules_to_save'):
        module.default.weight.data = module.default.weight.data.float()
    elif name.endswith('original_module'):
        module.weight.data = module.weight.data.float()



training_args = TrainingArguments(
    output_dir=str(TRAINER_LOGS_FOLDER),
    report_to=[],
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=False,
    save_total_limit=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    logging_steps=1,
    learning_rate=0.0004,
    num_train_epochs=5,
    lr_scheduler_type="linear",
    warmup_steps=1,
    fp16=False,
    bf16=True,
    deepspeed="deepspeed_config.json",
    optim="adamw_torch",
    adam_beta1=0.9,
    adam_beta2=0.98,
    adam_epsilon=1e-6,
    weight_decay=0.01,
    max_grad_norm=0.11
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=[],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()

DeepSpeed config

{
"fp16":
    {
    "enabled": "auto"
    },
"bf16": 
    {
    "enabled": "auto"
    },
"optimizer": 
    {
    "type": "AdamW",
    "params": 
        {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },
"scheduler": 
    {
        "type": "WarmupDecayLR",
        "params": 
        {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },
"zero_optimization":
    {
        "stage": 2,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto"
    },

"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

We launch our code using

deepspeed --num_gpus=8 --no_local_rank training.py

@pacman100
Copy link
Contributor

Hello, we are able to finetune 70B llama on 8 H100 80Gb GPUs using PEFT+DeepSpeed. Pelase refer to the docs here: https://github.com/huggingface/peft/blob/main/docs/source/accelerate/deepspeed.md

70B model can't be fully fine-tuned with 8 80Gb GPUs while using LoRA+DeepSpeed each GPU uses around 64GB of memory

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants