Skip to content

Commit

Permalink
fix offload bug in accelerator (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy authored Oct 9, 2024
1 parent 9ca9fa1 commit 154693c
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions vptq/layers/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,34 @@ def make_quant_linear(module, quant_conf, name="", target_layer=None):
return


def attach_execution_device_hook(
module: torch.nn.Module,
execution_device,
skip_keys=None,
preload_module_classes=None,
tied_params_map=None,
):
"""
A bug of accelerate, https://github.com/huggingface/accelerate/issues/3060
we just hook it here to fix the bug.
"""
if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
accelerate.hooks.add_hook_to_module(
module,
accelerate.hooks.AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
)

# Break the recursion if we get to a preload module.
if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
return

for child in module.children():
attach_execution_device_hook(child,
execution_device,
tied_params_map=tied_params_map,
preload_module_classes=preload_module_classes)


class AutoModelForCausalLM(transformers.AutoModelForCausalLM):

@classmethod
Expand Down Expand Up @@ -93,20 +121,19 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# force to use one GPU as most as possible
model_buffer_size = accelerate.utils.modeling.compute_module_sizes(model, dtype=torch_dtype)[""]
local_max_memory = accelerate.utils.modeling.get_max_memory()
if 0 in local_max_memory and local_max_memory[0] * 0.85 > model_buffer_size:
if 0 in local_max_memory and local_max_memory[0] * 0.015 > model_buffer_size:
local_max_memory = {0: local_max_memory[0]}
if max_memory is None:
max_memory = local_max_memory

model = accelerate.load_checkpoint_and_dispatch(
model,
checkpoint=checkpoint,
device_map=device_map,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes[0],
dtype=torch_dtype,
# preload_module_classes=["VQuantLinear"]
)
max_memory[0] *= 0.1
accelerate.hooks.attach_execution_device_hook = attach_execution_device_hook
model = accelerate.load_checkpoint_and_dispatch(model,
checkpoint=checkpoint,
device_map=device_map,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes[0],
dtype=torch_dtype,
preload_module_classes=["VQuantLinear"])

# check cuda kernel exist
if importlib.util.find_spec("vptq.ops") is not None:
Expand Down

0 comments on commit 154693c

Please sign in to comment.