Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Officially support naive PP for quantized models + PEFT #1523

Merged
merged 4 commits into from
Jun 6, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jun 5, 2023

What does this PR do?

Fixes #1515

Naive Pipeline Parallelism should be supported by accelerate and should work, if we properly educate users on how to use it.

What is NPP?

It is the simplest paradigm for running a model across multiple GPUs. It tries to evenly fit the model across all available GPUs (e.g. device_map="auto")

npp

When to use it and when not to use it?

Initially I added that check because I was afraid users will train 8bit models that are loaded across multiple GPUs and under multi-GPU distributed regime. In that case the model will be converted to DDP (which is fine if the model fits in a single GPU and duplicated across multiple GPUs (Data Parallelism)) - which can lead to many breaking behaviours such as huggingface/peft#269 (comment) .
The fix is to relax the check constraint and to also check if we are under multi GPU distributed regime (expects to use DDP).

In TRL library, it is possible to use the PPOTrainer (that calls accelerator.prepare under the hood) to apply Naive Pipeline Parallelism: https://huggingface.co/docs/trl/main/en/lora_tuning_peft#naive-pipeline-parallelism-npp-for-large-models-60b-models to train 60B+ scale models using RLHF. The error was never raised there because I forgot to store the attribute hf_device_map inside the model class we use in TRL.

To reproduce (you need PEFT and run this script in a multi-GPU env):

from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training

model_id = "facebook/opt-350m"
accelerator = Accelerator()

config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    target_modules=["q_proj", "v_proj"], 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
model = prepare_model_for_int8_training(model)

print(set(model.hf_device_map.values()))

model = get_peft_model(model, config)

model = accelerator.prepare(model)

cc @sgugger @muellerzr

- relax check
- add test
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 5, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada requested a review from muellerzr June 5, 2023 10:53
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

src/accelerate/accelerator.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LG2M :)

@robinsonmhj
Copy link

robinsonmhj commented Dec 6, 2023

this feature is only in main, is there any plan to put into a new release so that I can use pip to install?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ValueError: You can't train a model that has been loaded in 8-bit precision on multiple devices.
5 participants