Skip to content

Commit

Permalink
Partial Loading PR2: Add utils to support partial loading of models f…
Browse files Browse the repository at this point in the history
…rom CPU to GPU (invoke-ai#7494)

## Summary

This PR adds utilities to support partial loading of models from CPU to
GPU. The new utilities are not yet being used by the ModelCache, so
there should be no functional behavior changes in this PR.

Detailed changes:

- Add autocast modules that are designed to wrap common
`torch.nn.Module`s and enable them to run with automatic device casting.
E.g. a linear layer on the CPU can be executed with an input tensor on
the GPU by streaming the weights to the GPU at runtime.
- Add unit tests for the aforementioned autocast modules to verify that
they work for all supported quantization formats (GGUF, BnB NF4, BnB
LLM.int8()).
- Add `CachedModelWithPartialLoad` and `CachedModelOnlyFullLoad` classes
to manage partial loading at the model level.

## Alternative Implementations

Several options were explored for supporting inference on
partially-loaded models. The pros/cons of the explored options are
summarized here for reference. In the end, wrapper modules were selected
as the best overall solution for our use case.

Option 1: Re-implement the .forward() methods of modules to add support
for device conversions
- This is the option implemented in this PR.
- This approach is the most manual of the three, but as a result offers
the broadest compatibility with unusual model types. It is manual in
that we have to explicitly add support for all module types that we wish
to support. Fortunately, the list of foundational module types is
relatively small (e.g. the current set of implemented layers covers all
but 0.04 MB of the full FLUX model.).

Option 2: Implement a custom Tensor type that casts tensors to a
`target_device` each time the tensor is used
- This approach has the nice property that it is injected at the tensor
level, and the model does not need to be modified in any way.
- One challenge with this approach is handling interactions with other
custom tensor types (e.g. GGMLTensor). This problem is solvable, but
definitely introduces a layer of complexity. (There are likely to also
be some similar issues with interactions with the BnB quantization, but
I didn't get as far as testing BnB.)

Option 3: Override the `__torch_function__` dispatch calls globally and
cast all params to the execution device.
- This approach is nice and simple: just apply a global context manager
and all operations will happen on the compute device regardless of the
device of the participating tensors.
- Challenges:
- Overriding the `__torch_function__` dispatch calls introduces some
overhead even if the tensors are already on the correct device.
- It is difficult to manage the autocasting context manager. E.g. it is
tempting to apply it to the model's `.forward(...)` method, but we use
some models with non-standard entrypoints. And we don't want to end up
with nested autocasting context managers.
- BnB applies quantization side effects when a param is moved to the GPU
- this interacts in unexpected ways with a global context manager.


## QA Instructions

Most of the changes in this PR should not impact active code, and thus
should not cause any changes to behavior. The main risks come from
bumping the bitsandbytes dependency and some minor modifications to the
bitsandbytes quantization code.

- [x] Regression test bitsandbytes NF4 quantization
- [x] Regression test bitsandbytes LLM.int8() quantization
- [x] Regression test on MacOS (to ensure that there are no lingering
bitsandbytes import errors)

I also tested the new utilities for inference on full models in another
branch to validate that there were not major issues. This functionality
will be tested more thoroughly in a future PR.

## Merge Plan

- [x] invoke-ai#7492 should be merged first so that the target branch can be
updated to main.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Dec 27, 2024
2 parents d3916db + 0fc5387 commit 6bf5b74
Show file tree
Hide file tree
Showing 16 changed files with 1,302 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Any

import torch


class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""

def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
self._compute_device = compute_device
self._offload_device = torch.device("cpu")

# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
if isinstance(model, torch.nn.Module):
self._cpu_state_dict = model.state_dict()

self._total_bytes = total_bytes
self._is_in_vram = False

@property
def model(self) -> torch.nn.Module:
return self._model

def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better.
return self._cpu_state_dict

def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes

def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._is_in_vram:
return self._total_bytes
else:
return 0

def is_in_vram(self) -> bool:
"""Return true if the model is currently in VRAM."""
return self._is_in_vram

def full_load_to_vram(self) -> int:
"""Load all weights into VRAM (if supported by the model).
Returns:
The number of bytes loaded into VRAM.
"""
if self._is_in_vram:
# Already in VRAM.
return 0

if not hasattr(self._model, "to"):
# Model doesn't support moving to a device.
return 0

if self._cpu_state_dict is not None:
new_state_dict: dict[str, torch.Tensor] = {}
for k, v in self._cpu_state_dict.items():
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
self._model.to(self._compute_device)

self._is_in_vram = True
return self._total_bytes

def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM.
Returns:
The number of bytes unloaded from VRAM.
"""
if not self._is_in_vram:
# Already in RAM.
return 0

if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
self._model.to(self._offload_device)

self._is_in_vram = False
return self._total_bytes
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import torch

from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
AUTOCAST_MODULE_TYPE_MAPPING,
apply_custom_layers_to_model,
remove_custom_layers_from_model,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from invokeai.backend.util.logging import InvokeAILogger


def set_nested_attr(obj: object, attr: str, value: object):
"""A helper function that extends setattr() to support nested attributes.
Example:
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
"""
attrs = attr.split(".")
for attr in attrs[:-1]:
obj = getattr(obj, attr)
setattr(obj, attrs[-1], value)


class CachedModelWithPartialLoad:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""

def __init__(self, model: torch.nn.Module, compute_device: torch.device):
self._model = model
self._compute_device = compute_device

# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()

# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
self._cur_vram_bytes: int | None = None

self._modules_that_support_autocast = self._find_modules_that_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()

def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING}

def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
keys_in_modules_that_do_not_support_autocast = set()
for key in self._cpu_state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
break
else:
keys_in_modules_that_do_not_support_autocast.add(key)
return keys_in_modules_that_do_not_support_autocast

def _move_non_persistent_buffers_to_device(self, device: torch.device):
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
so we need to move them manually.
"""
# HACK(ryand): Typically, non-persistent buffers are moved when calling module.to(device). We don't move entire
# modules, because we manage the devices of individual tensors using the state dict. Since non-persistent
# buffers are not included in the state dict, we need to handle them manually. The only way to do this is by
# using private torch.nn.Module attributes.
for module in self._model.modules():
for name, buffer in module.named_buffers():
if name in module._non_persistent_buffers_set:
module._buffers[name] = buffer.to(device, copy=True)

@property
def model(self) -> torch.nn.Module:
return self._model

def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better.
return self._cpu_state_dict

def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes

def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._cur_vram_bytes is None:
cur_state_dict = self._model.state_dict()
self._cur_vram_bytes = sum(
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes

def full_load_to_vram(self) -> int:
"""Load all weights into VRAM."""
return self.partial_load_to_vram(self.total_bytes())

def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())

@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
Returns:
The number of bytes loaded into VRAM.
"""
# TODO(ryand): Handle the case where an exception is thrown while loading or unloading weights. At the very
# least, we should reset self._cur_vram_bytes to None.

vram_bytes_loaded = 0

cur_state_dict = self._model.state_dict()

# First, process the keys *must* be loaded into VRAM.
for key in self._keys_in_modules_that_do_not_support_autocast:
param = cur_state_dict[key]
if param.device.type == self._compute_device.type:
continue

param_size = calc_tensor_size(param)
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size

if vram_bytes_loaded > vram_bytes_to_load:
logger = InvokeAILogger.get_logger()
logger.warning(
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
"requested. This is the minimum set of weights in VRAM required to run the model."
)

# Next, process the keys that can optionally be loaded into VRAM.
fully_loaded = True
for key, param in cur_state_dict.items():
if param.device.type == self._compute_device.type:
continue

param_size = calc_tensor_size(param)
if vram_bytes_loaded + param_size > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
fully_loaded = False
continue

cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size

if vram_bytes_loaded > 0:
# We load the entire state dict, not just the parameters that changed, in case there are modules that
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
self._model.load_state_dict(cur_state_dict, assign=True)

if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded

if fully_loaded:
remove_custom_layers_from_model(self._model)
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
else:
apply_custom_layers_to_model(self._model)

# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in
# the vram_bytes_loaded tracking.
self._move_non_persistent_buffers_to_device(self._compute_device)

return vram_bytes_loaded

@torch.no_grad()
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
Returns:
The number of bytes unloaded from VRAM.
"""
vram_bytes_freed = 0

offload_device = "cpu"
cur_state_dict = self._model.state_dict()
for key, param in cur_state_dict.items():
if vram_bytes_freed >= vram_bytes_to_free:
break

if param.device.type == offload_device:
continue

cur_state_dict[key] = self._cpu_state_dict[key]
vram_bytes_freed += calc_tensor_size(param)

if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)

if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed

# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom
# layers.
apply_custom_layers_to_model(self._model)
return vram_bytes_freed
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch

from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device

# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
# - isinstance(m, torch.nn.OrginalModule) should still work.
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.


class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)


class CustomConv1d(torch.nn.Conv1d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)


class CustomConv2d(torch.nn.Conv2d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)


class CustomGroupNorm(torch.nn.GroupNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)


class CustomEmbedding(torch.nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
return torch.nn.functional.embedding(
input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import TypeVar

import torch

T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)


def cast_to_device(t: T, to_device: torch.device) -> T:
"""Helper function to cast an optional tensor to a target device."""
if t is None:
return t

if t.device.type != to_device.type:
return t.to(to_device)
return t
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import bitsandbytes as bnb
import torch

from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt


class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
def forward(self, x: torch.Tensor) -> torch.Tensor:
matmul_state = bnb.MatmulLtState()
matmul_state.threshold = self.state.threshold
matmul_state.has_fp16_weights = self.state.has_fp16_weights
matmul_state.use_pool = self.state.use_pool
matmul_state.is_training = self.training
# The underlying InvokeInt8Params weight must already be quantized.
assert self.weight.CB is not None
matmul_state.CB = cast_to_device(self.weight.CB, x.device)
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device)

# weights are cast automatically as Int8Params, but the bias has to be cast manually.
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
# on the wrong device.
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
Loading

0 comments on commit 6bf5b74

Please sign in to comment.