Skip to content

Commit

Permalink
update to prepare_model_for_kbit_training (huggingface#728)
Browse files Browse the repository at this point in the history
* update to `prepare_model_for_kbit_training`

from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for huggingface#694

* workaround for gradient checkpointing issue

calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop

also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
  • Loading branch information
mnoukhov authored and Andrew Lapp committed May 10, 2024
1 parent a47f9e2 commit 311ad51
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 17 deletions.
4 changes: 2 additions & 2 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ trainer.train()
Pay attention to the following best practices when training a model with that trainer:

- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_int8_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.

Expand All @@ -346,4 +346,4 @@ Pay attention to the following best practices when training a model with that tr

## ConstantLengthDataset

[[autodoc]] trainer.ConstantLengthDataset
[[autodoc]] trainer.ConstantLengthDataset
4 changes: 2 additions & 2 deletions docs/source/using_llama_models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ model = AutoModelForCausalLM.from_pretrained(
load_in_8bit=True,
device_map={"": Accelerator().local_process_index}
)
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(model)

# add LoRA to model
lora_config = LoraConfig(
Expand Down Expand Up @@ -157,4 +157,4 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
ppo_trainer.log_stats(stats, batch, rewards)
```

For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
Expand Down Expand Up @@ -148,7 +149,7 @@ def create_datasets(tokenizer, args):
base_model = AutoModelForCausalLM.from_pretrained(
script_args.model_name,
quantization_config=bnb_config,
device_map={"": 0},
device_map={"": Accelerator().local_process_index},
trust_remote_code=True,
use_auth_token=True,
)
Expand Down
12 changes: 6 additions & 6 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
PeftModelForSeq2SeqLM,
PromptLearningConfig,
get_peft_model,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
)
from peft.peft_model import set_peft_model_state_dict

Expand Down Expand Up @@ -108,7 +108,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
`from_pretrained` method. We also pre-process the kwargs to extract
the arguments that are specific to the `transformers.PreTrainedModel`
class and the arguments that are specific to trl models. The kwargs
also support `prepare_model_for_int8_training` arguments from
also support `prepare_model_for_kbit_training` arguments from
`peft` library.
"""
if kwargs is not None:
Expand Down Expand Up @@ -203,7 +203,7 @@ class and the arguments that are specific to trl models. The kwargs
if peft_config is not None:
# Initialize a new peft adapter with the given config
if is_loaded_in_8bit or is_loaded_in_4bit:
pretrained_model = prepare_model_for_int8_training(
pretrained_model = prepare_model_for_kbit_training(
pretrained_model,
**peft_quantization_kwargs,
)
Expand All @@ -216,7 +216,7 @@ class and the arguments that are specific to trl models. The kwargs
if peft_config is not None and isinstance(pretrained_model, PreTrainedModel):
# Initialize a new peft adapter with the given config
if is_loaded_in_8bit or is_loaded_in_4bit:
pretrained_model = prepare_model_for_int8_training(
pretrained_model = prepare_model_for_kbit_training(
pretrained_model,
**peft_quantization_kwargs,
)
Expand Down Expand Up @@ -339,7 +339,7 @@ def _split_kwargs(cls, kwargs):
check_peft_kwargs = False

if is_peft_available():
from peft import prepare_model_for_int8_training
from peft import prepare_model_for_kbit_training

check_peft_kwargs = True

Expand All @@ -354,7 +354,7 @@ def _split_kwargs(cls, kwargs):
unsupported_kwargs[key] = value

if check_peft_kwargs:
if key in prepare_model_for_int8_training.__code__.co_varnames:
if key in prepare_model_for_kbit_training.__code__.co_varnames:
peft_kwargs[key] = value
if key in unsupported_kwargs:
unsupported_kwargs.pop(key)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


class DPOTrainer(Trainer):
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
model = get_peft_model(model, peft_config)

if model is not None:
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


class RewardTrainer(Trainer):
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)

model = get_peft_model(model, peft_config)

Expand Down
9 changes: 7 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -35,7 +36,7 @@


if is_peft_available():
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_int8_training
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training


class SFTTrainer(Trainer):
Expand Down Expand Up @@ -147,7 +148,11 @@ def __init__(
)

if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args.gradient_checkpointing
)

args = dataclasses.replace(args, gradient_checkpointing=False)

model = get_peft_model(model, peft_config)

Expand Down

0 comments on commit 311ad51

Please sign in to comment.