Skip to content

Commit

Permalink
Check if the buffers fits GPU memory after device map auto inferred
Browse files Browse the repository at this point in the history
  * For some models, like TheBloke/WizardCoder-33B-V1.1-GPTQ, contain a
    huge buffer, which may cause OOM on GPU memory if not using
    offload_buffers. This commit adds a check for such case.
  • Loading branch information
notsyncing committed Feb 2, 2024
1 parent 7aafa25 commit 60b749c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,11 @@ def load_checkpoint_and_dispatch(
low_zero=(device_map == "balanced_low_0"),
)
device_map = infer_auto_device_map(
model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, dtype=dtype
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
dtype=dtype,
offload_buffers=offload_buffers,
)
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
offload_state_dict = True
Expand Down
66 changes: 65 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ def compute_module_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
buffers_only: bool = False,
):
"""
Compute the size of each submodule of a given model.
Expand All @@ -691,11 +692,18 @@ def compute_module_sizes(
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
for name, tensor in named_module_tensors(model, recurse=True):

module_list = named_module_tensors(model, recurse=True) if not buffers_only else model.named_buffers(recurse=True)

for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
Expand All @@ -705,6 +713,18 @@ def compute_module_sizes(
return module_sizes


def compute_module_total_buffer_size(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the total size of buffers in each submodule of a given model.
"""
module_sizes = compute_module_sizes(model, dtype, special_dtypes, True)
return module_sizes.get("", 0)


def get_max_layer_size(
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
):
Expand Down Expand Up @@ -1016,6 +1036,7 @@ def infer_auto_device_map(
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
verbose: bool = False,
clean_result: bool = True,
offload_buffers: bool = False,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down Expand Up @@ -1052,6 +1073,9 @@ def infer_auto_device_map(
Whether or not to provide debugging statements as the function builds the device_map.
clean_result (`bool`, *optional*, defaults to `True`):
Clean the resulting device_map by grouping all submodules that go on the same device together.
offload_buffers (`bool`, *optional*, defaults to `False`):
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
well as the parameters.
"""
# Get default / clean up max_memory
max_memory = get_max_memory(max_memory)
Expand Down Expand Up @@ -1084,6 +1108,8 @@ def infer_auto_device_map(
device_map = OrderedDict()
current_device = 0
current_memory_used = 0
device_memory_used = {}
device_buffer_sizes = {}

# Direct submodules and parameters
modules_to_treat = (
Expand Down Expand Up @@ -1149,6 +1175,8 @@ def infer_auto_device_map(
# -> no split, we go to the next device
if verbose:
print("This module cannot be split, going to the next device.")

device_memory_used[device] = current_memory_used
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand Down Expand Up @@ -1200,6 +1228,15 @@ def infer_auto_device_map(
modules_to_treat.pop(tied_module_index)
device_map[tied_module_name] = devices[current_device]

if not offload_buffers:
if device not in device_buffer_sizes.keys():
device_buffer_sizes[device] = 0

current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes[device] + current_buffer_size

else:
# We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it
# smaller or do we need to go on the next device?
Expand Down Expand Up @@ -1240,6 +1277,8 @@ def infer_auto_device_map(
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")

device_memory_used[device] = current_memory_used
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
Expand All @@ -1256,8 +1295,33 @@ def infer_auto_device_map(
current_memory_used += module_size
device_map[name] = devices[current_device]

if not offload_buffers:
if device not in device_buffer_sizes.keys():
device_buffer_sizes[device] = 0

current_buffer_size = compute_module_total_buffer_size(
module, dtype=dtype, special_dtypes=special_dtypes
)
device_buffer_sizes[device] = device_buffer_sizes[device] + current_buffer_size

if clean_result:
device_map = clean_device_map(device_map)

if not offload_buffers:
non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)

for gpu_device, gpu_max_memory in max_memory.items():
if gpu_device == "cpu" or gpu_device == "disk":
continue

gpu_memory_used = device_memory_used.get(gpu_device, 0)
if gpu_max_memory < non_gpu_buffer_size + gpu_memory_used:
raise ValueError(
f"Device {gpu_device} has {gpu_max_memory} bytes of memory, and has planned to use {gpu_memory_used} bytes, "
f"so the remaining memory is not enough to store additional buffers ({non_gpu_buffer_size} bytes)! "
f"Please consider using offload_buffers=True."
)

return device_map


Expand Down

0 comments on commit 60b749c

Please sign in to comment.