-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Central function to canonicalize state dicts (#40)
* Central function to canonicalize state dicts * fixes * Make public * Return a canonicalized state dict
- Loading branch information
1 parent
6e29e02
commit 49f4494
Showing
7 changed files
with
44 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from .model_descriptor import StateDict | ||
|
||
|
||
def canonicalize_state_dict(state_dict: StateDict) -> StateDict: | ||
""" | ||
Canonicalize a state dict. | ||
This function is used to canonicalize a state dict, so that it can be | ||
used for architecture detection and loading. | ||
This function is not intended to be used in production code. | ||
""" | ||
|
||
# the real state dict might be inside a dict with a known key | ||
unwrap_keys = ["state_dict", "params_ema", "params-ema", "params", "model", "net"] | ||
for unwrap_key in unwrap_keys: | ||
if unwrap_key in state_dict and isinstance(state_dict[unwrap_key], dict): | ||
state_dict = state_dict[unwrap_key] | ||
break | ||
|
||
# remove known common prefixes | ||
if len(state_dict) > 0: | ||
for prefix in ["module.", "netG."]: | ||
if all(i.startswith(prefix) for i in state_dict.keys()): | ||
state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()} | ||
|
||
return state_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters