Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Central function to canonicalize state dicts #40

Merged
merged 4 commits into from
Nov 22, 2023

Conversation

RunDevelopment
Copy link
Member

I wanted to add support for another arch today and noticed the pretrained models are checkpoints saved as .pth files. Since they are .pth files, our code for simplifying .ckpt files does not run, and the loaded state dict is a mess.

So I combined the code for cleaning up .ckpt files and the code for unwrapping nested dicts into one function: canonicalize_state_dict. The job of this function to bring all state dicts into a common form.

Open question: Should this function be public? The load functions of individual archs expect a canonicalized state dict, so users must go through an ArchRegistry if they load .pth (or similar) files themselves. Passing model.state_dict() into a load function will continue to work though.

@joeyballentine
Copy link
Member

Should this function be public?

Sure I guess. I don't see why it shouldn't be

@RunDevelopment
Copy link
Member Author

Another question: Should ModelLoader.load_state_dict_from_file return a canonicalized state dict? Yes, no, should there be a parameter canonicalized: bool, what should the default for that parameter be?

I would like the following code to always work:

state = ModelLoader().load_state_dict_from_file(file)
model = SomeArch.load(state)

So I think we should make the function load_state_dict_from_file(self, path: str | Path, canonicalized: bool = True). What do you think?

@joeyballentine
Copy link
Member

If you wanna load a state dict without doing anything special, you can just use torch.load(). Our stuff should always return usable state dicts, otherwise what's the point.

At least, that's my opinion

@RunDevelopment
Copy link
Member Author

RunDevelopment commented Nov 22, 2023

We do handle different file formats, but I agree with your argument. Then let's just say that load_state_dict_from_file always returns a canonicalized state dict. We can add a parameter to control this behavior later if needed.

@joeyballentine joeyballentine merged commit 49f4494 into main Nov 22, 2023
7 checks passed
@joeyballentine joeyballentine deleted the canonicalize_state_dict branch November 22, 2023 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants