Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codellama training issue with Multiple GPUs in SFTTrainer #844

Closed
humza-sami opened this issue Oct 7, 2023 · 6 comments
Closed

codellama training issue with Multiple GPUs in SFTTrainer #844

humza-sami opened this issue Oct 7, 2023 · 6 comments

Comments

@humza-sami
Copy link

humza-sami commented Oct 7, 2023

I am trying to train codellama-7B in int8 using SFT trainer by trl. Model size after quantization is around 8GB. I tried to train it on RTX 3090 24GB (35 FLOPS) and it took ~380 Hours for complete training. Then I upgraded my system and now I am trying to train it on 4xA4000 ~64GB (82 FLOPS). Training time on new setup is increased to ~4200 Hours which is suprisezingly wrong. It should be lower than the previous setup because VRAM and computing power is increased. What is the best way to utilize multiple GPUs for LLM training ?

I am using following code block:

import json
import torch
import pandas as pd
import datasets
from peft import LoraConfig,PeftModel
from transformers import (AutoModelForCausalLM,AutoTokenizer,TrainingArguments,BitsAndBytesConfig)
import transformers
from trl import SFTTrainer
from training_args import *
import os

import logging
import sys

output_dir = "../training-results/"


if not os.path.exists(output_dir):
    # If the directory doesn't exist, create it
    os.makedirs(output_dir)
    print(f"Directory '{output_dir}' created.")
else:
    print(f"Directory '{output_dir}' already exists.")



MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf"


# loading dataset
dataset = datasets.load_from_disk("../dataset/codellama_1000L_59248E")
# loading model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,use_safetensors=True,load_in_8bit=True,trust_remote_code=True,device_map='auto')
# loading tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, add_special_tokens=False, add_eos_token=False, add_bos_token=False)
tokenizer.pad_token = "[PAD]"
# LORA Configuration
peft_config = LoraConfig(
    lora_alpha=32,
    lora_dropout=0.05,
    r = 32,
    bias="none",
    task_type = "CAUSAL_LM",
    target_modules = ["q_proj", "v_proj","k_proj","o_proj","gate_proj","up_proj","down_proj","lm_head"]
)



training_arguments = TrainingArguments(
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    learning_rate=4e-4,
    fp16=True,
    max_grad_norm=0.3,
    num_train_epochs=15,
    warmup_ratio=0.05,
    logging_steps=5,
    save_total_limit=50,
    save_strategy="steps",
    save_steps=1000,
    group_by_length=True,
    output_dir=output_dir,
    report_to="tensorboard",
    save_safetensors=True,
    lr_scheduler_type="cosine",
    seed=42)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=4096,
    tokenizer=tokenizer,
    args=training_arguments,
)

# trainer.tokenizer.pad_token = False
# trainer.tokenizer.pad_token

try:
    trainer.train()
except Exception as e:
    logger.error(f"Error in Logs due to {e}")

GPU USAGE

image

accelerate env

  • Accelerate version: 0.23.0
  • Platform: Linux-5.15.0-79-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Numpy version: 1.24.4
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • PyTorch XPU available: False
  • PyTorch NPU available: False
  • System RAM: 125.83 GB
  • GPU type: NVIDIA RTX A4000
  • Accelerate default config:
    - compute_environment: LOCAL_MACHINE
    - distributed_type: MULTI_GPU
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 4
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
@humza-sami
Copy link
Author

@lvwerra
Copy link
Member

lvwerra commented Oct 10, 2023

Can you compare also bitsandbytes versions? Could be that with an older versions you get the slower kernels.

@younesbelkada
Copy link
Contributor

Hi @Humza1996
LLM.int8 is slower than the recent QLoRA, I suggest first to try out fine-tuning with 4bit quantization by making sure you are using bnb_4bit_compute_dtype=torch.float16 for faster training.
An example here: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da
To further increase training speed I suggest you to try out Flash Attention with packing; more details can be found here: https://huggingface.co/docs/trl/main/en/sft_trainer#using-flash-attention-and-flash-attention-2

@mallorbc
Copy link

Hi @Humza1996 LLM.int8 is slower than the recent QLoRA, I suggest first to try out fine-tuning with 4bit quantization by making sure you are using bnb_4bit_compute_dtype=torch.float16 for faster training. An example here: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da To further increase training speed I suggest you to try out Flash Attention with packing; more details can be found here: https://huggingface.co/docs/trl/main/en/sft_trainer#using-flash-attention-and-flash-attention-2

Is torch.bfloat16 not a good option?

@mallorbc
Copy link

This looks similar to the issue I created at #921

Specifically how all the GPUs are idle except one.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants