Skip to content

Commit

Permalink
Merge branch 'multimer' into permutation
Browse files Browse the repository at this point in the history
  • Loading branch information
dingquanyu authored Aug 29, 2023
2 parents c737664 + ab09ded commit 87fbff7
Show file tree
Hide file tree
Showing 13 changed files with 301 additions and 206 deletions.
118 changes: 42 additions & 76 deletions openfold/data/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,15 @@ def __init__(self,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
)

self.data_pipeline = data_pipeline.DataPipeline(
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
self.multimer_data_pipeline = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=self.data_pipeline
self.data_pipeline = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)

def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()

Expand All @@ -452,7 +452,6 @@ def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
data = self.data_pipeline.process_mmcif(
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
alignment_index=alignment_index
)

Expand All @@ -468,82 +467,49 @@ def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is :{mmcif_id} idx:{idx} and has {len(chains)}chains")
seqs = self.mmcif_data_cache[mmcif_id]['seqs']
fasta_str = ""
for c,s in zip(chains,seqs):
fasta_str+=f">{mmcif_id}_{c}\n{s}\n"
with temp_fasta_file(fasta_str) as fasta_file:
all_chain_features = self.multimer_data_pipeline.process_fasta(fasta_file,self.alignment_dir)

# process all_chain_features
all_chain_features = self.feature_pipeline.process_features(all_chain_features,
mode=self.mode,
is_multimer=True)


alignment_index = None
ground_truth=[]
if(self.mode == 'train' or self.mode == 'eval'):
for chain in chains:
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
for e in self.supported_exts:
if(os.path.exists(path + e)):
ext = e
break

if(ext is None):
raise ValueError("Invalid file type")
if(ext is None):
raise ValueError("Invalid file type")

path += ext
alignment_dir = os.path.join(self.alignment_dir,f"{mmcif_id}_{chain.upper()}")
if(ext == ".cif"):
data = self._parse_mmcif(
path, mmcif_id, chain, alignment_dir, alignment_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
#remove recycling dimension
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
elif(ext == ".pdb"):
structure_index = None
data = self.data_pipeline.process_pdb(
pdb_path=path,
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain,
alignment_index=alignment_index,
_structure_index=structure_index,
)
ground_truth_feats = self.feature_pipeline.process_features(data, "train",
is_multimer=False)
ground_truth_feats = tensor_tree_map(lambda t: t[..., -1], ground_truth_feats)
ground_truth.append(ground_truth_feats)
else:
raise ValueError("Extension branch missing")

all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=all_chain_features["aatype"].device)
# if it's training now, then return both all_chain_features and ground_truth
return all_chain_features,ground_truth

#TODO: Add pdb and core exts to data_pipeline for multimer
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path, mmcif_id, self.alignment_dir, alignment_index,
)
else:
raise ValueError("Extension branch missing")
else:
# if it's inference mode, only need all_chain_features
all_chain_features["batch_idx"] = torch.tensor(
[idx for _ in range(all_chain_features["aatype"].shape[-1])],
dtype=torch.int64,
device=all_chain_features["aatype"].device)
return all_chain_features
path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=self.alignment_dir
)

if (self._output_raw):
return data

# process all_chain_features
data = self.feature_pipeline.process_features(data,
mode=self.mode,
is_multimer=True)

# if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])],
dtype=torch.int64,
device=data["aatype"].device)

return data


def __len__(self):
Expand Down
71 changes: 71 additions & 0 deletions openfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,4 +1188,75 @@ def process_fasta(self,
# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)

return np_example

def get_mmcif_features(
self, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
mmcif_feats = {}

all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask

mmcif_feats["resolution"] = np.array(
mmcif_object.header["resolution"], dtype=np.float32
)

mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)

mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)

return mmcif_feats

def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
alignment_index: Optional[str] = None,
) -> FeatureDict:

all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id])

if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue

chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
description=desc,
chain_alignment_dir=os.path.join(alignment_dir, desc),
is_homomer_or_monomer=is_homomer_or_monomer
)

chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)

mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features

all_chain_features = add_assembly_features(all_chain_features)

np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)

# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)

return np_example
58 changes: 34 additions & 24 deletions openfold/data/data_transforms_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,25 @@ def make_msa_profile(batch):
return batch


def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])


def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :]
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))

diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :]
mask = diff_chain_mask[..., None] * pair_mask
mask = (diff_chain_mask[..., None] * pair_mask).bool()

min_dist_per_res = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
min_dist_per_res, _ = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)

valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
Expand All @@ -334,8 +344,12 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
if not torch.any(interface_residues):
return get_contiguous_crop_idx(protein, crop_size, generator)

target_res = interface_residues[int(torch.randint(0, interface_residues.shape[-1], (1,),
device=positions.device, generator=generator)[0])]
target_res_idx = randint(lower=0,
upper=interface_residues.shape[-1],
generator=generator,
device=positions.device)

target_res = interface_residues[target_res_idx]

ca_idx = rc.atom_order["CA"]
ca_positions = positions[..., ca_idx, :]
Expand All @@ -351,33 +365,24 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
).float()
* 1e-3
)
to_target_distances = torch.where(ca_mask[..., None], to_target_distances, torch.inf) + break_tie
to_target_distances = torch.where(ca_mask, to_target_distances, torch.inf) + break_tie

ret = torch.argsort(to_target_distances)[:crop_size]
return ret.sort().values


def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])


def get_contiguous_crop_idx(protein, crop_size, generator):
num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
return torch.arange(num_res)

_, chain_lens = protein["asym_id"].unique(return_counts=True)
unique_asym_ids, chain_lens = protein["asym_id"].unique(return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
num_remaining = int(chain_lens.sum())
num_budget = crop_size
crop_idxs = []
asym_offset = torch.tensor(0, dtype=torch.int64)

per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (protein["asym_id"]== cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(protein["asym_id"], asym_mask)[0]

for j, idx in enumerate(shuffle_idx):
this_len = int(chain_lens[idx])
num_remaining -= this_len
Expand All @@ -394,6 +399,8 @@ def get_contiguous_crop_idx(protein, crop_size, generator):
upper=this_len - chain_crop_size + 1,
generator=generator,
device=chain_lens.device)

asym_offset = per_asym_residue_index[int(idx)]
crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
)
Expand Down Expand Up @@ -424,7 +431,11 @@ def random_crop_to_size(
use_spatial_crop = torch.rand((1,),
device=protein["seq_length"].device,
generator=g) < spatial_crop_prob
if use_spatial_crop:

num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
crop_idxs = torch.arange(num_res)
elif use_spatial_crop:
crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g)
else:
crop_idxs = get_contiguous_crop_idx(protein, crop_size, g)
Expand Down Expand Up @@ -466,9 +477,8 @@ def random_crop_to_size(
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
v = v[slice(crop_start, crop_start + crop_size)]
v = v[slice(crop_start, crop_start + num_templates_crop_size)]
elif is_num_res:
v = torch.index_select(v, i, crop_idxs)

Expand Down
13 changes: 7 additions & 6 deletions openfold/data/feature_processing_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,13 @@ def process_unmerged_features(
chain_features['deletion_matrix'], axis=0
)

# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])
if 'all_atom_positions' not in chain_features:
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask.astype(dtype=np.float32)
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])

# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)
Expand Down
12 changes: 12 additions & 0 deletions openfold/data/input_pipeline_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_atom14_masks,
]

if mode_cfg.supervised:
transforms.extend(
[
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
)

return transforms


Expand Down
Loading

0 comments on commit 87fbff7

Please sign in to comment.