diff --git a/openfold/utils/loss.py b/openfold/utils/loss.py index 04c0fd7b..86f47b8b 100644 --- a/openfold/utils/loss.py +++ b/openfold/utils/loss.py @@ -1677,7 +1677,7 @@ def get_chain_center_of_mass(pos): # # def kabsch_rotation(P, Q): """ - Use scipy.spatial package to calculate best rotation that minimises + Use procrustes package to calculate best rotation that minimises the RMSD betwee P and Q The optimal rotation matrix was calculated using @@ -1755,19 +1755,6 @@ def compute_rmsd( msd = torch.nan_to_num(msd, nan=1e8) return torch.sqrt(msd + eps) -def kabsch_rmsd( - true_atom_pos: torch.Tensor, - pred_atom_pos: torch.Tensor, - atom_mask: torch.Tensor, -): - r, x = get_optimal_transform( - true_atom_pos, - pred_atom_pos, - atom_mask, - ) - aligned_true_atom_pos = true_atom_pos @ r + x - return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask) - def get_least_asym_entity_or_longest_length(batch): """ @@ -1802,6 +1789,11 @@ def get_least_asym_entity_or_longest_length(batch): least_asym_entities = random.choice(least_asym_entities) assert len(least_asym_entities)==1 best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]]) + + # If there is more than one chain in the predicted output that has the same sequence + # as the chosen ground truth anchor, then randomly picke one + if len(best_pred_asym) > 1: + best_pred_asym = random.choice(best_pred_asym) return least_asym_entities[0], best_pred_asym @@ -2032,65 +2024,49 @@ def multi_chain_perm_align(self,out, batch, labels, shuffle_times=2): per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask] anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) - print(f"anchor_gt_asym is chosen to be: {anchor_gt_asym}") + print(f"anchor_gt_asym is : {anchor_gt_asym} and anchor_pred_asym is {anchor_pred_asym}") anchor_gt_idx = int(anchor_gt_asym) - 1 - best_rmsd = 1e20 - best_labels = None - unique_entity_ids = torch.unique(batch["entity_id"]) entity_2_asym_list = {} for cur_ent_id in unique_entity_ids: ent_mask = batch["entity_id"] == cur_ent_id cur_asym_id = torch.unique(batch["asym_id"][ent_mask]) entity_2_asym_list[int(cur_ent_id)] = cur_asym_id - for cur_asym_id in anchor_pred_asym: - asym_mask = (batch["asym_id"] == cur_asym_id).bool() - anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)] - - anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx] - anchor_pred_pos = pred_ca_pos[asym_mask] - anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx] - anchor_pred_mask = pred_ca_mask[asym_mask] - r, x = get_optimal_transform( - anchor_true_pos, - anchor_pred_pos, - (anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(), + asym_mask = (batch["asym_id"] == anchor_pred_asym).bool() + anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)] + + anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx] + anchor_pred_pos = pred_ca_pos[asym_mask] + anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx] + anchor_pred_mask = pred_ca_mask[asym_mask] + r, x = get_optimal_transform( + anchor_true_pos, + anchor_pred_pos, + (anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(), + ) + + aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms + align = greedy_align( + batch, + per_asym_residue_index, + unique_asym_ids , + entity_2_asym_list, + pred_ca_pos, + pred_ca_mask, + aligned_true_ca_poses, + true_ca_masks, ) - - aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms - for _ in range(shuffle_times): - shuffle_idx = torch.randperm( - unique_asym_ids.shape[0], device=unique_asym_ids.device - ) - shuffled_asym_ids = unique_asym_ids[shuffle_idx] - align = greedy_align( - batch, - per_asym_residue_index, - shuffled_asym_ids, - entity_2_asym_list, - pred_ca_pos, - pred_ca_mask, - aligned_true_ca_poses, - true_ca_masks, - ) - merged_labels = merge_labels( - batch, - per_asym_residue_index, - labels, - align, - ) - rmsd = kabsch_rmsd( - merged_labels["all_atom_positions"][..., ca_idx, :].to('cpu') @ r.to('cpu') + x.to('cpu'), - pred_ca_pos, - (pred_ca_mask.to('cpu') * merged_labels["all_atom_mask"][..., ca_idx].to('cpu')).bool(), - ) - - if rmsd < best_rmsd: - best_rmsd = rmsd - best_labels = merged_labels - print(f"finished shuffling and final align is {align}") - return best_labels + merged_labels = merge_labels( + batch, + per_asym_residue_index, + labels, + align, + ) + + print(f"finished multi-chain permutation and final align is {align}") + + return merged_labels def forward(self,out,batch,_return_breakdown=False): """ @@ -2107,6 +2083,8 @@ def forward(self,out,batch,_return_breakdown=False): # then permutate ground truth chains before calculating the loss permutated_labels = self.multi_chain_perm_align(out,features,labels) logger.info("finished multi-chain permutation") + # features.update(permutated_labels) + # self.loss(out,features) return permutated_labels ## TODO next need to check how the ground truth label is used # in loss calculation. \ No newline at end of file