Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: RuntimeError: 'weight' must be 2-D issue #687

Merged
merged 3 commits into from
Sep 1, 2023
Merged

Fix: RuntimeError: 'weight' must be 2-D issue #687

merged 3 commits into from
Sep 1, 2023

Conversation

jp1924
Copy link
Contributor

@jp1924 jp1924 commented Aug 24, 2023

Fix #669

Problem description

In ZeRO3, issues like #669 are caused by running deepspeed.initialize on only one of the two models passed to DPO_Trainer.

Most users use .from_pretrained to get the weight of the model, and inside from_pretrained is the code below.

    if is_deepspeed_zero3_enabled():
        import deepspeed

        logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
        init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
    elif load_in_8bit or low_cpu_mem_usage:
        init_contexts.append(init_empty_weights())

    with ContextManagers(init_contexts):
        model = cls(config, *model_args, **model_kwargs)

Models that don't use ZeRO3 don't matter,
But most of the heavier models like LLM are run using ZeRO3, so from_pretrained, is_deepspeed_zero3_enabled() becomes True, it have to use the if statement.

But since ZeRO3 is parameter partitioning, does this mean that it need to use the _partition_param in deepspeed.zero.Init to divide the parameters in one layer by the number of GPUs?

At this time, the partitioned parameter is put into a .ds_tensor, and parameter.data(or weight, bias), where the original parameter was, is filled with a zero tensor of size 0.

So far, so good

The problem is that we need to do a deepspeed.initialize when training to load the partitioned parameters in .ds_tensor,
But ref_model didn't do deepspeed.initialize, so it can't load the .ds_tensor.

This causes an error like RuntimeError: 'weight' must be {n}-D because ref_model's parameter.data does not contain weight, as shown in the issue.

So to solve the issue, it need to deepspeed.initialize ref_model as well as model.

However, since prepare_model does not have a deepspeed wrapper,
so it need to add the process of wrapping the model with deepspeed using _prepare_deepspeed to the init.

@jp1924 jp1924 changed the title Update dpo_trainer.py Fix: RuntimeError: 'weight' must be 2-D issue Aug 24, 2023
@jp1924 jp1924 marked this pull request as draft August 29, 2023 00:57
@jp1924 jp1924 marked this pull request as ready for review August 29, 2023 06:59
@jp1924
Copy link
Contributor Author

jp1924 commented Aug 30, 2023

@lvwerra Can you review this PR if you don't mind?

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Aug 30, 2023

Thanks for the PR. A quick question: does deep speed allow you to initialize multiple models? I seem to have run into some related issue.

Could you also give a minimal command / config to run to show that your fix enables running, say, falcon-7b?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 30, 2023

The documentation is not available anymore as the PR was closed or merged.

@jp1924
Copy link
Contributor Author

jp1924 commented Aug 31, 2023

@vwxyzjn Here's the example code

I think the test doesn't pass if the ref_model is None, and I need to do a little more work on that.

example code

from typing import Dict

import torch
import torch._dynamo
from datasets import Dataset, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
)
from trl import DPOTrainer


def extract_anthropic_prompt(prompt_and_response):
    """Extract the anthropic prompt from a prompt and response pair."""
    search_term = "\n\nAssistant:"
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(
    split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None
) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_anthropic_prompt(sample["chosen"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"][len(prompt) :],
            "rejected": sample["rejected"][len(prompt) :],
        }

    return dataset.map(split_prompt_and_responses)


def main():
    training_args = TrainingArguments(
        per_device_train_batch_size=1,
        max_steps=10,
        remove_unused_columns=False,
        gradient_accumulation_steps=1,
        learning_rate=0.00001,
        evaluation_strategy="no",
        save_strategy="no",
        logging_strategy="steps",
        logging_steps=1,
        output_dir="./",
        report_to="none",
        deepspeed="./config.json",
    )

    model_name = "falcon-7b"
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )
    model_ref = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map="auto",
    )

    tokenizer.bos_token = "<s>"
    tokenizer.eos_token = "</s>"
    tokenizer.pad_token = tokenizer.eos_token

    with training_args.main_process_first():
        train_dataset = get_hh("train", sanity_check=True)
        eval_dataset = get_hh("test", sanity_check=True)

    dpo_trainer = DPOTrainer(
        model,
        model_ref,
        args=training_args,
        beta=0.1,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_length=2048,
        max_prompt_length=128,
    )

    # 6. train
    dpo_trainer.train()
    if training_args.local_rank == 0:
        model.save_pretrained(training_args.output_dir)
        tokenizer.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
    main()

DeepSpeed config

{
    "bf16": {
        "enabled": "auto"
    },
    "zero_allow_untested_optimizer": true,
    "zero_force_ds_cpu_optimizer": false,
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "allgather_partitions": true,
        "allgather_bucket_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto"
}

@jp1924
Copy link
Contributor Author

jp1924 commented Aug 31, 2023

we found the cause of the crash on test:
if we didn't specify the path to the config file for deepspeed in the TrainingArgument,
it would be replaced with None, which would be true when compared to
self.args.deepspeed != "", causing an error like AttributeError: 'NoneType' object has no attribute 'deepspeed_config'.

So I changed self.args.deepspeed != "" to self.is_deepspeed_enabled

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Left one comment, what do you think?

trl/trainer/dpo_trainer.py Outdated Show resolved Hide resolved
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@jp1924 jp1924 requested a review from younesbelkada August 31, 2023 23:31
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great to me thank you very much @jp1924 !

@lvwerra lvwerra merged commit 5bb4668 into huggingface:main Sep 1, 2023
@jp1924
Copy link
Contributor Author

jp1924 commented Sep 1, 2023

thank you accept PR!

kushal-tri pushed a commit to kushalarora/trl that referenced this pull request Sep 19, 2023
* Update dpo_trainer.py

* Fix: self.args.deepspeed > self.is_deepspeed_enabled

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@andrew-zm-ml
Copy link

Has this been addressed for PPOTrainer as well? I am getting this issue when trying to use ZeRO3, but inserting the lines

    ppo_trainer.ref_model = ppo_trainer.accelerator._prepare_deepspeed(ppo_trainer.ref_model)[0]
    ppo_trainer.ref_model.eval()

leads to the following error:

AssertionError: You can't use same Accelerator() instance with multiple models when using DeepSpeed

@lewtun
Copy link
Member

lewtun commented Sep 29, 2023

@andrew-zm-ml yes, we added full ZeRO-{1,2,3} integration for PPOTrainer in #758 which takes care of the limitations around only having 1 model per Accelerator()

lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* Update dpo_trainer.py

* Fix: self.args.deepspeed > self.is_deepspeed_enabled

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DeepSpeed Zero3 dpo accured embedding weight error
7 participants