Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for gradient checkpointing for LLM fine-tuning #3613

Merged
merged 15 commits into from
Sep 15, 2023
12 changes: 12 additions & 0 deletions ludwig/schema/metadata/configs/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,18 @@ ecd:
Suggested to enable this if training is proceeding very slowly in distributed training (and GPU
utilization is low), or the batch size is very small and the loss curves look very spiky.
ui_display_name: Gradient Accumulation Steps
enable_gradient_checkpointing:
expected_impact: 2
ui_display_name: Enable Gradient Checkpointing
default_value_reasoning:
Gradient checkpointing is a technique to reduce the memory footprint of the model by
trading compute for memory. This is useful when training very large models that run into out of memory
errors very quickly during training. It is particularly helpful when doing non-quantization based training
(adapter based or full fine-tuning). Gradient checkpointing works by recomputing the activations of the
model during the backward pass, rather than storing them in memory during the forward pass.
This is a tradeoff between compute and memory, as the activations need to be recomputed during
the backward pass, but the memory footprint is reduced. This is set to false by default because
it is not always beneficial to use gradient checkpointing, and it can sometimes slow down training.
validation_field:
default_value_reasoning:
Concrete evaluation metrics are usually better than loss,
Expand Down
7 changes: 7 additions & 0 deletions ludwig/schema/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ def __post_init__(self):
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["compile"],
)

enable_gradient_checkpointing: bool = schema_utils.Boolean(
default=False,
description="Whether to enable gradient checkpointing, which trades compute for memory."
"This is useful for training very deep models with limited memory.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["enable_gradient_checkpointing"],
)

def update_batch_size_grad_accum(self, num_workers: int):
from ludwig.utils.trainer_utils import get_rendered_batch_size_grad_accum

Expand Down
19 changes: 19 additions & 0 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
TRAINING_PROGRESS_TRACKER_FILE_NAME,
)
from ludwig.models.ecd import ECD
from ludwig.models.llm import LLM
from ludwig.models.predictor import Predictor
from ludwig.modules.lr_scheduler import LRScheduler
from ludwig.modules.metric_modules import get_improved_fn, get_initial_validation_value
Expand Down Expand Up @@ -214,6 +215,24 @@ def prepare(self):

# We may need to replace the embedding layer when using 8-bit optimizers from bitsandbytes.
update_embedding_layer(self.compiled_model, self.config)

# Enable gradient checkpointing if configured
if self.config.enable_gradient_checkpointing:
# TODO(Arnav): Add support for gradient checkpointing in the compiled model
# when the model is an ECD model using torch.utils.checkpoint (torch.utils.checkpoint.sequential())
if not isinstance(self.compiled_model, LLM):
logger.warning("Gradient checkpointing is currently only supported for model_type: llm. Skipping...")
elif not hasattr(self.compiled_model, "model") and not hasattr(
self.compiled_model.model, "gradient_checkpointing_enable"
):
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
logger.warning("Gradient checkpointing is not supported by this model. Skipping...")
elif hasattr(self.compiled_model.model, "gradient_checkpointing_enable"):
self.compiled_model.model.gradient_checkpointing_enable()
self.compiled_model.model.enable_input_require_grads()
logger.info("Gradient checkpointing enabled for training.")
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise RuntimeError("Error when trying to enable gradient checkpointing.")

self.dist_model, self.optimizer = self.distributed.prepare(
self.compiled_model,
self.config,
Expand Down
29 changes: 29 additions & 0 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,35 @@ def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strat
assert preds


@pytest.mark.parametrize("use_adapter", [True, False], ids=["with_adapter", "without_adapter"])
def test_llm_training_with_gradient_checkpointing(tmpdir, csv_filename, use_adapter):
input_features = [text_feature(name="input", encoder={"type": "passthrough"})]
output_features = [text_feature(name="output")]

df = generate_data(input_features, output_features, filename=csv_filename, num_examples=25)

config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceM4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: input_features,
OUTPUT_FEATURES: output_features,
TRAINER: {
TYPE: "finetune",
BATCH_SIZE: 8,
EPOCHS: 1,
"enable_gradient_checkpointing": True,
},
}

if use_adapter:
config[ADAPTER] = {TYPE: "lora"}

model = LudwigModel(config)
assert model.config_obj.trainer.enable_gradient_checkpointing

model.train(dataset=df, output_directory=str(tmpdir), skip_save_processed_input=False)


def test_lora_wrap_on_init():
from peft import PeftModel
from transformers import PreTrainedModel
Expand Down
32 changes: 32 additions & 0 deletions tests/integration_tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,35 @@ def test_gradient_accumulation(gradient_accumulation_steps: int, tmpdir):
# convergence like gradient magnitudes, etc. Should also add distributed tests.
model = LudwigModel(config, backend=LocalTestBackend(), logging_level=logging.INFO)
model.train(training_set=data_csv, validation_set=val_csv, test_set=test_csv, output_directory=tmpdir)


def test_enable_gradient_checkpointing(tmpdir, caplog):
"""Test that gradient checkpointing is enabled when specified in the config and that it does not cause an error
when the model does not have support for gradient checkpointing."""
input_features = [text_feature()]
output_features = [category_feature(decoder={"vocab_size": 2}, reduce_input="sum")]

csv_filename = os.path.join(tmpdir, "training.csv")
data_csv = generate_data(input_features, output_features, csv_filename)
val_csv = shutil.copyfile(data_csv, os.path.join(tmpdir, "validation.csv"))
test_csv = shutil.copyfile(data_csv, os.path.join(tmpdir, "test.csv"))

config = {
"input_features": input_features,
"output_features": output_features,
"combiner": {"type": "concat", "output_size": 14},
TRAINER: {
"train_steps": 2,
"batch_size": 8,
"enable_gradient_checkpointing": True,
},
}

model = LudwigModel(config, backend=LocalTestBackend(), logging_level=logging.INFO)
assert model.config_obj.trainer.enable_gradient_checkpointing

model.train(training_set=data_csv, validation_set=val_csv, test_set=test_csv, output_directory=tmpdir)

# Check that the warning is emitted when the model does not support gradient checkpointing
# but does not prevent training from starting.
assert "Gradient checkpointing is currently only supported for model_type: llm. Skipping..." in caplog.text