Skip to content

Commit

Permalink
account for slightly different update param behavior (#1005)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs authored Dec 20, 2024
1 parent 7366a2d commit bb8660a
Showing 1 changed file with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ModelCompressor,
SparsityCompressionConfig,
is_module_offloaded,
update_parameter_data,
update_offload_parameter,
)
from loguru import logger
from safetensors.torch import storage_ptr
Expand Down Expand Up @@ -238,14 +238,15 @@ def patch_tied_tensors_bug(model: torch.nn.Module):

if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight):
for module in (input_embed, output_embed):
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

update_parameter_data(module, module.weight.data.clone(), "weight")

if offloaded:
module._hf_hook.post_forward(module, None)
if not is_module_offloaded(module):
# create new storage ptr for onloaded weight
untied_data = module.weight.data.clone()
module.weight.data = untied_data
else:
# create new storage ptr for offloaded weight
# note `update_offload_parameter` does not create a new storage ptr
untied_data = module._hf_hook.weights_map["weight"].clone()
update_offload_parameter(module, "weight", untied_data)


def get_model_compressor(
Expand Down

0 comments on commit bb8660a

Please sign in to comment.