Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if the buffers fit GPU memory after device map auto inferred #2412

Merged
merged 3 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,11 @@ def load_checkpoint_and_dispatch(
low_zero=(device_map == "balanced_low_0"),
)
device_map = infer_auto_device_map(
model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, dtype=dtype
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
dtype=dtype,
offload_buffers=offload_buffers,
)
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
offload_state_dict = True
Expand Down
70 changes: 69 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import re
import shutil
import tempfile
import warnings
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -690,6 +691,7 @@ def compute_module_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
buffers_only: bool = False,
):
"""
Compute the size of each submodule of a given model.
Expand All @@ -701,7 +703,15 @@ def compute_module_sizes(
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
for name, tensor in named_module_tensors(model, recurse=True):

module_list = []

if not buffers_only:
module_list = named_module_tensors(model, recurse=True)
else:
module_list = model.named_buffers(recurse=True)

for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
Expand All @@ -719,6 +729,18 @@ def compute_module_sizes(
return module_sizes


def compute_module_total_buffer_size(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the total size of buffers in each submodule of a given model.
"""
module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True)
return module_sizes.get("", 0)


def get_max_layer_size(
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
):
Expand Down Expand Up @@ -1030,6 +1052,7 @@ def infer_auto_device_map(
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
clean_result: bool = True,
offload_buffers: bool = False,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down Expand Up @@ -1066,6 +1089,9 @@ def infer_auto_device_map(
Whether or not to provide debugging statements as the function builds the device_map.
clean_result (`bool`, *optional*, defaults to `True`):
Clean the resulting device_map by grouping all submodules that go on the same device together.
offload_buffers (`bool`, *optional*, defaults to `False`):
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
well as the parameters.
"""
# Get default / clean up max_memory
max_memory = get_max_memory(max_memory)
Expand Down Expand Up @@ -1098,6 +1124,8 @@ def infer_auto_device_map(
device_map = OrderedDict()
current_device = 0
current_memory_used = 0
device_memory_used = {}
device_buffer_sizes = {}

# Direct submodules and parameters
modules_to_treat = (
Expand Down Expand Up @@ -1147,9 +1175,11 @@ def infer_auto_device_map(

device = devices[current_device]
current_max_size = max_memory[device] if device != "disk" else None
current_memory_reserved = 0
# Reduce max size available by the largest layer.
if devices[current_device] in main_devices:
current_max_size = current_max_size - max_layer_size
current_memory_reserved = max_layer_size
# Case 1 -> We're too big!
if current_max_size is not None and current_memory_used + module_size > current_max_size:
# Split or not split?
Expand All @@ -1167,6 +1197,8 @@ def infer_auto_device_map(
# -> no split, we go to the next device
if verbose:
print("This module cannot be split, going to the next device.")

device_memory_used[device] = current_memory_used + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand Down Expand Up @@ -1218,6 +1250,12 @@ def infer_auto_device_map(
modules_to_treat.pop(tied_module_index)
device_map[tied_module_name] = devices[current_device]

if not offload_buffers and isinstance(module, nn.Module):
current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

else:
# We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it
# smaller or do we need to go on the next device?
Expand Down Expand Up @@ -1258,6 +1296,8 @@ def infer_auto_device_map(
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")

device_memory_used[device] = current_memory_used + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand All @@ -1272,10 +1312,38 @@ def infer_auto_device_map(
f"(available={current_max_size - current_memory_used})."
)
current_memory_used += module_size
device_memory_used[device] = current_memory_used + current_memory_reserved
device_map[name] = devices[current_device]

if not offload_buffers and isinstance(module, nn.Module):
current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

if clean_result:
device_map = clean_device_map(device_map)

non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)
if non_gpu_buffer_size > 0 and not offload_buffers:
is_buffer_fit_any_gpu = False
for gpu_device, gpu_max_memory in max_memory.items():
if gpu_device == "cpu" or gpu_device == "disk":
continue

if not is_buffer_fit_any_gpu:
gpu_memory_used = device_memory_used.get(gpu_device, 0)

if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:
is_buffer_fit_any_gpu = True

if len(gpus) > 0 and not is_buffer_fit_any_gpu:
warnings.warn(
f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does "
f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using "
f"offload_buffers=True."
)

return device_map


Expand Down
70 changes: 70 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import tempfile
import unittest
import warnings
from collections import OrderedDict

import torch
Expand All @@ -28,6 +29,7 @@
check_device_map,
clean_device_map,
compute_module_sizes,
compute_module_total_buffer_size,
convert_file_size_to_int,
find_tied_parameters,
get_balanced_memory,
Expand Down Expand Up @@ -297,6 +299,18 @@ def test_compute_module_sizes(self):
module_sizes = compute_module_sizes(model)
assert module_sizes == expected_sizes

def test_compute_module_total_buffer_size(self):
model = ModelForTest()
model.linear1.register_buffer("test_buffer", torch.zeros(10, 10))
model.register_buffer("test_buffer2", torch.zeros(20, 10))

buffer_size = compute_module_total_buffer_size(model)
assert buffer_size == 1240

model.half()
buffer_size = compute_module_total_buffer_size(model)
assert buffer_size == 624

def test_check_device_map(self):
model = ModelForTest()
check_device_map(model, {"": 0})
Expand Down Expand Up @@ -604,6 +618,62 @@ def test_infer_auto_device_map_on_t0pp(self):
assert device_map["encoder.embed_tokens"] == 0
assert device_map["decoder.embed_tokens"] == 0

def test_infer_auto_device_map_with_buffer_check(self):
model = ModelForTest()
model.linear1.register_buffer("test_buffer1", torch.zeros(10, 2))
model.batchnorm.register_buffer("test_buffer2", torch.zeros(10, 3))
model.linear2.register_buffer("test_buffer3", torch.zeros(10, 3))
# model has size 236(parameters) + 360(buffers): linear1 64 + 80, batchnorm 72 + 160, linear2 100 + 120

# Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit
# device 0, because they will also be loaded to device 0 all at once when inferencing without offload_buffers
# Should print a warning as intended in such case
with self.assertWarns(Warning):
device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"})
assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"}

# Only linear1 (144) fits on device 0, and remaining buffers (batchnorm's 160 + linear2's 120 = 280) won't fit
# device 0, but with offload_buffers they won't be loaded to device 0 all at once, so it's ok now
# Should NOT print a warning in such case
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
device_map = infer_auto_device_map(model, max_memory={0: 400, "cpu": "1GB"}, offload_buffers=True)
assert len(w) == 0
assert device_map == {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"}

def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self):
model = ModelForTest()
model.linear1.register_buffer("test_buffer1", torch.zeros(10, 2))
model.batchnorm.register_buffer("test_buffer2", torch.zeros(10, 3))
model.linear2.register_buffer("test_buffer3", torch.zeros(10, 3))
model.linear3 = nn.Linear(4, 5)
model.linear3.register_buffer("test_buffer4", torch.zeros(10, 2))
# model has size 336(parameters) + 440(buffers): linear1 64 + 80, batchnorm 72 + 160, linear2 100 + 120,
# linear3 100 + 80

# Now we have two devices, linear1 will fit on device 0, batchnorm will fit on device 1, and the second device
# can hold all remaining buffers
# Should NOT print a warning in such case
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 400, "cpu": "1GB"})
assert len(w) == 0
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

# Now we have two devices, but neither the first nor the second device can hold all remaining buffers
# Should print a warning as intended in such case
with self.assertWarns(Warning):
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"})
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

# Now we have two devices, neither can hold all the buffers, but we are using the offload_buffers=True
# Should NOT print a warning in such case
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 200, "cpu": "1GB"}, offload_buffers=True)
assert len(w) == 0
assert device_map == {"linear1": 0, "batchnorm": 1, "linear2": "cpu", "linear3": "cpu"}

@require_cuda
def test_get_balanced_memory(self):
model = ModelForTest()
Expand Down
Loading