Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow partial loading for pre trained ms2 models #226

Open
wants to merge 2 commits into
base: development
Choose a base branch
from

Conversation

mo-sameh
Copy link
Collaborator

This PR include two main changes:

  • Allow to dynamically change the charged_frag_types used for ms2 hence changing the dimensions of the output layer for the ms2 model.
  • Allow for partially loading the ms2 model so that a user can benefit from a pretrained model backbone when using a different prediction head such as changing the fragment types. With this feature users can finetune/train a model with more fragment types than the ones used from for pretraining. From my experiments using a pretrained backbone had a significant performance gap compared to training fully from scratch.

From an alphaDia finetuning experiment where I want to predict the following fragment types ['b','y','c','a','x','z']

  • Performance when training for 50 epochs from scratch:
Model tested on test dataset with the following metrics:

l1_loss                       : 0.0228
PCC-mean                      : 0.8015
COS-mean                      : 0.8140
SA-mean                       : 0.6054
SPC-mean                      : -0.2333
  • Performance when training for 50 epochs when loading a pretrained backbone:
Model tested on test dataset with the following metrics:

l1_loss                       : 0.0122
PCC-mean                      : 0.9462
COS-mean                      : 0.9491
SA-mean                       : 0.7961
SPC-mean                      : -0.3072

@mo-sameh mo-sameh changed the title feat: allow partial loading for pre trained ms2 models allow partial loading for pre trained ms2 models Dec 30, 2024
@@ -425,6 +425,53 @@ def _prepare_train_data_df(
# if np.all(precursor_df['nce'].values > 1):
# precursor_df['nce'] = precursor_df['nce']*self.NCE_factor

def _load_model_from_stream(self, stream: IO):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be two cases:

  1. we save partial model params into the file and this method allows us to load this partial model.
  2. we save the full model but we only need to load partial params by specifying param names or first Kth layers.

Copy link
Collaborator Author

@mo-sameh mo-sameh Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I’m not sure I fully understand your concerns. The current interface saves the complete model weights, so I’m unclear about why users would want to save only partial models. Are you suggesting we add this functionality?

As for loading partial weights, the current implementation(in this PR) should already handle this automatically. It matches parameter keys and sizes, loading the matching weights while initializing the remaining parameters from scratch. Are you suggesting we modify the interface to let users explicitly specify which layers to load?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I mean the use cases.

self.model.load_state_dict(filtered_params, strict=False)
if size_mismatches or unexpected_keys or missing_keys:
warning_msg = "Some layers might be randomly initialized due to a mismatch between the loaded weights and the model architecture. Make sure to train the model or load different weights before prediction."
warning_msg += (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) This more concise format might be easier on the eye:

    warning_msg += "".join(
        [
            f"\nKeys with size mismatches: {size_mismatches}" if size_mismatches else "",
            f"\nUnexpected keys: {unexpected_keys}" if unexpected_keys else "",
            f"\nMissing keys: {missing_keys}" if missing_keys else "",
        ]
    )

@@ -270,6 +270,7 @@ def __init__(
self,
mask_modloss: bool = False,
device: str = "gpu",
charged_frag_types: list[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still have requires-python = ">=3.8.0" .. so Optional[List[str]] it is ..

@jalew188 should we drop support for 3.8?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, in the next release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants