-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: RuntimeError: 'weight' must be 2-D issue #687
Conversation
@lvwerra Can you review this PR if you don't mind? |
Thanks for the PR. A quick question: does deep speed allow you to initialize multiple models? I seem to have run into some related issue. Could you also give a minimal command / config to run to show that your fix enables running, say, falcon-7b? |
The documentation is not available anymore as the PR was closed or merged. |
@vwxyzjn Here's the example code I think the test doesn't pass if the ref_model is None, and I need to do a little more work on that. example codefrom typing import Dict
import torch
import torch._dynamo
from datasets import Dataset, load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
TrainingArguments,
)
from trl import DPOTrainer
def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]
def get_hh(
split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None
) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
def split_prompt_and_responses(sample) -> Dict[str, str]:
prompt = extract_anthropic_prompt(sample["chosen"])
return {
"prompt": prompt,
"chosen": sample["chosen"][len(prompt) :],
"rejected": sample["rejected"][len(prompt) :],
}
return dataset.map(split_prompt_and_responses)
def main():
training_args = TrainingArguments(
per_device_train_batch_size=1,
max_steps=10,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=0.00001,
evaluation_strategy="no",
save_strategy="no",
logging_strategy="steps",
logging_steps=1,
output_dir="./",
report_to="none",
deepspeed="./config.json",
)
model_name = "falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model_ref = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
)
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.pad_token = tokenizer.eos_token
with training_args.main_process_first():
train_dataset = get_hh("train", sanity_check=True)
eval_dataset = get_hh("test", sanity_check=True)
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=2048,
max_prompt_length=128,
)
# 6. train
dpo_trainer.train()
if training_args.local_rank == 0:
model.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main() DeepSpeed config{
"bf16": {
"enabled": "auto"
},
"zero_allow_untested_optimizer": true,
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"allgather_partitions": true,
"allgather_bucket_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto"
} |
we found the cause of the crash on test: So I changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Left one comment, what do you think?
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great to me thank you very much @jp1924 !
thank you accept PR! |
* Update dpo_trainer.py * Fix: self.args.deepspeed > self.is_deepspeed_enabled * Update trl/trainer/dpo_trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Has this been addressed for
leads to the following error:
|
@andrew-zm-ml yes, we added full ZeRO-{1,2,3} integration for |
* Update dpo_trainer.py * Fix: self.args.deepspeed > self.is_deepspeed_enabled * Update trl/trainer/dpo_trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Fix #669
Problem description
In ZeRO3, issues like #669 are caused by running
deepspeed.initialize
on only one of the two models passed to DPO_Trainer.Most users use
.from_pretrained
to get the weight of the model, and inside from_pretrained is the code below.Models that don't use ZeRO3 don't matter,
But most of the heavier models like LLM are run using ZeRO3, so
from_pretrained
,is_deepspeed_zero3_enabled()
becomes True, it have to use the if statement.But since ZeRO3 is parameter partitioning, does this mean that it need to use the
_partition_param
indeepspeed.zero.Init
to divide the parameters in one layer by the number of GPUs?At this time, the partitioned parameter is put into a
.ds_tensor
, andparameter.data
(or weight, bias), where the original parameter was, is filled with a zero tensor of size 0.So far, so good
The problem is that we need to do a
deepspeed.initialize
when training to load the partitioned parameters in.ds_tensor
,But ref_model didn't do
deepspeed.initialize
, so it can't load the.ds_tensor
.This causes an error like
RuntimeError: 'weight' must be {n}-D
because ref_model'sparameter.data
does not contain weight, as shown in the issue.So to solve the issue, it need to
deepspeed.initialize
ref_model as well as model.However, since
prepare_model
does not have a deepspeed wrapper,so it need to add the process of wrapping the model with deepspeed using
_prepare_deepspeed
to the init.