-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
allow partial loading for pre trained ms2 models #226
base: development
Are you sure you want to change the base?
Conversation
@@ -425,6 +425,53 @@ def _prepare_train_data_df( | |||
# if np.all(precursor_df['nce'].values > 1): | |||
# precursor_df['nce'] = precursor_df['nce']*self.NCE_factor | |||
|
|||
def _load_model_from_stream(self, stream: IO): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There may be two cases:
- we save partial model params into the file and this method allows us to load this partial model.
- we save the full model but we only need to load partial params by specifying param names or first Kth layers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I’m not sure I fully understand your concerns. The current interface saves the complete model weights, so I’m unclear about why users would want to save only partial models. Are you suggesting we add this functionality?
As for loading partial weights, the current implementation(in this PR) should already handle this automatically. It matches parameter keys and sizes, loading the matching weights while initializing the remaining parameters from scratch. Are you suggesting we modify the interface to let users explicitly specify which layers to load?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I mean the use cases.
self.model.load_state_dict(filtered_params, strict=False) | ||
if size_mismatches or unexpected_keys or missing_keys: | ||
warning_msg = "Some layers might be randomly initialized due to a mismatch between the loaded weights and the model architecture. Make sure to train the model or load different weights before prediction." | ||
warning_msg += ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) This more concise format might be easier on the eye:
warning_msg += "".join(
[
f"\nKeys with size mismatches: {size_mismatches}" if size_mismatches else "",
f"\nUnexpected keys: {unexpected_keys}" if unexpected_keys else "",
f"\nMissing keys: {missing_keys}" if missing_keys else "",
]
)
@@ -270,6 +270,7 @@ def __init__( | |||
self, | |||
mask_modloss: bool = False, | |||
device: str = "gpu", | |||
charged_frag_types: list[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still have requires-python = ">=3.8.0"
.. so Optional[List[str]]
it is ..
@jalew188 should we drop support for 3.8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can, in the next release.
This PR include two main changes:
From an alphaDia finetuning experiment where I want to predict the following fragment types
['b','y','c','a','x','z']