Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infer_auto_device_map when tied weights share the same prefix name #2324

Merged
merged 2 commits into from
Jan 10, 2024

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Jan 10, 2024

As per title, thanks to @Giuseppe5 & @nickfraser notice.

On main currently, the detection of tied_param_goups and tied_params is wrong as e.g. if we have a group ["compute.weight", "compute.weight_submodule.parameter"] and we are currently treating the parameter compute.weight, tied_param_goups will wrongfully be empty as all(name in k for k in tied_group) is True.

This result in an error in this example:

import torch
import accelerate

class SubModule(torch.nn.Module):
  def __init__(self, ref_to_parameter):
    super().__init__()
    self.parameter = ref_to_parameter
  def forward(self, x):
    return self.x + torch.max(self.parameter)

class LinearModuleAndSubModule(torch.nn.Linear):
  def __init__(self, in_features, out_features):
    super().__init__(in_features, out_features)
    self.weight_submodule = SubModule(self.weight)
  def forward(self, x):
    return torch.nn.functional.linear(self.weight_submodule(x), self.weight)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.compute = LinearModuleAndSubModule(3, 8)
  def forward(self, x):
    return self.compute(x)

model = Model()
print(model)

from accelerate.utils import infer_auto_device_map
# Low memory device, just to force splitting and trigger the error
device_memory = {0: 4, 'cpu': 96000}
device_map = infer_auto_device_map(model, device_memory, verbose=True)

The error does not exist if we use the name self.brrrweight_submodule instead of self.weight_submodule.

fxmarty and others added 2 commits January 10, 2024 15:34
Co-authored-by: Giuseppe Franco <giuseppefranco4@gmail.com>
Co-authored-by: Nick Fraser <icanlosh@gmail.com>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix and the test!

@fxmarty fxmarty merged commit e3e9b87 into main Jan 10, 2024
25 checks passed
@fxmarty fxmarty deleted the fix-tied-param-detection branch January 10, 2024 14:57
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for fixing and adding a test !

statelesshz pushed a commit to statelesshz/accelerate that referenced this pull request Jan 22, 2024
huggingface#2324)

* fix auto device map with tied weights sharing a prefix name

Co-authored-by: Giuseppe Franco <giuseppefranco4@gmail.com>
Co-authored-by: Nick Fraser <icanlosh@gmail.com>

* precise comment

---------

Co-authored-by: Giuseppe Franco <giuseppefranco4@gmail.com>
Co-authored-by: Nick Fraser <icanlosh@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants