Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add permutation unittest #339

Closed
wants to merge 0 commits into from

Conversation

dingquanyu
Copy link
Contributor

  1. fixed some naming errors in config.py e.g. template_single_embedder should have keys: c_in and c_out but it was c_in and c_m in the previous version
  2. fixed incompatibilities between embedders.py and config.py This is mainly caused by some of the keys in config.py are not in embedders.py
  3. moved some matrix slicing and processing steps to numpy in loss.py to solve the integer overflow issue caused by GPU memory usage
  4. wrote a lddt_ca_multimer function in loss.py specifically for calculating lddt loss in the case of multimer training so that it can accommodate multimer input data structure
  5. overwrote the validation matrix in train_openfold.py for data structures in multimer training mode

@dingquanyu dingquanyu force-pushed the permutation branch 2 times, most recently from 84f1b1a to e46acc2 Compare August 21, 2023 11:22
class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.stage = stage
self.stage = stage
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So these changes will overwrite my fixes to the seeding, I think I need to commit my other changes first and then you can rebase. It might be a bit nasty because of my refactoring :/.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. These torch.generator() is gonna overwrite your new seeding steps?

@@ -784,6 +784,45 @@ def read_template(start, size):

return all_hits

def _parse_template_hits(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also be gone after rebase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry this file is outdated. I will update data_pipeline.py with the version on multimer branch now and push it.

per_residue=per_residue,
)

def lddt_ca_multimer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what is happening with the masking here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it uses the same way as lddt_ca : all_atom_mask[..., ca_pos : (ca_pos + 1)] then the dimension is not correct. I have to perform this all_atom_mask[..., None] - all_atom_mask[..., None, :] so that all_atom_mask ends up with a shape of [batch_size,n_res,n_res]


c_n_bonds = (
neighbour_mask
* c_one_hot[..., None, None, :, None]
* n_one_hot[..., None, None, None, :]
* c_one_hot.detach().to('cpu').numpy()[..., None, None, :, None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I add my changes and you rebase, just make sure to remove all the cpu stuff before submitting PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

@@ -2021,7 +2052,8 @@ def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config

def multi_chain_perm_align(self,out, batch, labels, shuffle_times=2):
@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to add your changes to the permutation loss in my commit, because otherwise it will be very annoying to combine our changes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I see

pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"][...,-1]
all_atom_mask = all_atom_mask.unsqueeze(-1).expand(*all_atom_mask.shape,3)
# In the case of multimer training, no need to introduce an empty dimension to all_atom_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this part, like why there is a difference for multimer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry I wrongly copied pasted _compute_validation_metrics here again in line 64. I'll removed this chunk

@@ -272,6 +329,63 @@ def __init__(self, config):
self.cached_weights = None
self.last_lr_step = -1

def _compute_validation_metrics(self,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christinaflo This _compute_validation_metrics is overwrites the parent class method so that it uses lddt_ca_multimer instead of lddt_ca and I turned all_atom_mask, ca_pos and all_atom_mask_ca to None once they are not used anymore just to free up memory on GPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants