From bb8660a89c5cbb1489df54507f973860d07cac8c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Dec 2024 15:52:13 -0500 Subject: [PATCH] account for slightly different update param behavior (#1005) Signed-off-by: Kyle Sayers --- .../compressed_tensors_utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index ce4ae7fb2..eba5c5882 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -11,7 +11,7 @@ ModelCompressor, SparsityCompressionConfig, is_module_offloaded, - update_parameter_data, + update_offload_parameter, ) from loguru import logger from safetensors.torch import storage_ptr @@ -238,14 +238,15 @@ def patch_tied_tensors_bug(model: torch.nn.Module): if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight): for module in (input_embed, output_embed): - offloaded = is_module_offloaded(module) - if offloaded: - module._hf_hook.pre_forward(module) - - update_parameter_data(module, module.weight.data.clone(), "weight") - - if offloaded: - module._hf_hook.post_forward(module, None) + if not is_module_offloaded(module): + # create new storage ptr for onloaded weight + untied_data = module.weight.data.clone() + module.weight.data = untied_data + else: + # create new storage ptr for offloaded weight + # note `update_offload_parameter` does not create a new storage ptr + untied_data = module._hf_hook.weights_map["weight"].clone() + update_offload_parameter(module, "weight", untied_data) def get_model_compressor(