-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add permutation unittest #339
Conversation
dingquanyu
commented
Aug 10, 2023
- fixed some naming errors in config.py e.g. template_single_embedder should have keys: c_in and c_out but it was c_in and c_m in the previous version
- fixed incompatibilities between embedders.py and config.py This is mainly caused by some of the keys in config.py are not in embedders.py
- moved some matrix slicing and processing steps to numpy in loss.py to solve the integer overflow issue caused by GPU memory usage
- wrote a lddt_ca_multimer function in loss.py specifically for calculating lddt loss in the case of multimer training so that it can accommodate multimer input data structure
- overwrote the validation matrix in train_openfold.py for data structures in multimer training mode
84f1b1a
to
e46acc2
Compare
openfold/data/data_modules.py
Outdated
class OpenFoldDataLoader(torch.utils.data.DataLoader): | ||
def __init__(self, *args, config, stage="train", generator=None, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.config = config | ||
self.stage = stage | ||
self.stage = stage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So these changes will overwrite my fixes to the seeding, I think I need to commit my other changes first and then you can rebase. It might be a bit nasty because of my refactoring :/.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. These torch.generator() is gonna overwrite your new seeding steps?
openfold/data/data_pipeline.py
Outdated
@@ -784,6 +784,45 @@ def read_template(start, size): | |||
|
|||
return all_hits | |||
|
|||
def _parse_template_hits( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should also be gone after rebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry this file is outdated. I will update data_pipeline.py with the version on multimer branch now and push it.
openfold/utils/loss.py
Outdated
per_residue=per_residue, | ||
) | ||
|
||
def lddt_ca_multimer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain what is happening with the masking here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it uses the same way as lddt_ca : all_atom_mask[..., ca_pos : (ca_pos + 1)]
then the dimension is not correct. I have to perform this all_atom_mask[..., None] - all_atom_mask[..., None, :]
so that all_atom_mask
ends up with a shape of [batch_size,n_res,n_res]
openfold/utils/loss.py
Outdated
|
||
c_n_bonds = ( | ||
neighbour_mask | ||
* c_one_hot[..., None, None, :, None] | ||
* n_one_hot[..., None, None, None, :] | ||
* c_one_hot.detach().to('cpu').numpy()[..., None, None, :, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I add my changes and you rebase, just make sure to remove all the cpu stuff before submitting PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
openfold/utils/loss.py
Outdated
@@ -2021,7 +2052,8 @@ def __init__(self, config): | |||
super(AlphaFoldMultimerLoss, self).__init__(config) | |||
self.config = config | |||
|
|||
def multi_chain_perm_align(self,out, batch, labels, shuffle_times=2): | |||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to add your changes to the permutation loss in my commit, because otherwise it will be very annoying to combine our changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I see
train_openfold.py
Outdated
pred_coords = outputs["final_atom_positions"] | ||
all_atom_mask = batch["all_atom_mask"][...,-1] | ||
all_atom_mask = all_atom_mask.unsqueeze(-1).expand(*all_atom_mask.shape,3) | ||
# In the case of multimer training, no need to introduce an empty dimension to all_atom_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain this part, like why there is a difference for multimer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry I wrongly copied pasted _compute_validation_metrics
here again in line 64. I'll removed this chunk
train_openfold.py
Outdated
@@ -272,6 +329,63 @@ def __init__(self, config): | |||
self.cached_weights = None | |||
self.last_lr_step = -1 | |||
|
|||
def _compute_validation_metrics(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@christinaflo This _compute_validation_metrics
is overwrites the parent class method so that it uses lddt_ca_multimer
instead of lddt_ca
and I turned all_atom_mask
, ca_pos
and all_atom_mask_ca
to None once they are not used anymore just to free up memory on GPU
87fbff7
to
ab09ded
Compare