Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Better reordering #129

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions unifold/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def forward(self, model, batch, reduce=True):
# return config in model.
out, config = model(batch)
num_recycling = batch["msa_feat"].shape[0]

# remove recyling dim
batch = tensor_tree_map(lambda t: t[-1, ...], batch)

loss, sample_size, logging_output = self.loss(out, batch, config)
logging_output["num_recycling"] = num_recycling
return loss, sample_size, logging_output
Expand All @@ -57,7 +57,7 @@ def loss(self, out, batch, config):
if "renamed_atom14_gt_positions" not in out.keys():
batch.update(
compute_renamed_ground_truth(batch, out["sm"]["positions"]))

loss_dict = {}
loss_fns = {
"chain_centre_mass": lambda: chain_centre_mass_loss(
Expand Down Expand Up @@ -143,11 +143,11 @@ def loss(self, out, batch, config):
with torch.no_grad():
seq_len = torch.sum(batch["seq_mask"].float(), dim=-1)
seq_length_weight = seq_len**0.5

assert (
len(seq_length_weight.shape) == 1 and seq_length_weight.shape[0] == bsz
), seq_length_weight.shape

for loss_name, loss_fn in loss_fns.items():
weight = config[loss_name].weight
if weight > 0.:
Expand All @@ -159,7 +159,7 @@ def loss(self, out, batch, config):
if any(torch.isnan(loss)) or any(torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0.0, requires_grad=True)

cum_loss = cum_loss + weight * loss

for key in loss_dict:
Expand Down Expand Up @@ -208,10 +208,10 @@ def forward(self, model, batch, reduce=True):
# return config in model.
out, config = model(features)
num_recycling = features["msa_feat"].shape[0]

# remove recycling dim
features = tensor_tree_map(lambda t: t[-1, ...], features)

# perform multi-chain permutation alignment.
if labels:
with torch.no_grad():
Expand All @@ -221,7 +221,7 @@ def forward(self, model, batch, reduce=True):
cur_out = {
k: out[k][batch_idx]
for k in out
if k in {"final_atom_positions", "final_atom_mask"}
if k in {"final_atom_positions", "final_atom_mask", "pred_frame_tensor"}
}
cur_feature = {k: features[k][batch_idx] for k in features}
cur_label = labels[batch_idx]
Expand All @@ -230,12 +230,12 @@ def forward(self, model, batch, reduce=True):
)
new_labels.append(cur_new_labels)
new_labels = data_utils.collate_dict(new_labels, dim=0)

# check for consistency of label and feature.
assert (new_labels["aatype"] == features["aatype"]).all()
features.update(new_labels)

loss, sample_size, logging_output = self.loss(out, features, config)
logging_output["num_recycling"] = num_recycling

return loss, sample_size, logging_output
217 changes: 201 additions & 16 deletions unifold/losses/chain_align.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,184 @@
import torch
from unifold.data import residue_constants as rc
import torch as th
from unifold.modules.frame import Frame
from scipy.optimize import linear_sum_assignment
from .geometry import kabsch_rmsd, get_optimal_transform, compute_rmsd
from typing import List, Tuple, Dict, Optional


def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
hypnopump marked this conversation as resolved.
Show resolved Hide resolved
def compute_xx_fape(
pred_frames: Frame,
target_frames: Frame,
pred_points: th.Tensor,
target_points: th.Tensor,
frames_mask: Optional[th.Tensor] = None,
points_mask: Optional[th.Tensor] = None,
) -> th.Tensor:
""" FAPE cross-matrix from frames to the cross-matrix of points,
used to find a permutation which gives the optimal loss for a symmetric structure.

Notation for use with n chains of length p, under f=k frames:
- n: number of protein chains
- p: number of points (length of chain)
- d: dimension of points = 3
- f: arbitrary number of frames
- ': frames dimension
(..., n', f, 1, 1) @ (..., 1, 1, n, p, d) -> (..., n', f, n, p, d)
(..., ni', f, ni, p, d) - (..., nj', f, nj, p, d) -> (..., ni', nj', ni, nj)

Args:
pred_frames: (..., n(i'), f)
target_frames: (..., n(j'), f)
pred_points: (..., n(i), p, d)
target_points: (..., n(j), p, d)
frames_mask: (..., n', f) float tensor
points_mask: (..., n, p) float tensor

Returns:
(..., n(i'), n(j'), n(i), n(j)) th.Tensor
"""
# define masks for reduction, mask is (ni', nj', f, ni, nj, p)
mask = 1.
if frames_mask is not None:
mask = mask * (
frames_mask[..., :, None, :, None, None, None] + frames_mask[..., None, :, :, None, None, None]
).bool().float()
if points_mask is not None:
mask = mask * (
points_mask[..., None, None, None, :, None, :] + points_mask[..., None, None, None, None, :, :]
).bool().float()

# (..., n', f) · (..., n, p, d) -> (..., n', f, n, p, d)
local_pred_pos = pred_frames[..., None, None].invert().apply(
pred_points[..., None, None, :, :, :].float(),
)
# (..., n', f) · (..., n, p) -> (..., n', f, n, p)
local_target_pos = target_frames[..., None, None].invert().apply(
target_points[..., None, None, :, :, :].float(),
)
# chunk in ni, nj to avoid memory errors
n_, n = local_pred_pos.shape[-5], local_pred_pos.shape[-3]
xx_fape = local_pred_pos.new_zeros(*local_pred_pos.shape[:-5], n_, n_, n, n)
for i_ in range(n_):
for j_ in range(n_):
# (..., ni, f, ni, p, d) - (..., nj, f, nj, p, d) -> (..., ni, nj, f, ni', nj', p)
d_pt2 = (
local_pred_pos[..., i_:i_ + 1, None, :, :, None, :, :] -
local_target_pos[..., None, j_:j_ + 1, :, None, :, :, :]
).square().sum(-1)
d_pt = d_pt2.add_(1e-5).sqrt()
# (..., ni, nj, f, ni', nj', p) -> (..., ni, nj, ni', nj')
if frames_mask is not None or points_mask is not None:
mask_ = mask[..., i_:i_+1, j_:j_+1, :, :, :, :]
x_fape_ij = (d_pt * mask_).sum(dim=(-1, -4)) / mask_.sum(dim=(-1, -4))
else:
x_fape_ij = d_pt.mean(dim=(-1, -4))
xx_fape[..., i_, j_, :, :] = x_fape_ij
# save memory
del d_pt2, d_pt, x_fape_ij

return xx_fape


def multi_chain_perm_align(out: Dict, batch: Dict, labels: List[Dict]) -> Dict:
""" Permutes labels so that a structural loss wrt preds is minimized.
Framed as a linear assignment problem, loss is sum of "individual" losses
and the permutation is found by the Hungarian algorithm on the cross matrix.

WARNING! All keys in `out` have no batch size
"""
assert isinstance(labels, list)
# get all unique chains - remove padding tokens with no labels
unique_asym_ids = th.unique(batch["asym_id"])
if len(unique_asym_ids) == len(labels) + 1:
unique_asym_ids = th.tensor(list((set(unique_asym_ids.tolist()) - {0}))).to(batch["asym_id"])
assert len(unique_asym_ids) == len(labels)
best_global_curr = th.clone(unique_asym_ids)
best_global_perm = th.clone(unique_asym_ids)
best_global_perm_list = best_global_perm.tolist()

# all indices associated with a chain (asymmetric unit)
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]

# Get values to compute cross-matrix, use reference frames
true_frames, true_frames_mask = [], []
for l, cur_asym_id in zip(labels, unique_asym_ids):
asym_res_idx = per_asym_residue_index[int(cur_asym_id)]
true_frames.append(Frame.from_tensor_4x4(l["true_frame_tensor"][asym_res_idx]))
true_frames_mask.append(l["frame_mask"][asym_res_idx])

# [bsz, nres, d=3]
true_frames = Frame.cat(true_frames, dim=0)
# [bsz, nres]
true_frames_mask = th.cat(true_frames_mask, dim=0)
# [bsz, nres, d=3]
pred_frames = Frame.from_tensor_4x4(out["pred_frame_tensor"])
# [bsz, nres]
pred_frames_mask = batch["atom14_gt_exists"][..., [0, 1, 2]].float().prod(dim=-1)

# rename symmetric chains, (works for abab, abb, abbcc, ...)
unique_ent_ids = th.unique(batch["entity_id"])
for ent_id in unique_ent_ids:
# see how many chains for the entity, if just 1, continue
ent_mask = batch["entity_id"] == ent_id
asym_ids = th.unique(batch["asym_id"][ent_mask])
if len(asym_ids) == 1:
continue
# create placeholders for points and corresponding masks of shape (n, l, d) and (n, l)
local_perm_idxs = []
ent_mask = batch["entity_id"] == ent_id
ent_res_idx = batch["residue_index"][ent_mask]
min_res, max_res = ent_res_idx.amin().item(), ent_res_idx.amax().item()
n = len(asym_ids)
l = max_res - min_res + 1
ph_frames_pred = Frame.identity((n, l), device=pred_frames.device, dtype=pred_frames.dtype)
ph_frames_true = Frame.identity((n, l), device=pred_frames.device, dtype=pred_frames.dtype)
frames_mask = points_mask = true_frames_mask.new_zeros((n, l,))
# fill placeholders with points and masks
for i, ni in enumerate(asym_ids):
local_perm_idxs.append(best_global_perm_list[ni.item()])
asym_mask = batch["asym_id"] == ni
asym_res_idx = batch["residue_index"][asym_mask]
ph_frames_pred[i, asym_res_idx - min_res] = pred_frames[asym_mask].clone()
ph_frames_true[i, asym_res_idx - min_res] = true_frames[asym_mask].clone()
points_mask[i, asym_res_idx - min_res] = pred_frames_mask[asym_mask] * true_frames_mask[asym_mask]
frames_mask[i, asym_res_idx - min_res] = pred_frames_mask[asym_mask] * true_frames_mask[asym_mask]

