Skip to content

Commit

Permalink
Minor fixes/reformatting for recent multimer training PR
Browse files Browse the repository at this point in the history
  • Loading branch information
christinaflo committed Aug 3, 2023
1 parent 31051cf commit 30764cf
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 92 deletions.
33 changes: 9 additions & 24 deletions openfold/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ def model_config(
c.loss.masked_msa.num_classes = 22
c.data.common.max_recycling_iters = 20

for k,v in multimer_model_config_update['model'].items():
for k, v in multimer_model_config_update['model'].items():
c.model[k] = v

for k,v in multimer_model_config_update['loss'].items():
for k, v in multimer_model_config_update['loss'].items():
c.loss[k] = v

# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
Expand Down Expand Up @@ -683,24 +683,11 @@ def model_config(
)

multimer_model_config_update = {
'model':{"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
'model': {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
Expand Down Expand Up @@ -841,8 +828,6 @@ def model_config(
},
"recycle_early_stop_tolerance": 0.5
},
"recycle_early_stop_tolerance": 0.5
},
"loss": {
"distogram": {
"min_bin": 2.3125,
Expand All @@ -863,7 +848,7 @@ def model_config(
"loss_unit_distance": 10.0,
"weight": 0.5,
},
"interface": {
"interface_backbone": {
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5,
Expand Down Expand Up @@ -918,5 +903,5 @@ def model_config(
"enabled": True,
},
"eps": eps,
},
}
}
22 changes: 15 additions & 7 deletions openfold/data/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Optional, Sequence, List, Any

import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
Expand All @@ -18,7 +17,7 @@
mmcif_parsing,
templates,
)
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
from openfold.utils.tensor_utils import dict_multimap
import contextlib
import tempfile
from openfold.utils.tensor_utils import (
Expand All @@ -34,6 +33,7 @@ def temp_fasta_file(sequence_str):
fasta_file.seek(0)
yield fasta_file.name


class OpenFoldSingleDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
Expand Down Expand Up @@ -296,6 +296,7 @@ def __getitem__(self, idx):
def __len__(self):
return len(self._chain_ids)


class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir: str,
Expand Down Expand Up @@ -549,10 +550,10 @@ def __getitem__(self, idx):
device=all_chain_features["aatype"].device)
return all_chain_features


def __len__(self):
return len(self._chain_ids)


def deterministic_train_filter(
chain_data_cache_entry: Any,
max_resolution: float = 9.,
Expand All @@ -575,6 +576,7 @@ def deterministic_train_filter(

return True


def deterministic_multimer_train_filter(
mmcif_data_cache_entry,
max_resolution:float= 9.,
Expand Down Expand Up @@ -613,9 +615,10 @@ def deterministic_multimer_train_filter(

return True


def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any,
) -> List[float]:
) -> float:
# Stochastic filters
probabilities = []

Expand Down Expand Up @@ -723,8 +726,8 @@ def reroll(self):
datapoint_idx = next(samples)
self.datapoints.append((dataset_idx, datapoint_idx))


class OpenFoldMultimerDataset(torch.utils.data.Dataset):

"""
Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper:
Expand Down Expand Up @@ -753,7 +756,8 @@ def filter_samples(self,dataset_idx):
chains = mmcif_data_cache[mmcif_id]['chain_ids']
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,minimum_number_of_residues=5):
max_resolution=9,
minimum_number_of_residues=5):
selected_idx.append(i)

return selected_idx
Expand Down Expand Up @@ -781,11 +785,13 @@ def reroll(self):
logging.info(f"self.epoch_len is {self.epoch_len}")
self.datapoints += [(dataset_idx, selected_idx[i]) for i in range(self.epoch_len) ]


class OpenFoldBatchCollator:
def __call__(self, prots):
stack_fn = partial(torch.stack, dim=0)
return dict_multimap(stack_fn, prots)


class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -873,6 +879,7 @@ def _batch_prop_gen(iterator):

return _batch_prop_gen(it)


class OpenFoldMultimerDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
super(OpenFoldMultimerDataLoader,self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -1110,7 +1117,8 @@ def val_dataloader(self):

def predict_dataloader(self):
return self._gen_dataloader("predict")



class OpenFoldMultimerDataModule(OpenFoldDataModule):
"""
Create a datamodule specifically for multimer training
Expand Down
65 changes: 16 additions & 49 deletions openfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,45 +784,6 @@ def read_template(start, size):

return all_hits

def _parse_template_hits(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None,
input_sequence=None,
) -> Mapping[str, Any]:
all_hits = {}
if (alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')

def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")

for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]

if (ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits

fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]

if (ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits

elif (ext =='.sto') and (f.startswith("hmm")):
with open(path,"r") as fp:
hits = parsers.parse_hmmsearch_sto(fp.read(),input_sequence)
all_hits[f] = hits
fp.close()
return all_hits

def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
Expand Down Expand Up @@ -879,9 +840,9 @@ def process_fasta(
num_res = len(input_sequence)

hits = self._parse_template_hit_files(
alignment_dir,
input_sequence,
alignment_index,
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index,
)

template_features = make_template_features(
Expand Down Expand Up @@ -928,8 +889,9 @@ def process_mmcif(

input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hit_files(
alignment_dir,
alignment_index,input_sequence)
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index)

template_features = make_template_features(
input_sequence,
Expand Down Expand Up @@ -976,8 +938,9 @@ def process_pdb(
)

hits = self._parse_template_hit_files(
alignment_dir,
alignment_index,input_sequence
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index,
)

template_features = make_template_features(
Expand Down Expand Up @@ -1008,8 +971,9 @@ def process_core(
core_feats = make_protein_features(protein_object, description)

hits = self._parse_template_hit_files(
alignment_dir,
alignment_index,input_sequence
alignment_dir=alignment_dir,
input_sequence=input_sequence,
alignment_index=alignment_index,
)

template_features = make_template_features(
Expand Down Expand Up @@ -1098,7 +1062,10 @@ def process_multiseq_fasta(self,
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, alignment_index=None,input_sequence=input_sequence)
hits = self._parse_template_hit_files(alignment_dir=alignment_dir,
input_sequence=seq,
alignment_index=None)

template_features = make_template_features(
seq,
hits,
Expand Down
21 changes: 14 additions & 7 deletions openfold/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def fape_loss(
interface_bb_loss = backbone_loss(
traj=traj,
pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface},
**{**batch, **config.interface_backbone},
)
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface.weight)
+ interface_bb_loss * config.interface_backbone.weight)
else:
bb_loss = backbone_loss(
traj=traj,
Expand Down Expand Up @@ -541,8 +541,11 @@ def lddt_loss(
cutoff=cutoff,
eps=eps
)
score = torch.nan_to_num(score,nan=torch.nanmean(score))

# TODO: Remove after initial pipeline testing
score = torch.nan_to_num(score, nan=torch.nanmean(score))
score[score<0] = 0

score = score.detach()
bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
Expand Down Expand Up @@ -1233,7 +1236,7 @@ def find_structural_violations(
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
torch.cuda.memory_summary()

# Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
Expand Down Expand Up @@ -1665,9 +1668,11 @@ def chain_center_of_mass_loss(
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True)
one_hot = torch.nn.functional.one_hot(asym_id.to(torch.int64)-1, # have to reduce asym_id by one because class values must be smaller than num_classes
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) # make sure asym_id dtype is int
chains = asym_id.unique()

# Reduce asym_id by one because class values must be smaller than num_classes and asym_ids start at 1
one_hot = torch.nn.functional.one_hot(asym_id.long() - 1,
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float()
Expand All @@ -1688,6 +1693,7 @@ def get_chain_center_of_mass(pos):
loss = masked_mean(loss_mask, losses, dim=(-1, -2))
return loss


# #
# below are the functions required for permutations
# #
Expand Down Expand Up @@ -1715,6 +1721,7 @@ def kabsch_rotation(P, Q):
assert rotation.shape == torch.Size([3,3])
return rotation.to('cuda')


def get_optimal_transform(
src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion tests/compare_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def get_alphafold_config():
return config


_param_path = f"openfold/resources/params/params_{consts.model}.npz"
dir_path = os.path.dirname(os.path.realpath(__file__))
_param_path = os.path.join(dir_path, "..", f"openfold/resources/params/params_{consts.model}.npz")
_model = None


Expand Down
8 changes: 4 additions & 4 deletions tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):

template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
if consts.is_multimer:
out_repro = model.template_embedder(
out_repro_all = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
Expand All @@ -267,7 +267,7 @@ def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
inplace_safe=False
)
else:
out_repro = model.template_embedder(
out_repro_all = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
Expand All @@ -277,10 +277,10 @@ def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
inplace_safe=False
)

out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu()

self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)


if __name__ == "__main__":
Expand Down

0 comments on commit 30764cf

Please sign in to comment.