Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype bug when offload_state_dict=True and dtype is specified #2116

Merged
merged 7 commits into from
Dec 5, 2023

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Nov 2, 2023

As per title.

I don't have time to provide a full repro & add a test - could do next week.

But for general context, the issue is the following:

When loading a state dict with offload_state_dict=True, we go into

elif param_device == "cpu" and offload_state_dict:

and we are missing a check similar to

elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
when the argument dtype is specified.

The byproduct is that all parameters are casted to dtype, although they should not.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Left a few comments. I will run a few tests about bnb cpu offload to see if nothing is broken then merge.

src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Dec 2, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@muellerzr
Copy link
Collaborator

Gentle ping @fxmarty 🤗

@fxmarty fxmarty force-pushed the patch-set_module_tensor_to_device branch from 598a149 to 996107b Compare December 5, 2023 09:51
@@ -1311,7 +1313,7 @@ def get_state_dict_offloaded_model(model: nn.Module):
def load_checkpoint_in_model(
model: nn.Module,
checkpoint: Union[str, os.PathLike],
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This typing is wrong as None is not an accepted value in this function (raises an error at

if param in device_map:
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave it as is because changing to device_map: Dict[str, Union[int, str, torch.device]] raises errors in the CI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for noticing this. This should be fixed here.

@fxmarty fxmarty requested a review from SunMarc December 5, 2023 10:42
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating! Overall looks great to me bar one suggestion. cc @SunMarc for a final look through?

tests/test_modeling_utils.py Outdated Show resolved Hide resolved
tests/test_modeling_utils.py Outdated Show resolved Hide resolved
@muellerzr
Copy link
Collaborator

@fxmarty could you run make style; make quality? Thanks!

@fxmarty
Copy link
Contributor Author

fxmarty commented Dec 5, 2023

Done thank you!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Thanks for adding a test.

@fxmarty fxmarty merged commit 9569150 into main Dec 5, 2023
25 checks passed
@fxmarty fxmarty deleted the patch-set_module_tensor_to_device branch December 5, 2023 17:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants