Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent OOM during checkpoint save on colab for llama3-8b qlora recipe #1315

Merged
merged 36 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9194409
For testing: Set seed, shuffle=False, decrease max_steps so checkpoin…
mikaylagawarecki Aug 12, 2024
c3beb8a
Add wandb logger + change number of epochs/max_steps_per_epoch
mikaylagawarecki Aug 20, 2024
7a825d8
temp commit
mikaylagawarecki Aug 12, 2024
b0b9031
Remove unecessary imports
mikaylagawarecki Aug 15, 2024
327cb10
Changes for loss curves
mikaylagawarecki Aug 21, 2024
923b3b3
Use core APIs, refactor into state_dict post hook
mikaylagawarecki Aug 29, 2024
14a147b
Fix bad merge
mikaylagawarecki Aug 29, 2024
30f3b1e
lint
mikaylagawarecki Aug 29, 2024
731d062
Remove breakpoint, add docstring
mikaylagawarecki Aug 29, 2024
6dffac5
fix return
mikaylagawarecki Aug 29, 2024
1106e47
Remove changes to reload checkpoint every epoch + add gate on versioning
mikaylagawarecki Aug 29, 2024
2a4f19b
Update torch version in runtime error
mikaylagawarecki Aug 29, 2024
cc4bd24
low_cpu_ram --> False
mikaylagawarecki Aug 29, 2024
d17098c
lint
mikaylagawarecki Aug 29, 2024
8f1f633
Make suggested fixes
mikaylagawarecki Aug 30, 2024
f87ca63
Fix botched rebase
mikaylagawarecki Aug 30, 2024
8390868
Fix circular import
mikaylagawarecki Aug 30, 2024
a119a4d
lint
mikaylagawarecki Aug 30, 2024
fc15271
Fix bad rebase
mikaylagawarecki Aug 30, 2024
278e89e
lint
mikaylagawarecki Aug 30, 2024
f309bca
Changes to not use monkeypatching
mikaylagawarecki Sep 3, 2024
244896b
validate with fake resume_from_checkpoint again
mikaylagawarecki Sep 3, 2024
7b094df
Revert "validate with fake resume_from_checkpoint again"
mikaylagawarecki Sep 5, 2024
147b8b4
lint
mikaylagawarecki Sep 5, 2024
05250f5
set low_cpu_ram to False
mikaylagawarecki Sep 5, 2024
c2829db
fix doc
mikaylagawarecki Sep 5, 2024
8a16c76
bump version
mikaylagawarecki Sep 5, 2024
91f7d43
address comments
mikaylagawarecki Sep 9, 2024
b3364f6
fix rebase bug
mikaylagawarecki Sep 9, 2024
645891b
lint
mikaylagawarecki Sep 9, 2024
e44507a
fix bad rebase
mikaylagawarecki Sep 9, 2024
31f8a20
fix bad rebase again
mikaylagawarecki Sep 9, 2024
c808a60
Add changes from https://github.com/pytorch/torchtune/pull/1535
mikaylagawarecki Sep 10, 2024
922c73a
lora_magnitude is not None
mikaylagawarecki Sep 10, 2024
eba1ffa
changes for testing
mikaylagawarecki Sep 10, 2024
dc8c6b0
Revert "changes for testing"
mikaylagawarecki Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,6 @@ profiler:
warmup_steps: 5
active_steps: 2
num_cycles: 1

# For colab use True
low_cpu_ram: False
25 changes: 20 additions & 5 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from warnings import warn

import torch
import torchtune.modules.common_utils as common_utils
from omegaconf import DictConfig, ListConfig

from torch import nn
Expand Down Expand Up @@ -213,6 +214,10 @@ def setup(self, cfg: DictConfig) -> None:
self._compile = cfg.compile
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

# hack to toggle to the low cpu ram version of the reparametrize_as_dtype
# hook based on the config.
common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False)

# set up model
self._model = self._setup_model(
cfg_model=cfg.model,
Expand Down Expand Up @@ -525,6 +530,14 @@ def save_checkpoint(self, epoch: int) -> None:
# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}

# Construct the adapter weights
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}
Comment on lines +534 to +539
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we want to only run state_dict post hooks once, so we should reuse the state_dict

this does change the semantic slightly though -- before adapter_*.pt contained weights tagged with CUDA, but now it contains weights tagged with CPU.