# cross-matrix and hungarian algorithm finds the best permutation
# (n=N, f=L), (n=N, f=L), (n=N, p=L), (n=N, p=L) -> (ni'=N, nj'=N, ni=N, nj=N)
x_mat = compute_xx_fape(
pred_frames=ph_frames_pred,
target_frames=ph_frames_true,
pred_points=ph_frames_pred._t,
target_points=ph_frames_true._t,
frames_mask=frames_mask,
points_mask=points_mask,
).detach().cpu()
# (ni'=N, nj'=N, ni=N, nj=N) -> (N, N)
x_mat_frames = x_mat.sum(dim=(-1, -2)) / (x_mat.shape[-1] * x_mat.shape[-2])
x_mat_points = x_mat.sum(dim=(-3, -4)) / (x_mat.shape[-3] * x_mat.shape[-4])
rows, cols = linear_sum_assignment((x_mat_frames + x_mat_points).numpy())

# remap labels like: labels["ent_mask"] = ph_true_ca_pos[cols][ph_true_ca_mask[cols]]
global_rows = local_perm_idxs
global_cols = [local_perm_idxs[c] for c in cols]
best_global_perm[global_rows] = best_global_perm[global_cols]

# (N,) -> (2, N) and match indices of labels list
ij_label_align = th.stack((best_global_curr, best_global_perm), dim=0).long().T
ij_label_align = (ij_label_align - ij_label_align.amin()).tolist()
best_labels = merge_labels(
batch=batch,
per_asym_residue_index=per_asym_residue_index,
labels=labels,
align=ij_label_align
)
return best_labels


def multi_chain_perm_align_outdated(out, batch, labels, shuffle_times=2):
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :].float() # [bsz, nres, 3]
Expand All @@ -15,7 +190,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
l["all_atom_mask"][..., ca_idx].float() for l in labels
] # list([nres,])

unique_asym_ids = torch.unique(batch["asym_id"])
unique_asym_ids = th.unique(batch["asym_id"])

per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
Expand All @@ -30,11 +205,11 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
best_rmsd = 1e9
best_labels = None

unique_entity_ids = torch.unique(batch["entity_id"])
unique_entity_ids = th.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
cur_asym_id = th.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id

for cur_asym_id in anchor_pred_asym:
Expand All @@ -53,7 +228,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):

aligned_true_ca_poses = [ca @ r + x for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times):
shuffle_idx = torch.randperm(
shuffle_idx = th.randperm(
unique_asym_ids.shape[0], device=unique_asym_ids.device
)
shuffled_asym_ids = unique_asym_ids[shuffle_idx]
Expand Down Expand Up @@ -91,7 +266,7 @@ def get_anchor_candidates(batch, per_asym_residue_index, true_masks):
def find_by_num_sym(min_num_sym):
best_len = -1
best_gt_asym = None
asym_ids = torch.unique(batch["asym_id"][batch["num_sym"] == min_num_sym])
asym_ids = th.unique(batch["asym_id"][batch["num_sym"] == min_num_sym])
for cur_asym_id in asym_ids:
assert cur_asym_id > 0
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
Expand All @@ -116,7 +291,7 @@ def find_by_num_sym(min_num_sym):
if best_len >= 3:
break
best_entity = batch["entity_id"][batch["asym_id"] == best_gt_asym][0]
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == best_entity])
best_pred_asym = th.unique(batch["asym_id"][batch["entity_id"] == best_entity])
return best_gt_asym, best_pred_asym


Expand Down Expand Up @@ -172,12 +347,22 @@ def greedy_align(

return align

def merge_labels(
batch: Dict,
per_asym_residue_index: Dict,
labels: List[Dict],
align: List[Tuple]
) -> Dict:
""" Reorders the labels

def merge_labels(batch, per_asym_residue_index, labels, align):
"""
batch:
labels: list of label dicts, each with shape [nk, *]
align: list of int, such as [2, None, 0, 1], each entry specify the corresponding label of the asym.
Args:
batch: dict of tensors
per_asym_residue_index: dict mapping every asym_id to a list of its residue indices
labels: list of label dicts, each with shape [nk, *]
align: list of int tuples (i,j) such that label j will soon be label i. ex. [(1, 2), (2, 1)]

Returns:
merged_labels: dict of tensors with reordered labels
"""
num_res = batch["msa_mask"].shape[-1]
outs = {}
Expand All @@ -193,14 +378,14 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
cur_residue_index = per_asym_residue_index[i + 1]
cur_out[i] = label[cur_residue_index]
cur_out = [x[1] for x in sorted(cur_out.items())]
new_v = torch.concat(cur_out, dim=0)
new_v = th.cat(cur_out, dim=0)
merged_nres = new_v.shape[0]
assert (
merged_nres <= num_res
), f"bad merged num res: {merged_nres} > {num_res}. something is wrong."
if merged_nres < num_res: # must pad
pad_dim = new_v.shape[1:]
pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim))
new_v = torch.concat((new_v, pad_v), dim=0)
new_v = th.cat((new_v, pad_v), dim=0)
outs[k] = new_v
return outs
Loading