Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify assignment stage #3

Merged
merged 2 commits into from
Jun 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 42 additions & 64 deletions openfold/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,7 @@ def get_chain_center_of_mass(pos):
# #
def kabsch_rotation(P, Q):
"""
Use scipy.spatial package to calculate best rotation that minimises
Use procrustes package to calculate best rotation that minimises
the RMSD betwee P and Q

The optimal rotation matrix was calculated using
Expand Down Expand Up @@ -1755,19 +1755,6 @@ def compute_rmsd(
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps)

def kabsch_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor,
):
r, x = get_optimal_transform(
true_atom_pos,
pred_atom_pos,
atom_mask,
)
aligned_true_atom_pos = true_atom_pos @ r + x
return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask)


def get_least_asym_entity_or_longest_length(batch):
"""
Expand Down Expand Up @@ -1802,6 +1789,11 @@ def get_least_asym_entity_or_longest_length(batch):
least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])

# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym)
return least_asym_entities[0], best_pred_asym


Expand Down Expand Up @@ -2032,65 +2024,49 @@ def multi_chain_perm_align(self,out, batch, labels, shuffle_times=2):
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]

anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym is chosen to be: {anchor_gt_asym}")
print(f"anchor_gt_asym is : {anchor_gt_asym} and anchor_pred_asym is {anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1

best_rmsd = 1e20
best_labels = None

unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
for cur_asym_id in anchor_pred_asym:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]

anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
anchor_pred_mask = pred_ca_mask[asym_mask]
r, x = get_optimal_transform(
anchor_true_pos,
anchor_pred_pos,
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]

anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
anchor_pred_mask = pred_ca_mask[asym_mask]
r, x = get_optimal_transform(
anchor_true_pos,
anchor_pred_pos,
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
)

aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
align = greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids ,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)

aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times):
shuffle_idx = torch.randperm(
unique_asym_ids.shape[0], device=unique_asym_ids.device
)
shuffled_asym_ids = unique_asym_ids[shuffle_idx]
align = greedy_align(
batch,
per_asym_residue_index,
shuffled_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
)
rmsd = kabsch_rmsd(
merged_labels["all_atom_positions"][..., ca_idx, :].to('cpu') @ r.to('cpu') + x.to('cpu'),
pred_ca_pos,
(pred_ca_mask.to('cpu') * merged_labels["all_atom_mask"][..., ca_idx].to('cpu')).bool(),
)

if rmsd < best_rmsd:
best_rmsd = rmsd
best_labels = merged_labels
print(f"finished shuffling and final align is {align}")
return best_labels
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
)

print(f"finished multi-chain permutation and final align is {align}")

return merged_labels

def forward(self,out,batch,_return_breakdown=False):
"""
Expand All @@ -2107,6 +2083,8 @@ def forward(self,out,batch,_return_breakdown=False):
# then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels)
logger.info("finished multi-chain permutation")
# features.update(permutated_labels)
# self.loss(out,features)
return permutated_labels
## TODO next need to check how the ground truth label is used
# in loss calculation.