forked from invoke-ai/InvokeAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Partial Loading PR2: Add utils to support partial loading of models f…
…rom CPU to GPU (invoke-ai#7494) ## Summary This PR adds utilities to support partial loading of models from CPU to GPU. The new utilities are not yet being used by the ModelCache, so there should be no functional behavior changes in this PR. Detailed changes: - Add autocast modules that are designed to wrap common `torch.nn.Module`s and enable them to run with automatic device casting. E.g. a linear layer on the CPU can be executed with an input tensor on the GPU by streaming the weights to the GPU at runtime. - Add unit tests for the aforementioned autocast modules to verify that they work for all supported quantization formats (GGUF, BnB NF4, BnB LLM.int8()). - Add `CachedModelWithPartialLoad` and `CachedModelOnlyFullLoad` classes to manage partial loading at the model level. ## Alternative Implementations Several options were explored for supporting inference on partially-loaded models. The pros/cons of the explored options are summarized here for reference. In the end, wrapper modules were selected as the best overall solution for our use case. Option 1: Re-implement the .forward() methods of modules to add support for device conversions - This is the option implemented in this PR. - This approach is the most manual of the three, but as a result offers the broadest compatibility with unusual model types. It is manual in that we have to explicitly add support for all module types that we wish to support. Fortunately, the list of foundational module types is relatively small (e.g. the current set of implemented layers covers all but 0.04 MB of the full FLUX model.). Option 2: Implement a custom Tensor type that casts tensors to a `target_device` each time the tensor is used - This approach has the nice property that it is injected at the tensor level, and the model does not need to be modified in any way. - One challenge with this approach is handling interactions with other custom tensor types (e.g. GGMLTensor). This problem is solvable, but definitely introduces a layer of complexity. (There are likely to also be some similar issues with interactions with the BnB quantization, but I didn't get as far as testing BnB.) Option 3: Override the `__torch_function__` dispatch calls globally and cast all params to the execution device. - This approach is nice and simple: just apply a global context manager and all operations will happen on the compute device regardless of the device of the participating tensors. - Challenges: - Overriding the `__torch_function__` dispatch calls introduces some overhead even if the tensors are already on the correct device. - It is difficult to manage the autocasting context manager. E.g. it is tempting to apply it to the model's `.forward(...)` method, but we use some models with non-standard entrypoints. And we don't want to end up with nested autocasting context managers. - BnB applies quantization side effects when a param is moved to the GPU - this interacts in unexpected ways with a global context manager. ## QA Instructions Most of the changes in this PR should not impact active code, and thus should not cause any changes to behavior. The main risks come from bumping the bitsandbytes dependency and some minor modifications to the bitsandbytes quantization code. - [x] Regression test bitsandbytes NF4 quantization - [x] Regression test bitsandbytes LLM.int8() quantization - [x] Regression test on MacOS (to ensure that there are no lingering bitsandbytes import errors) I also tested the new utilities for inference on full models in another branch to validate that there were not major issues. This functionality will be tested more thoroughly in a future PR. ## Merge Plan - [x] invoke-ai#7492 should be merged first so that the target branch can be updated to main. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
- Loading branch information
Showing
16 changed files
with
1,302 additions
and
8 deletions.
There are no files selected for viewing
93 changes: 93 additions & 0 deletions
93
invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from typing import Any | ||
|
||
import torch | ||
|
||
|
||
class CachedModelOnlyFullLoad: | ||
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device. | ||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory, | ||
MPS memory, etc. | ||
""" | ||
|
||
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int): | ||
"""Initialize a CachedModelOnlyFullLoad. | ||
Args: | ||
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU. | ||
compute_device (torch.device): The compute device to move the model to. | ||
total_bytes (int): The total size (in bytes) of all the weights in the model. | ||
""" | ||
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases. | ||
self._model = model | ||
self._compute_device = compute_device | ||
self._offload_device = torch.device("cpu") | ||
|
||
# A CPU read-only copy of the model's state dict. | ||
self._cpu_state_dict: dict[str, torch.Tensor] | None = None | ||
if isinstance(model, torch.nn.Module): | ||
self._cpu_state_dict = model.state_dict() | ||
|
||
self._total_bytes = total_bytes | ||
self._is_in_vram = False | ||
|
||
@property | ||
def model(self) -> torch.nn.Module: | ||
return self._model | ||
|
||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: | ||
"""Get a read-only copy of the model's state dict in RAM.""" | ||
# TODO(ryand): Document this better. | ||
return self._cpu_state_dict | ||
|
||
def total_bytes(self) -> int: | ||
"""Get the total size (in bytes) of all the weights in the model.""" | ||
return self._total_bytes | ||
|
||
def cur_vram_bytes(self) -> int: | ||
"""Get the size (in bytes) of the weights that are currently in VRAM.""" | ||
if self._is_in_vram: | ||
return self._total_bytes | ||
else: | ||
return 0 | ||
|
||
def is_in_vram(self) -> bool: | ||
"""Return true if the model is currently in VRAM.""" | ||
return self._is_in_vram | ||
|
||
def full_load_to_vram(self) -> int: | ||
"""Load all weights into VRAM (if supported by the model). | ||
Returns: | ||
The number of bytes loaded into VRAM. | ||
""" | ||
if self._is_in_vram: | ||
# Already in VRAM. | ||
return 0 | ||
|
||
if not hasattr(self._model, "to"): | ||
# Model doesn't support moving to a device. | ||
return 0 | ||
|
||
if self._cpu_state_dict is not None: | ||
new_state_dict: dict[str, torch.Tensor] = {} | ||
for k, v in self._cpu_state_dict.items(): | ||
new_state_dict[k] = v.to(self._compute_device, copy=True) | ||
self._model.load_state_dict(new_state_dict, assign=True) | ||
self._model.to(self._compute_device) | ||
|
||
self._is_in_vram = True | ||
return self._total_bytes | ||
|
||
def full_unload_from_vram(self) -> int: | ||
"""Unload all weights from VRAM. | ||
Returns: | ||
The number of bytes unloaded from VRAM. | ||
""" | ||
if not self._is_in_vram: | ||
# Already in RAM. | ||
return 0 | ||
|
||
if self._cpu_state_dict is not None: | ||
self._model.load_state_dict(self._cpu_state_dict, assign=True) | ||
self._model.to(self._offload_device) | ||
|
||
self._is_in_vram = False | ||
return self._total_bytes |
201 changes: 201 additions & 0 deletions
201
...eai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import torch | ||
|
||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( | ||
AUTOCAST_MODULE_TYPE_MAPPING, | ||
apply_custom_layers_to_model, | ||
remove_custom_layers_from_model, | ||
) | ||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size | ||
from invokeai.backend.util.logging import InvokeAILogger | ||
|
||
|
||
def set_nested_attr(obj: object, attr: str, value: object): | ||
"""A helper function that extends setattr() to support nested attributes. | ||
Example: | ||
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight) | ||
""" | ||
attrs = attr.split(".") | ||
for attr in attrs[:-1]: | ||
obj = getattr(obj, attr) | ||
setattr(obj, attrs[-1], value) | ||
|
||
|
||
class CachedModelWithPartialLoad: | ||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device. | ||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory, | ||
MPS memory, etc. | ||
""" | ||
|
||
def __init__(self, model: torch.nn.Module, compute_device: torch.device): | ||
self._model = model | ||
self._compute_device = compute_device | ||
|
||
# A CPU read-only copy of the model's state dict. | ||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict() | ||
|
||
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting). | ||
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes. | ||
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values()) | ||
self._cur_vram_bytes: int | None = None | ||
|
||
self._modules_that_support_autocast = self._find_modules_that_support_autocast() | ||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast() | ||
|
||
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]: | ||
"""Find all modules that support autocasting.""" | ||
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING} | ||
|
||
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]: | ||
keys_in_modules_that_do_not_support_autocast = set() | ||
for key in self._cpu_state_dict.keys(): | ||
for module_name in self._modules_that_support_autocast.keys(): | ||
if key.startswith(module_name): | ||
break | ||
else: | ||
keys_in_modules_that_do_not_support_autocast.add(key) | ||
return keys_in_modules_that_do_not_support_autocast | ||
|
||
def _move_non_persistent_buffers_to_device(self, device: torch.device): | ||
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict, | ||
so we need to move them manually. | ||
""" | ||
# HACK(ryand): Typically, non-persistent buffers are moved when calling module.to(device). We don't move entire | ||
# modules, because we manage the devices of individual tensors using the state dict. Since non-persistent | ||
# buffers are not included in the state dict, we need to handle them manually. The only way to do this is by | ||
# using private torch.nn.Module attributes. | ||
for module in self._model.modules(): | ||
for name, buffer in module.named_buffers(): | ||
if name in module._non_persistent_buffers_set: | ||
module._buffers[name] = buffer.to(device, copy=True) | ||
|
||
@property | ||
def model(self) -> torch.nn.Module: | ||
return self._model | ||
|
||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: | ||
"""Get a read-only copy of the model's state dict in RAM.""" | ||
# TODO(ryand): Document this better. | ||
return self._cpu_state_dict | ||
|
||
def total_bytes(self) -> int: | ||
"""Get the total size (in bytes) of all the weights in the model.""" | ||
return self._total_bytes | ||
|
||
def cur_vram_bytes(self) -> int: | ||
"""Get the size (in bytes) of the weights that are currently in VRAM.""" | ||
if self._cur_vram_bytes is None: | ||
cur_state_dict = self._model.state_dict() | ||
self._cur_vram_bytes = sum( | ||
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type | ||
) | ||
return self._cur_vram_bytes | ||
|
||
def full_load_to_vram(self) -> int: | ||
"""Load all weights into VRAM.""" | ||
return self.partial_load_to_vram(self.total_bytes()) | ||
|
||
def full_unload_from_vram(self) -> int: | ||
"""Unload all weights from VRAM.""" | ||
return self.partial_unload_from_vram(self.total_bytes()) | ||
|
||
@torch.no_grad() | ||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: | ||
"""Load more weights into VRAM without exceeding vram_bytes_to_load. | ||
Returns: | ||
The number of bytes loaded into VRAM. | ||
""" | ||
# TODO(ryand): Handle the case where an exception is thrown while loading or unloading weights. At the very | ||
# least, we should reset self._cur_vram_bytes to None. | ||
|
||
vram_bytes_loaded = 0 | ||
|
||
cur_state_dict = self._model.state_dict() | ||
|
||
# First, process the keys *must* be loaded into VRAM. | ||
for key in self._keys_in_modules_that_do_not_support_autocast: | ||
param = cur_state_dict[key] | ||
if param.device.type == self._compute_device.type: | ||
continue | ||
|
||
param_size = calc_tensor_size(param) | ||
cur_state_dict[key] = param.to(self._compute_device, copy=True) | ||
vram_bytes_loaded += param_size | ||
|
||
if vram_bytes_loaded > vram_bytes_to_load: | ||
logger = InvokeAILogger.get_logger() | ||
logger.warning( | ||
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were " | ||
"requested. This is the minimum set of weights in VRAM required to run the model." | ||
) | ||
|
||
# Next, process the keys that can optionally be loaded into VRAM. | ||
fully_loaded = True | ||
for key, param in cur_state_dict.items(): | ||
if param.device.type == self._compute_device.type: | ||
continue | ||
|
||
param_size = calc_tensor_size(param) | ||
if vram_bytes_loaded + param_size > vram_bytes_to_load: | ||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really | ||
# worth continuing to search for a smaller parameter that would fit? | ||
fully_loaded = False | ||
continue | ||
|
||
cur_state_dict[key] = param.to(self._compute_device, copy=True) | ||
vram_bytes_loaded += param_size | ||
|
||
if vram_bytes_loaded > 0: | ||
# We load the entire state dict, not just the parameters that changed, in case there are modules that | ||
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict. | ||
# Alternatively, in the future, grouping parameters by module could probably solve this problem. | ||
self._model.load_state_dict(cur_state_dict, assign=True) | ||
|
||
if self._cur_vram_bytes is not None: | ||
self._cur_vram_bytes += vram_bytes_loaded | ||
|
||
if fully_loaded: | ||
remove_custom_layers_from_model(self._model) | ||
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync. | ||
else: | ||
apply_custom_layers_to_model(self._model) | ||
|
||
# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in | ||
# the vram_bytes_loaded tracking. | ||
self._move_non_persistent_buffers_to_device(self._compute_device) | ||
|
||
return vram_bytes_loaded | ||
|
||
@torch.no_grad() | ||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int: | ||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded. | ||
Returns: | ||
The number of bytes unloaded from VRAM. | ||
""" | ||
vram_bytes_freed = 0 | ||
|
||
offload_device = "cpu" | ||
cur_state_dict = self._model.state_dict() | ||
for key, param in cur_state_dict.items(): | ||
if vram_bytes_freed >= vram_bytes_to_free: | ||
break | ||
|
||
if param.device.type == offload_device: | ||
continue | ||
|
||
cur_state_dict[key] = self._cpu_state_dict[key] | ||
vram_bytes_freed += calc_tensor_size(param) | ||
|
||
if vram_bytes_freed > 0: | ||
self._model.load_state_dict(cur_state_dict, assign=True) | ||
|
||
if self._cur_vram_bytes is not None: | ||
self._cur_vram_bytes -= vram_bytes_freed | ||
|
||
# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom | ||
# layers. | ||
apply_custom_layers_to_model(self._model) | ||
return vram_bytes_freed |
Empty file.
50 changes: 50 additions & 0 deletions
50
invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
|
||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device | ||
|
||
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device. | ||
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved: | ||
# - isinstance(m, torch.nn.OrginalModule) should still work. | ||
# - Patching the weights (e.g. for LoRA) should still work if non-quantized. | ||
|
||
|
||
class CustomLinear(torch.nn.Linear): | ||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
weight = cast_to_device(self.weight, input.device) | ||
bias = cast_to_device(self.bias, input.device) | ||
return torch.nn.functional.linear(input, weight, bias) | ||
|
||
|
||
class CustomConv1d(torch.nn.Conv1d): | ||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
weight = cast_to_device(self.weight, input.device) | ||
bias = cast_to_device(self.bias, input.device) | ||
return self._conv_forward(input, weight, bias) | ||
|
||
|
||
class CustomConv2d(torch.nn.Conv2d): | ||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
weight = cast_to_device(self.weight, input.device) | ||
bias = cast_to_device(self.bias, input.device) | ||
return self._conv_forward(input, weight, bias) | ||
|
||
|
||
class CustomGroupNorm(torch.nn.GroupNorm): | ||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
weight = cast_to_device(self.weight, input.device) | ||
bias = cast_to_device(self.bias, input.device) | ||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) | ||
|
||
|
||
class CustomEmbedding(torch.nn.Embedding): | ||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
weight = cast_to_device(self.weight, input.device) | ||
return torch.nn.functional.embedding( | ||
input, | ||
weight, | ||
self.padding_idx, | ||
self.max_norm, | ||
self.norm_type, | ||
self.scale_grad_by_freq, | ||
self.sparse, | ||
) |
15 changes: 15 additions & 0 deletions
15
invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_device.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from typing import TypeVar | ||
|
||
import torch | ||
|
||
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None) | ||
|
||
|
||
def cast_to_device(t: T, to_device: torch.device) -> T: | ||
"""Helper function to cast an optional tensor to a target device.""" | ||
if t is None: | ||
return t | ||
|
||
if t.device.type != to_device.type: | ||
return t.to(to_device) | ||
return t |
27 changes: 27 additions & 0 deletions
27
...end/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import bitsandbytes as bnb | ||
import torch | ||
|
||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device | ||
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt | ||
|
||
|
||
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt): | ||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
matmul_state = bnb.MatmulLtState() | ||
matmul_state.threshold = self.state.threshold | ||
matmul_state.has_fp16_weights = self.state.has_fp16_weights | ||
matmul_state.use_pool = self.state.use_pool | ||
matmul_state.is_training = self.training | ||
# The underlying InvokeInt8Params weight must already be quantized. | ||
assert self.weight.CB is not None | ||
matmul_state.CB = cast_to_device(self.weight.CB, x.device) | ||
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device) | ||
|
||
# weights are cast automatically as Int8Params, but the bias has to be cast manually. | ||
if self.bias is not None and self.bias.dtype != x.dtype: | ||
self.bias.data = self.bias.data.to(x.dtype) | ||
|
||
# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but | ||
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be | ||
# on the wrong device. | ||
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state) |
Oops, something went wrong.