Not sure whether the old behavior was intended/whether this change is ok (but no CI seems to fail :D) Also when loading, we map_location="cpu" regardless

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this comment before. But yeah I think this change makes sense, I don't think there's any reason we need to require CUDA weights. And as you point out since we load on CPU when resuming it was probably never really an issue. Plus not re-running the state dict post hooks is a nice bonus.


# Construct the full state dict with LoRA weights merged into base LLM weights
merged_state_dict = get_merged_lora_ckpt(
state_dict,
Expand All @@ -533,11 +546,6 @@ def save_checkpoint(self, epoch: int) -> None:
)
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})

# Construct the adapter weights
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
}
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
adapter_config = {
"r": self._lora_rank,
Expand Down Expand Up @@ -698,7 +706,14 @@ def train(self) -> None:
prof.step()

self.epochs_run += 1
start_save_checkpoint = time.perf_counter()
log.info("Starting checkpoint save...")
self.save_checkpoint(epoch=curr_epoch)
log.info(
"Checkpoint saved in {:.2f} seconds.".format(
time.perf_counter() - start_save_checkpoint
)
)

def cleanup(self) -> None:
self._metric_logger.close()
Expand Down
6 changes: 2 additions & 4 deletions torchtune/models/llama3/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TransformerSelfAttentionLayer,
)

from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks

from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear

