Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could not run gemma 9b with fsdp #28

Open
imadoualid opened this issue Sep 29, 2024 · 0 comments
Open

Could not run gemma 9b with fsdp #28

imadoualid opened this issue Sep 29, 2024 · 0 comments

Comments

@imadoualid
Copy link

Hello,

I tried running llama 8b using the same config used with llama 3 70b and it worked (i just changed qlora to lora)
however when i tried reproducing it using gemma 2 9b i couldn't and i'am getting this error

ExitCode 1
ErrorMessage "torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 7.81 GiB. GPU 7 has a total capacty of 22.19 GiB of which 5.32 GiB is free. Process 55370 has 16.85 GiB memory in use. Of the allocated memory 16.09 GiB is allocated by PyTorch, and 377.60 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
 return model_forward(*args, **kwargs)
 File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 808, in __call__
 return convert_to_fp32(self.model_forward(*args, **kwargs))
 File "/opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
 return func(*args, **kwargs)
 File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
 output = self._fsdp_wrapped_m. Check troubleshooting guide for common errors: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-python-sdk-troubleshooting.html

this is my script

import logging
from dataclasses import dataclass, field
import os
import random
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, TrainingArguments
from trl.commands.cli_utils import TrlParser
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    set_seed,
)
from peft import LoraConfig


from trl import SFTTrainer

# Comment in if you want to use the Llama 3 instruct template but make sure to add modules_to_save
# LLAMA_3_CHAT_TEMPLATE="{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"

# Anthropic/Vicuna like template without the need for special tokens
LLAMA_3_CHAT_TEMPLATE = (
    "{% for message in messages %}"
    "{% if message['role'] == 'system' %}"
    "{{ message['content'] }}"
    "{% elif message['role'] == 'user' %}"
    "{{ '\n\nHuman: ' + message['content'] +  eos_token }}"
    "{% elif message['role'] == 'assistant' %}"
    "{{ '\n\nAssistant: '  + message['content'] +  eos_token  }}"
    "{% endif %}"
    "{% endfor %}"
    "{% if add_generation_prompt %}"
    "{{ '\n\nAssistant: ' }}"
    "{% endif %}"
)


# ACCELERATE_USE_FSDP=1 FSDP_CPU_RAM_EFFICIENT_LOADING=1 torchrun --nproc_per_node=4 ./scripts/run_fsdp_qlora.py --config llama_3_70b_fsdp_qlora.yaml


@dataclass
class ScriptArguments:
    train_dataset_path: str = field(
        default=None,
        metadata={"help": "Path to the dataset, e.g. /opt/ml/input/data/train/"},
    )
    test_dataset_path: str = field(
        default=None,
        metadata={"help": "Path to the dataset, e.g. /opt/ml/input/data/valid/"},
    )
    model_id: str = field(
        default=None, metadata={"help": "Model ID to use for SFT training"}
    )
    max_seq_length: int = field(
        default=512, metadata={"help": "The maximum sequence length for SFT Trainer"}
    )


def merge_and_save_model(model_id, adapter_dir, output_dir):
    from peft import PeftModel

    print("Trying to load a Peft model. It might take a while without feedback")
    base_model = AutoModelForCausalLM.from_pretrained(
        model_id,
        low_cpu_mem_usage=True,
    )
    peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
    model = peft_model.merge_and_unload()

    os.makedirs(output_dir, exist_ok=True)
    print(f"Saving the newly created merged model to {output_dir}")
    model.save_pretrained(output_dir, safe_serialization=True)
    base_model.config.save_pretrained(output_dir)


def training_function(script_args, training_args):
    ################
    # Dataset
    ################

    train_dataset = load_dataset(
        "json",
        data_files=os.path.join(script_args.train_dataset_path, "dataset.json"),
        split="train",
    )
    test_dataset = load_dataset(
        "json",
        data_files=os.path.join(script_args.test_dataset_path, "dataset.json"),
        split="train",
    )

    ################
    # Model & Tokenizer
    ################

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, use_fast=True, add_eos_token=False)
    tokenizer.padding_side = "right"
    tokenizer.chat_template = LLAMA_3_CHAT_TEMPLATE

    # template dataset
    def template_dataset(examples):
        return {
            "text": tokenizer.apply_chat_template(examples["messages"], tokenize=False)
        }

    train_dataset = train_dataset.map(template_dataset, remove_columns=["messages"])
    test_dataset = test_dataset.map(template_dataset, remove_columns=["messages"])

    # print random sample on rank 0
    if training_args.distributed_state.is_main_process:
        for index in random.sample(range(len(train_dataset)), 2):
            print(train_dataset[index]["text"])
    training_args.distributed_state.wait_for_everyone()  # wait for all processes to print

    # Model
    torch_dtype = torch.bfloat16
    quant_storage_dtype = torch.bfloat16

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch_dtype,
        bnb_4bit_quant_storage=quant_storage_dtype,
        
    )

    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_id,
        #quantization_config=quantization_config,
        #attn_implementation="flash_attention_2",
        attn_implementation="eager",
        torch_dtype=quant_storage_dtype,
        use_cache=(
            False if training_args.gradient_checkpointing else True
        ),  # this is needed for gradient checkpointing
    )

    if training_args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    ################
    # PEFT
    ################

    # LoRA config based on QLoRA paper & Sebastian Raschka experiment
    peft_config = LoraConfig(
        lora_alpha=8,
        lora_dropout=0.05,
        r=16,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
        # modules_to_save = ["lm_head", "embed_tokens"] # add if you want to use the Llama 3 instruct template
    )

    ################
    # Training
    ################
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        dataset_text_field="text",
        eval_dataset=test_dataset,
        peft_config=peft_config,
        max_seq_length=script_args.max_seq_length,
        tokenizer=tokenizer,
        packing=True,
        dataset_kwargs={
            "add_special_tokens": False,  # We template with special tokens
            "append_concat_token": False,  # No need to add additional separator token
        },
    )
    if trainer.accelerator.is_main_process:
        trainer.model.print_trainable_parameters()
    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
    print("here !!!!")
    ##########################
    # Train model
    ##########################
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    trainer.train(resume_from_checkpoint=checkpoint)

    #########################################
    # SAVE ADAPTER AND CONFIG FOR SAGEMAKER
    #########################################
    # save adapter
    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
    trainer.save_model()

    del model
    del trainer
    torch.cuda.empty_cache()  # Clears the cache
    # load and merge
    if training_args.distributed_state.is_main_process:
        merge_and_save_model(
            script_args.model_id, training_args.output_dir, "/opt/ml/model"
        )
        tokenizer.save_pretrained("/opt/ml/model")
    training_args.distributed_state.wait_for_everyone()  # wait for all processes to print


