Skip to content

Commit

Permalink
Merge pull request #3 from dingquanyu/modify-assignment-stage
Browse files Browse the repository at this point in the history
Modify assignment stage
  • Loading branch information
dingquanyu authored Jun 27, 2023
2 parents 2a70e08 + eeb035c commit 3d87ef2
Showing 1 changed file with 42 additions and 64 deletions.
106 changes: 42 additions & 64 deletions openfold/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,7 @@ def get_chain_center_of_mass(pos):
# #
def kabsch_rotation(P, Q):
"""
Use scipy.spatial package to calculate best rotation that minimises
Use procrustes package to calculate best rotation that minimises
the RMSD betwee P and Q
The optimal rotation matrix was calculated using
Expand Down Expand Up @@ -1755,19 +1755,6 @@ def compute_rmsd(
msd = torch.nan_to_num(msd, nan=1e8)
return torch.sqrt(msd + eps)

def kabsch_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: torch.Tensor,
):
r, x = get_optimal_transform(
true_atom_pos,
pred_atom_pos,
atom_mask,
)
aligned_true_atom_pos = true_atom_pos @ r + x
return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask)


def get_least_asym_entity_or_longest_length(batch):
"""
Expand Down Expand Up @@ -1802,6 +1789,11 @@ def get_least_asym_entity_or_longest_length(batch):
least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])

# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym)
return least_asym_entities[0], best_pred_asym


Expand Down Expand Up @@ -2032,65 +2024,49 @@ def multi_chain_perm_align(self,out, batch, labels, shuffle_times=2):
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]

anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym is chosen to be: {anchor_gt_asym}")
print(f"anchor_gt_asym is : {anchor_gt_asym} and anchor_pred_asym is {anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1

best_rmsd = 1e20
best_labels = None

unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
for cur_asym_id in anchor_pred_asym:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]

anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
anchor_pred_mask = pred_ca_mask[asym_mask]
r, x = get_optimal_transform(
anchor_true_pos,
anchor_pred_pos,
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]

anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
anchor_pred_mask = pred_ca_mask[asym_mask]
r, x = get_optimal_transform(
anchor_true_pos,
anchor_pred_pos,
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
)

aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
align = greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids ,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)

aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times):
shuffle_idx = torch.randperm(
unique_asym_ids.shape[0], device=unique_asym_ids.device
)
shuffled_asym_ids = unique_asym_ids[shuffle_idx]
align = greedy_align(
batch,
per_asym_residue_index,
shuffled_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
)
rmsd = kabsch_rmsd(
merged_labels["all_atom_positions"][..., ca_idx, :].to('cpu') @ r.to('cpu') + x.to('cpu'),
pred_ca_pos,
(pred_ca_mask.to('cpu') * merged_labels["all_atom_mask"][..., ca_idx].to('cpu')).bool(),
)

if rmsd < best_rmsd:
best_rmsd = rmsd
best_labels = merged_labels
print(f"finished shuffling and final align is {align}")
return best_labels
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
)

print(f"finished multi-chain permutation and final align is {align}")

return merged_labels

def forward(self,out,batch,_return_breakdown=False):
"""
Expand All @@ -2107,6 +2083,8 @@ def forward(self,out,batch,_return_breakdown=False):
# then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels)
logger.info("finished multi-chain permutation")
# features.update(permutated_labels)
# self.loss(out,features)
return permutated_labels
## TODO next need to check how the ground truth label is used
# in loss calculation.

0 comments on commit 3d87ef2

Please sign in to comment.