Expand Down Expand Up @@ -256,9 +256,7 @@ def lora_llama3(
if quantize_base:
# For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly
# so as to not increase peak memory
model._register_state_dict_hook(
partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True)
)
_register_reparametrize_state_dict_hooks(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be done in this PR, but we can think about adding this for other models that would have a similar memory situation when running QLoRA (Llama 3.1 8B is an obvious choice, but there are a handful of other similarly-sized models supported in our repo that could benefit from this)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do in followup


return model

Expand Down
115 changes: 115 additions & 0 deletions torchtune/modules/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import mmap
import sys
from collections import OrderedDict
from functools import partial
from typing import Any, Dict, Tuple

import torch

import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorConverter, FakeTensorMode
from torchao.dtypes.nf4tensor import NF4Tensor

_use_low_cpu_ram: bool = False


def reparametrize_as_dtype_state_dict_post_hook(
model: nn.Module,
Expand Down Expand Up @@ -48,3 +55,111 @@ def reparametrize_as_dtype_state_dict_post_hook(
state_dict[k] = v.to(dtype)
if offload_to_cpu:
state_dict[k] = state_dict[k].cpu()


def _low_ram_reparametrize_as_dtype_state_dict_post_hook(
model: nn.Module,
state_dict: Dict[str, Any],
*args: Tuple[Any, ...],
dtype: torch.dtype = torch.bfloat16,
offload_to_cpu: bool = True,
**kwargs: Dict[Any, Any],
):
"""
A state_dict hook that replaces NF4 tensors with their restored
higher-precision weight and optionally offloads the restored weight to CPU.
Use this hook to avoid increased peak GPU memory usage during checkpoint
save when training with QLoRA.

This hook is similar to ``reparametrize_as_dtype_state_dict_post_hook`` but uses
FakeTensor and mmap(2) to avoid CPU OOM on colab.

This function is meant to be used with PyTorch's ``nn.Module._register_state_dict_hook``, i.e.

>>> m = MyModule()
>>> m._register_state_dict_hook(reparametrize_as_dtype_state_dict_post_hook)

If the hook is registered per the above process, this hook will be called _after_ the module's
``state_dict`` method is called. The hook will replace all ``NF4Tensor`` instances by unquantizing
them to the original dtype, and optionally offload the restored weight to CPU.

Args:
model (nn.Module): the model to take ``state_dict()`` on
state_dict (Dict[str, Any]): the state dict to modify
*args (Tuple[Any, ...]): Unused args passed when running this as a state_dict hook.
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.
**kwargs (Dict[Any, Any]): Unused keyword args passed when running this as a state_dict hook.
"""
# Create a state dict of FakeTensors that matches the state_dict
mode = FakeTensorMode()
converter = FakeTensorConverter()
fake_state_dict = OrderedDict()
for k, v in state_dict.items():
if isinstance(v, NF4Tensor):
fake_state_dict[k] = converter.from_real_tensor(mode, v).to(dtype)
else:
fake_state_dict[k] = converter.from_real_tensor(mode, v)

if offload_to_cpu:
fake_state_dict[k] = fake_state_dict[k].cpu()

# Create a state_dict on disk with space reserved for storage bytes
# Then load with mmap and MAP_SHARED (can writeback to disk file)
dest_state_dict_path = "/tmp/fake_state_dict.pt"
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(fake_state_dict, dest_state_dict_path)
with torch.serialization.set_default_mmap_options(mmap.MAP_SHARED):
dest_state_dict = torch.load(dest_state_dict_path, mmap=True, weights_only=True)

# Do D2H and upcast one by one and since dest_state_dict is backed by mmap --> won't OOM
# even when there is no swap space (e.g. colab)
for k in state_dict.keys():
if isinstance(state_dict[k], NF4Tensor):
dest_state_dict[k].copy_(state_dict[k].to(dtype))
else:
dest_state_dict[k].copy_(state_dict[k])

# In place update original state_dict object. Although the private state dict
# post hook supports out of place behavior, the semantic actually buggy. We eventually want
# to use the public state_dict post hook which does not support out of place behavior.
for k in state_dict.keys():
state_dict[k] = dest_state_dict[k]


def _register_reparametrize_state_dict_hooks(
module: nn.Module,
dtype: torch.dtype = torch.bfloat16,
offload_to_cpu: bool = True,
):
"""
Register the reparametrize state dict hooks to the module and its submodules.

This function is a wrapper that is meant to toggle between the low_cpu_ram
and regular versions of the ``reparametrize_as_dtype`` state dict hooks.

Args:
module (nn.Module): the module to register the hooks to.
dtype (torch.dtype): the dtype to restore the weight to. Default is ``torch.bfloat16``.
offload_to_cpu (bool): whether to offload the restored weight to CPU. Default is ``True``.

Raises:
RuntimeError: If the low RAM reparametrize hook is used on Windows or an incompatible torch version.
"""
if _use_low_cpu_ram:
if torch.__version__ < "2.5.0.dev20240906":
raise RuntimeError(
"Low RAM reparametrize_as_dtype_state_dict_post_hook requires PyTorch 2.5.0.dev20240906 or later."
)
elif sys.platform == "win32":
# mmap.MAP_SHARED is not supported on Windows but this change targets colab.
raise RuntimeError(
"Low RAM reparametrize_as_dtype_state_dict_post_hook is not supported on Windows."
)
else:
hook = _low_ram_reparametrize_as_dtype_state_dict_post_hook
else:
hook = reparametrize_as_dtype_state_dict_post_hook
module._register_state_dict_hook(
partial(hook, dtype=dtype, offload_to_cpu=offload_to_cpu)
)
19 changes: 13 additions & 6 deletions torchtune/modules/peft/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,21 +243,28 @@ def get_merged_lora_ckpt(
for module in lora_modules:
lora_a_weight = state_dict[f"{module}.lora_a.weight"]
lora_b_weight = state_dict[f"{module}.lora_b.weight"]
base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype)
lora_magnitude = state_dict.get(f"{module}.magnitude", None)

lora_weight = (alpha / rank) * lora_b_weight @ lora_a_weight
merged_weight = base_weight + lora_weight
# If magnitude is present, calculate merged DoRA weight
if lora_magnitude is not None:
base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype)

lora_weight = (alpha / rank) * lora_b_weight @ lora_a_weight
merged_weight = base_weight + lora_weight
weight_norm = torch.linalg.norm(base_weight + lora_weight, dim=1)
mag_norm_scale = (lora_magnitude / weight_norm).view(-1, 1)
merged_weight *= mag_norm_scale
state_dict[f"{module}.weight"] = merged_weight
state_dict[f"{module}.weight"] = merged_weight
del state_dict[f"{module}.magnitude"]

# Otherwise it is just vanilla LoRA
else:
state_dict[f"{module}.weight"] += (
(alpha / rank) * lora_b_weight @ lora_a_weight
)

del state_dict[f"{module}.lora_a.weight"]
del state_dict[f"{module}.lora_b.weight"]
if lora_magnitude is not None:
del state_dict[f"{module}.magnitude"]

return state_dict

Expand Down
Loading