if __name__ == "__main__":
    parser = TrlParser((ScriptArguments, TrainingArguments))
    script_args, training_args = parser.parse_args_and_config()

    # set use reentrant to False
    if training_args.gradient_checkpointing:
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
    # set seed
    set_seed(training_args.seed)

    # launch training
    training_function(script_args, training_args)

requirement.txt :

my transformers==4.44.2
datasets==3.0.0
accelerate==0.34.2
bitsandbytes==0.44.0
huggingface_hub==0.25.1
trl==0.11.0
peft==0.12.0
torch==2.1.1

yaml file

%%writefile gemma_2_9b_fsdp_qlora.yaml
# script parameters
model_id: "google/gemma-2-9b-it"# Hugging Face model id
max_seq_length:  2048 # 2048              # max sequence length for model and packing of the dataset
# sagemaker specific parameters
train_dataset_path: "/opt/ml/input/data/train/" # path to where SageMaker saves train dataset
test_dataset_path: "/opt/ml/input/data/test/"   # path to where SageMaker saves test dataset
# output_dir: "/opt/ml/model"            # path to where SageMaker will upload the model 
output_dir: "/tmp/llama3"            # path to where SageMaker will upload the model 
# training parameters
report_to: "tensorboard"               # report metrics to tensorboard
learning_rate: 0.0002                  # learning rate 2e-4
lr_scheduler_type: "constant"          # learning rate scheduler
num_train_epochs: 1                    # number of training epochs
per_device_train_batch_size: 1         # batch size per device during training
per_device_eval_batch_size: 1          # batch size for evaluation
gradient_accumulation_steps: 1         # number of steps before performing a backward/update pass
optim: adamw_torch                     # use torch adamw optimizer
logging_steps: 10                      # log every 10 steps
save_strategy: epoch                   # save checkpoint every epoch
evaluation_strategy: epoch             # evaluate every epoch
max_grad_norm: 0.3                     # max gradient norm
warmup_ratio: 0.03                     # warmup ratio
bf16: true                             # use bfloat16 precision
tf32: true                             # use tf32 precision
gradient_checkpointing: true           # use gradient checkpointing to save memory
# FSDP parameters: https://huggingface.co/docs/transformers/main/en/fsdp
fsdp: "full_shard auto_wrap offload" # remove offload if enough GPU memory
fsdp_config:
  backward_prefetch: "backward_pre"
  forward_prefetch: "false"
  use_orig_params: "false"

code to instantiate Huggingface estimator

from sagemaker.huggingface import HuggingFace
from huggingface_hub import HfFolder

# define Training Job Name 
job_name = f'gemma-9b-test'

# create the Estimator
huggingface_estimator = HuggingFace(
    entry_point          = 'run_fsdp_qlora.py',      # train script
    source_dir           = 'gemma_scripts',  # directory which includes all the files needed for training
    instance_type        = 'ml.g5.48xlarge',  # instances type used for the training job
    instance_count       = 1,                 # the number of instances used for training
    max_run              = 2*24*60*60,        # maximum runtime in seconds (days * hours * minutes * seconds)
    base_job_name        = job_name,          # the name of the training job
    role                 = role,              # Iam role used in training job to access AWS ressources, e.g. S3
    volume_size          = 500,               # the size of the EBS volume in GB
    transformers_version = '4.36.0',          # the transformers version used in the training job
    pytorch_version      = '2.1.0',           # the pytorch_version version used in the training job
    py_version           = 'py310',           # the python version used in the training job
    hyperparameters      =  {
        "config": "/opt/ml/input/data/config/gemma_2_9b_fsdp_qlora.yaml" # path to TRL config which was uploaded to s3
    },
    disable_output_compression = True,        # not compress output to save training time and cost
    distribution={"torch_distributed": {"enabled": True}},   # enables torchrun
    environment  = {
        "HUGGINGFACE_HUB_CACHE": "/tmp/.cache", # set env variable to cache models in /tmp
        "HF_TOKEN": HfFolder.get_token(),       # huggingface token to access gated models, e.g. llama 3
        "ACCELERATE_USE_FSDP": "1",             # enable FSDP
        "FSDP_CPU_RAM_EFFICIENT_LOADING": "1"   # enable CPU RAM efficient loading
    }, 
)```

any idea of what to do ? 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant