Skip to content

Commit

Permalink
add backbone-sequence scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jdauparas committed Feb 20, 2024
1 parent 1fcad04 commit b161414
Show file tree
Hide file tree
Showing 7 changed files with 766 additions and 71 deletions.
73 changes: 72 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ To run the model of your choice specify `--model_type` and optionally the model
--checkpoint_per_residue_label_membrane_mpnn "./model_params/per_residue_label_membrane_mpnn_v_48_020.pt" #noised with 0.20A Gaussian noise
```

## Examples
## Design examples
### 1 default
Default settings will run ProteinMPNN.
```
Expand Down Expand Up @@ -461,6 +461,77 @@ python run.py \
--parse_atoms_with_zero_occupancy 1
```

## Scoring examples
### Output dictionary
```
out_dict = {}
out_dict["logits"] - raw logits from the model
out_dict["probs"] - softmax(logits)
out_dict["log_probs"] - log_softmax(logits)
out_dict["decoding_order"] - decoding order used (logits will depend on the decoding order)
out_dict["native_sequence"] - parsed input sequence in integers
out_dict["mask"] - mask for missing residues (usually all ones)
out_dict["chain_mask"] - controls which residues are decoded first
out_dict["alphabet"] - amino acid alphabet used
out_dict["residue_names"] - dictionary to map integers to residue_names, e.g. {0: "C10", 1: "C11"}
out_dict["sequence"] - parsed input sequence in alphabet
out_dict["mean_of_probs"] - averaged over batch_size*number_of_batches probabilities, [protein_length, 21]
out_dict["std_of_probs"] - same as above, but std
```

### 1 autoregressive with sequence info
Get probabilities/scores for backbone-sequence pairs using autoregressive probabilities: p(AA_1|backbone), p(AA_2|backbone, AA_1) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
```
python score.py \
--model_type "ligand_mpnn" \
--seed 111 \
--autoregressive_score 1\
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
--out_folder "./outputs/autoregressive_score_w_seq" \
--use_sequence 1\
--batch_size 1 \
--number_of_batches 10
```
### 2 autoregressive with backbone info only
Get probabilities/scores for backbone using probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
```
python score.py \
--model_type "ligand_mpnn" \
--seed 111 \
--autoregressive_score 1\
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
--out_folder "./outputs/autoregressive_score_wo_seq" \
--use_sequence 0\
--batch_size 1 \
--number_of_batches 10
```
### 3 single amino acid score with sequence info
Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone, AA_{all except AA_1}), p(AA_2|backbone, AA_{all except AA_2}) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
```
python score.py \
--model_type "ligand_mpnn" \
--seed 111 \
--single_aa_score 1\
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
--out_folder "./outputs/single_aa_score_w_seq" \
--use_sequence 1\
--batch_size 1 \
--number_of_batches 10
```
### 4 single amino acid score with backbone info only
Get probabilities/scores for backbone-sequence pairs using single aa probabilities: p(AA_1|backbone), p(AA_2|backbone) etc. These probabilities will depend on the decoding order, so it's recomended to set number_of_batches to at least 10.
```
python score.py \
--model_type "ligand_mpnn" \
--seed 111 \
--single_aa_score 1\
--pdb_path "./outputs/ligandmpnn_default/backbones/1BC8_1.pdb" \
--out_folder "./outputs/single_aa_score_wo_seq" \
--use_sequence 0\
--batch_size 1 \
--number_of_batches 10
```

### Things to add
- Support for ProteinMPNN CA-only model.
- Examples for scoring sequences only.
Expand Down
215 changes: 145 additions & 70 deletions model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,72 +468,112 @@ def sample(self, feature_dict):
}
return output_dict

def unconditional_probs(self, feature_dict):
# xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed
# xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords
# Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords
# Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type
# Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask
X = feature_dict["X"] # [B,L,4,3] - backbone xyz coordinates for N,CA,C,O
def single_aa_score(self, feature_dict, use_sequence: bool):
"""
feature_dict - input features
use_sequence - False using backbone info only
"""
B_decoder = feature_dict["batch_size"]
# R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index
mask = feature_dict[
S_true_enc = feature_dict[
"S"
]
mask_enc = feature_dict[
"mask"
] # [B,L] - mask for missing regions - should be removed! all ones most of the time
# chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters
device = X.device

h_V, h_E, E_idx = self.encode(feature_dict)
order_mask_backward = torch.zeros(
[X.shape[0], X.shape[1], X.shape[1]], device=device
)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_fw = mask_1D * (1.0 - mask_attend)
]
chain_mask_enc = feature_dict[
"chain_mask"
]
randn = feature_dict[
"randn"
]
B, L = S_true_enc.shape
device = S_true_enc.device

h_V_enc, h_E_enc, E_idx_enc = self.encode(feature_dict)
log_probs_out = torch.zeros([B_decoder, L, 21], device=device).float()
logits_out = torch.zeros([B_decoder, L, 21], device=device).float()
decoding_order_out = torch.zeros([B_decoder, L, L], device=device).float()

for idx in range(L):
h_V = torch.clone(h_V_enc)
E_idx = torch.clone(E_idx_enc)
mask = torch.clone(mask_enc)
S_true = torch.clone(S_true_enc)
if not use_sequence:
order_mask = torch.zeros(chain_mask_enc.shape[1], device=device).float()
order_mask[idx] = 1.
else:
order_mask = torch.ones(chain_mask_enc.shape[1], device=device).float()
order_mask[idx] = 0.
decoding_order = torch.argsort(
(order_mask + 0.0001) * (torch.abs(randn))
) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
E_idx = E_idx.repeat(B_decoder, 1, 1)
permutation_matrix_reverse = torch.nn.functional.one_hot(
decoding_order, num_classes=L
).float()
order_mask_backward = torch.einsum(
"ij, biq, bjp->bqp",
(1 - torch.triu(torch.ones(L, L, device=device))),
permutation_matrix_reverse,
permutation_matrix_reverse,
)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([B, L, 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1.0 - mask_attend)
S_true = S_true.repeat(B_decoder, 1)
h_V = h_V.repeat(B_decoder, 1, 1)
h_E = h_E_enc.repeat(B_decoder, 1, 1, 1)
mask = mask.repeat(B_decoder, 1)

h_V = h_V.repeat(B_decoder, 1, 1)
h_E = h_E.repeat(B_decoder, 1, 1, 1)
E_idx = E_idx.repeat(B_decoder, 1, 1)
mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1)
mask = mask.repeat(B_decoder, 1)
h_S = self.W_s(S_true)
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)

h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_V), h_E, E_idx)
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
# Build encoder embeddings
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)

h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for layer in self.decoder_layers:
h_V = layer(h_V, h_EXV_encoder_fw, mask)
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for layer in self.decoder_layers:
# Masked positions attend to encoder information, unmasked see.
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
h_V = layer(h_V, h_ESV, mask)

logits = self.W_out(h_V)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

log_probs_out[:,idx,:] = log_probs[:,idx,:]
logits_out[:,idx,:] = logits[:,idx,:]
decoding_order_out[:,idx,:] = decoding_order

logits = self.W_out(h_V)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
output_dict = {"log_probs": log_probs}
output_dict = {
"S": S_true,
"log_probs": log_probs_out,
"logits": logits_out,
"decoding_order": decoding_order_out,
}
return output_dict

def score(self, feature_dict):
# check if score matches - sample log probs

# xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed
# xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords
# Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords
# Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type
# Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask
# X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O
def score(self, feature_dict, use_sequence: bool):
B_decoder = feature_dict["batch_size"]
S_true = feature_dict[
"S"
] # [B,L] - integer proitein sequence encoded using "restype_STRtoINT"
# R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index
]
mask = feature_dict[
"mask"
] # [B,L] - mask for missing regions - should be removed! all ones most of the time
]
chain_mask = feature_dict[
"chain_mask"
] # [B,L] - mask for which residues need to be fixed; 0.0 - fixed; 1.0 - will be designed
# chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters
]
randn = feature_dict[
"randn"
] # [B,L] - random numbers for decoding order; only the first entry is used since decoding within a batch needs to match for symmetry

]
symmetry_list_of_lists = feature_dict[
"symmetry_residues"
]
B, L = S_true.shape
device = S_true.device

Expand All @@ -543,27 +583,57 @@ def score(self, feature_dict):
decoding_order = torch.argsort(
(chain_mask + 0.0001) * (torch.abs(randn))
) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1:
E_idx = E_idx.repeat(B_decoder, 1, 1)
permutation_matrix_reverse = torch.nn.functional.one_hot(
decoding_order, num_classes=L
).float()
order_mask_backward = torch.einsum(
"ij, biq, bjp->bqp",
(1 - torch.triu(torch.ones(L, L, device=device))),
permutation_matrix_reverse,
permutation_matrix_reverse,
)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([B, L, 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1.0 - mask_attend)
else:
new_decoding_order = []
for t_dec in list(decoding_order[0,].cpu().data.numpy()):
if t_dec not in list(itertools.chain(*new_decoding_order)):
list_a = [item for item in symmetry_list_of_lists if t_dec in item]
if list_a:
new_decoding_order.append(list_a[0])
else:
new_decoding_order.append([t_dec])

permutation_matrix_reverse = torch.nn.functional.one_hot(
decoding_order, num_classes=L
).float()
order_mask_backward = torch.einsum(
"ij, biq, bjp->bqp",
(1 - torch.triu(torch.ones(L, L, device=device))),
permutation_matrix_reverse,
permutation_matrix_reverse,
)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([B, L, 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1.0 - mask_attend)
decoding_order = torch.tensor(
list(itertools.chain(*new_decoding_order)), device=device
)[None,].repeat(B, 1)

permutation_matrix_reverse = torch.nn.functional.one_hot(
decoding_order, num_classes=L
).float()
order_mask_backward = torch.einsum(
"ij, biq, bjp->bqp",
(1 - torch.triu(torch.ones(L, L, device=device))),
permutation_matrix_reverse,
permutation_matrix_reverse,
)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([B, L, 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1.0 - mask_attend)

E_idx = E_idx.repeat(B_decoder, 1, 1)
mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1)
mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1)
decoding_order = decoding_order.repeat(B_decoder, 1)

# repeat for decoding
S_true = S_true.repeat(B_decoder, 1)
h_V = h_V.repeat(B_decoder, 1, 1)
h_E = h_E.repeat(B_decoder, 1, 1, 1)
E_idx = E_idx.repeat(B_decoder, 1, 1)
chain_mask = chain_mask.repeat(B_decoder, 1)
mask = mask.repeat(B_decoder, 1)

h_S = self.W_s(S_true)
Expand All @@ -574,19 +644,24 @@ def score(self, feature_dict):
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)

h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for layer in self.decoder_layers:
# Masked positions attend to encoder information, unmasked see.
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
h_V = layer(h_V, h_ESV, mask)
if not use_sequence:
for layer in self.decoder_layers:
h_V = layer(h_V, h_EXV_encoder_fw, mask)
else:
for layer in self.decoder_layers:
# Masked positions attend to encoder information, unmasked see.
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
h_V = layer(h_V, h_ESV, mask)

logits = self.W_out(h_V)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

output_dict = {
"S": S_true,
"log_probs": log_probs,
"decoding_order": decoding_order[0],
"logits": logits,
"decoding_order": decoding_order,
}
return output_dict

Expand Down
Binary file added outputs/autoregressive_score_w_seq/1BC8_1.pt
Binary file not shown.
Binary file added outputs/autoregressive_score_wo_seq/1BC8_1.pt
Binary file not shown.
Binary file added outputs/single_aa_score_w_seq/1BC8_1.pt
Binary file not shown.
Binary file added outputs/single_aa_score_wo_seq/1BC8_1.pt
Binary file not shown.
Loading

0 comments on commit b161414

Please sign in to comment.