Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral 16bit LoRa OOM with deepspeed zero stage 3 and dpo trainer on 4 80GB A100s #1268

Closed
janphilippfranken opened this issue Jan 23, 2024 · 15 comments
Labels
🏋 DPO Related to DPO

Comments

@janphilippfranken
Copy link

janphilippfranken commented Jan 23, 2024

hi!

i am trying to use the dpo trainer to fine-tune a mixtral 8*7B model in 16bit precision (i've already completed fine-tuning for a 4bit model without issues, but unfortunately the quantized adapter performs worse than the 16bit version of the model which i want to compare it to).

my goal is to complete training an adapter in 16bit precision, and then merge and unload the model with the adapter to run inference with vllm using the merged model.

unfortunately, i am running into OOM issues when trying to run dpo_trainer.train() for the following setup (any help would be much appreciated):

deepspeed config: (from https://huggingface.co/blog/accelerate-deepspeed)

    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "weight_decay": "auto"
        }
    },
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "sub_group_size": 1e9,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

accelerate config

compute_environment: LOCAL_MACHINE
deepspeed_config:
 deepspeed_config_file: ds_config.json
 zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: 4
use_cpu: false

training script:

import os

import fire
import hydra
import torch
from omegaconf import DictConfig, OmegaConf

import torch
import wandb
from trl import DPOTrainer
from datasets import Dataset
from peft import LoraConfig
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer


from helpers import *

logging.basicConfig(level=logging.INFO)

@hydra.main(version_base=None, config_path="conf", config_name="train")
def main(args: DictConfig) -> None:
    
    # wandb
    args_dict = OmegaConf.to_container(args, resolve=True)
    wandb.init(project=args.wandb.project, name=args.wandb.name, config=args_dict)

    # get tokenizer
    tokenizer = AutoTokenizer.from_pretrained(**args.model.tokenizer_config)   
    tokenizer.pad_token, tokenizer.padding_side  = "[PAD]", "right"
    
    # training args
    training_args_dict = OmegaConf.to_container(args.training_args, resolve=True)
    training_args = TrainingArguments(**training_args_dict)
    
    # get model 
    model = AutoModelForCausalLM.from_pretrained(
        **args.model.model_config, 
        torch_dtype=torch.float16,
    )

    # get dataset
    data_dict = jload(args.data)
    logging.info(f"N Train Examples: {len(data_dict)}")  
    
    # get dpo format 
    prompt = [example['prompt'] for example in data_dict]
    chosen = [example['chosen'] for example in data_dict]
    rejected = [example['rejected'] for example in data_dict]
    dataset = Dataset.from_dict(dict(prompt=prompt, chosen=chosen, rejected=rejected))  
    
    dataset = dataset.train_test_split(test_size=args.validation_split_size)
    logging.info(dataset)
    logging.info(dataset['train'][0])
    
    # LoRA and peft setup
    lora_config_dict = OmegaConf.to_container(args.lora_config, resolve=True)
    peft_config = LoraConfig(**lora_config_dict)

    logging.info(f"Devices: {torch.cuda.device_count()}")
    logging.info(f"training_args.output_dir: {training_args.output_dir}")

    # trainer
    dpo_trainer = DPOTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args, 
        beta=args.dpo_beta,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        peft_config=peft_config,
        max_prompt_length=args.max_prompt_length,
        max_length=args.model.tokenizer_config.model_max_length,
    )
    
    # train
    dpo_trainer.train()
    dpo_trainer.save_model(output_dir=training_args.output_dir)
    
    # save final checkpoint
    output_dir = os.path.join(training_args.output_dir, "final_checkpoint")
    dpo_trainer.model.save_pretrained(output_dir)
    
if __name__ == "__main__":
    fire.Fire(main())

Hardware: 4 80GB A100 GPUs

Command: accelerate launch --config_file accelerate_config.yaml train_dpo.py

error
File "/scr/jphilipp/miniconda3/envs/scai-tuning/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1210, in all_gather_coalesced
param_buffer = torch.empty(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 224.00 MiB. GPU 2 has a total capacty of 79.15 GiB of which 171.25 MiB is free. Including non-PyTorch memory, this process has 78.87 GiB memory in use. Of the allocated memory 76.47 GiB is allocated by PyTorch, and 1.00 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

@janphilippfranken janphilippfranken changed the title Mixtral 16bit LoRA OOM with deepspeed zero stage 3 and dpo trainer on 4 80GB A100s Mixtral 16bit LoRa OOM with deepspeed zero stage 3 and dpo trainer on 4 80GB A100s Jan 23, 2024
@maywind23
Copy link

It seems that the mixtral model is loaded to your every GPU rather than partitioned to your GPU (A100) equally.

    model = AutoModelForCausalLM.from_pretrained(
        model=$model_name_or_model_path,
        torch_dtype=torch.float16,
    )

Try this as above.

@younesbelkada
Copy link
Contributor

Thanks a lot for the isue!
I second what @maywind23 said - Mixtral is quite large and I don't think it'll fit on a single A100 GPU for inference. You will need to load it with device_map="auto" and simply run python xxx.py
For increasing the 4bit model performance I advise you to look into LoftQ initialization technique: https://huggingface.co/docs/peft/main/en/developer_guides/lora#loftq in order to boost the performance of your model. Could you try that out as well?

@kashif kashif added the 🏋 DPO Related to DPO label Jan 30, 2024
@saeedkhaki92
Copy link

@younesbelkada Thanks for this solution. I am using accelerate multi-gpu config and it is working well for Mixtral using DPO. My GPUs are 8 A-100 40G. However, It goes OOM if seq length is larger than 1024 which is small, I need at least 2048. I have enabled gradient checkpointing, decreasing batch size to 1, and using paged adamw 8bit. Still, it goes OOM. Is there any thing else I can do? I am not sure if multi-gpu config allows for CPU offload like deepspeed. I really appreciate if you could help? Thanks

@saeedkhaki92
Copy link

@janphilippfranken Did deepspeed work for you? It does not work for me.

@janphilippfranken
Copy link
Author

janphilippfranken commented Feb 1, 2024 via email

@saeedkhaki92
Copy link

I see, but it becomes very slow. It does not utilize all capacity of GPUs, e.g. the GPU utilization is low.

@younesbelkada
Copy link
Contributor

@saeedkhaki92 for decreasing the memory footprint of the training of your model you might consider using flash-attention 2 , simply pass attn_implementation="flash_attention_2" in from_pretrained. Make sure to use TRL built from main to include some important fixes with respect to DPO + FA2 + Mixtral: #1290

@saeedkhaki92
Copy link

saeedkhaki92 commented Feb 2, 2024

@younesbelkada Thanks. It still goes OOM, I added attn_implementation="flash_attention_2" and setting use_cache=False.

This is my training scripts and how I call it:


accelerate launch --config_file ./accelerate_configs/multi_gpu.yaml --num_processes=8 \
 rlhf_dpo_4bit.py \
--model_name_or_path="/mnt/efs/workspace/sakhaki/models/Mixtral-8x7B-Instruct-v0.1" \
--output_dir="/mnt/efs/workspace/sakhaki/models/Mixtral-8x7B-dpo-v5" \
--data_path="/mnt/efs/workspace/sakhaki/data/mixtral_dpo_12858.json" \
--use_lamma2_peft_config False \
--beta 0.1 \
--optimizer_type adamw_bnb_8bit \
--learning_rate 2e-5 \
--warmup_steps 50 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--lora_r 8 \
--max_prompt_length 1024 \
--max_length 2048 \
--num_train_epochs 4 \
--logging_steps 2 \
--save_steps 50 \
--save_total_limit 8 \
--eval_steps 10 \
--gradient_checkpointing True \
--report_to "wandb" \
--target_modules q_proj k_proj v_proj o_proj

And this is part of my code where I load the mixtral model inside my script: rlhf_dpo_4bit.py

quantization_config = BitsAndBytesConfig(
        load_in_8bit=False, load_in_4bit=True
    )


    torch_dtype = torch.bfloat16
    
    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        quantization_config=quantization_config,
        device_map=get_kbit_device_map(),
        trust_remote_code=True,
        use_cache=False,
        torch_dtype=torch_dtype,
        attn_implementation="flash_attention_2"
     #   use_auth_token=script_args.use_auth_token,
    )

    if script_args.ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

model_ref=None


trl version: 0.7.11.dev0

@younesbelkada Could you please let us know if there is any other way around this? Like CPU offloading, as far as I know, accelerate does not have cpu offload options. I tried deepspeed but getting errors. Thanks a lot

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@saeedkhaki92
Copy link

@younesbelkada
Update:
I tried using deepspeed_zero2 config and adding cpu offload options, it is still going OOM, here is my zero config

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  gradient_accumulation_steps: 8
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

My understanding is that with zero2+offloading, we should not go OOM because of excessive memory would be assigned to CPU. I appreciate it if you could comment on this? Thanks

@younesbelkada
Copy link
Contributor

Hi @saeedkhaki92
sadly I can't really tell .. zero-2 could give OOMs theoretically and the only solution would be to go for Zero-3 but is not supported by bitsandbytes / QLoRA. The other option could be to restrict the target modules to a smaller set, e.g. by removing o_proj and only restricting it to qkv layers

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@saeedkhaki92
Copy link

@younesbelkada Just a quick update, I managed to get it working with zero3+offloading, and by adding:

from deepspeed.utils import set_z3_leaf_modules
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])

it significantly reduced the memory usage. Per documentation of DeepSpeed: set_z3_leaf_modules is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.

@younesbelkada
Copy link
Contributor

very nice thanks for sharing!
Note also now QLoRA + DS-Zero3 is compatible if you use the latest transformers / accelerate: https://huggingface.co/docs/peft/accelerate/deepspeed

Copy link

github-actions bot commented May 2, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO
Projects
None yet
Development

No branches or pull requests

5 participants