Skip to content

Commit

Permalink
rename ops to data_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Aug 1, 2022
1 parent b2b80b0 commit 2cc9792
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 42 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion unifold/data/msa_pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Dict, Iterable, List, Sequence

from .residue_constants import restypes_with_x_and_gap
from .ops import NumpyDict
from .data_ops import NumpyDict
import numpy as np
import pandas as pd
import scipy.linalg
Expand Down
82 changes: 41 additions & 41 deletions unifold/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,37 @@
import torch
import numpy as np

from unifold.data import ops
from unifold.data import data_ops


def nonensembled_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
v2_feature = common_cfg.v2_feature
operators = []
if mode_cfg.random_delete_msa:
operators.append(ops.random_delete_msa(common_cfg.random_delete_msa))
operators.append(data_ops.random_delete_msa(common_cfg.random_delete_msa))
operators.extend(
[
ops.cast_to_64bit_ints,
ops.correct_msa_restypes,
ops.squeeze_features,
ops.randomly_replace_msa_with_unknown(0.0),
ops.make_seq_mask,
ops.make_msa_mask,
data_ops.cast_to_64bit_ints,
data_ops.correct_msa_restypes,
data_ops.squeeze_features,
data_ops.randomly_replace_msa_with_unknown(0.0),
data_ops.make_seq_mask,
data_ops.make_msa_mask,
]
)
operators.append(
ops.make_hhblits_profile_v2 if v2_feature else ops.make_hhblits_profile
data_ops.make_hhblits_profile_v2 if v2_feature else data_ops.make_hhblits_profile
)
if common_cfg.use_templates:
operators.extend(
[
ops.make_template_mask,
ops.make_pseudo_beta("template_"),
data_ops.make_template_mask,
data_ops.make_pseudo_beta("template_"),
]
)
operators.append(
ops.crop_templates(
data_ops.crop_templates(
max_templates=mode_cfg.max_templates,
subsample_templates=mode_cfg.subsample_templates,
)
Expand All @@ -42,12 +42,12 @@ def nonensembled_fns(common_cfg, mode_cfg):
if common_cfg.use_template_torsion_angles:
operators.extend(
[
ops.atom37_to_torsion_angles("template_"),
data_ops.atom37_to_torsion_angles("template_"),
]
)

operators.append(ops.make_atom14_masks)
operators.append(ops.make_target_feat)
operators.append(data_ops.make_atom14_masks)
operators.append(data_ops.make_target_feat)

return operators

Expand All @@ -62,25 +62,25 @@ def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed):
if mode_cfg.fixed_size:
if mode_cfg.crop:
if common_cfg.is_multimer:
crop_fn = ops.crop_to_size_multimer(
crop_fn = data_ops.crop_to_size_multimer(
crop_size=mode_cfg.crop_size,
shape_schema=crop_feats,
seed=crop_and_fix_size_seed,
spatial_crop_prob=mode_cfg.spatial_crop_prob,
ca_ca_threshold=mode_cfg.ca_ca_threshold,
)
else:
crop_fn = ops.crop_to_size_single(
crop_fn = data_ops.crop_to_size_single(
crop_size=mode_cfg.crop_size,
shape_schema=crop_feats,
seed=crop_and_fix_size_seed,
)
operators.append(crop_fn)

operators.append(ops.select_feat(crop_feats))
operators.append(data_ops.select_feat(crop_feats))

operators.append(
ops.make_fixed_size(
data_ops.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
Expand All @@ -98,10 +98,10 @@ def ensembled_fns(common_cfg, mode_cfg):
v2_feature = common_cfg.v2_feature
# multimer don't use block delete msa
if mode_cfg.block_delete_msa and not multimer_mode:
operators.append(ops.block_delete_msa(common_cfg.block_delete_msa))
operators.append(data_ops.block_delete_msa(common_cfg.block_delete_msa))
if "max_distillation_msa_clusters" in mode_cfg:
operators.append(
ops.sample_msa_distillation(mode_cfg.max_distillation_msa_clusters)
data_ops.sample_msa_distillation(mode_cfg.max_distillation_msa_clusters)
)

if common_cfg.reduce_msa_clusters_by_max_templates:
Expand All @@ -115,7 +115,7 @@ def ensembled_fns(common_cfg, mode_cfg):
assert common_cfg.resample_msa_in_recycling
gumbel_sample = common_cfg.gumbel_sample
operators.append(
ops.sample_msa(
data_ops.sample_msa(
max_msa_clusters,
keep_extra=True,
gumbel_sample=gumbel_sample,
Expand All @@ -128,7 +128,7 @@ def ensembled_fns(common_cfg, mode_cfg):
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
operators.append(
ops.make_masked_msa(
data_ops.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction,
gumbel_sample=gumbel_sample,
Expand All @@ -138,23 +138,23 @@ def ensembled_fns(common_cfg, mode_cfg):

if common_cfg.msa_cluster_features:
if v2_feature:
operators.append(ops.nearest_neighbor_clusters_v2())
operators.append(data_ops.nearest_neighbor_clusters_v2())
else:
operators.append(ops.nearest_neighbor_clusters())
operators.append(ops.summarize_clusters)
operators.append(data_ops.nearest_neighbor_clusters())
operators.append(data_ops.summarize_clusters)

if v2_feature:
operators.append(ops.make_msa_feat_v2)
operators.append(data_ops.make_msa_feat_v2)
else:
operators.append(ops.make_msa_feat)
operators.append(data_ops.make_msa_feat)
# Crop after creating the cluster profiles.
if max_extra_msa:
if v2_feature:
operators.append(ops.make_extra_msa_feat(max_extra_msa))
operators.append(data_ops.make_extra_msa_feat(max_extra_msa))
else:
operators.append(ops.crop_extra_msa(max_extra_msa))
operators.append(data_ops.crop_extra_msa(max_extra_msa))
else:
operators.append(ops.delete_extra_msa)
operators.append(data_ops.delete_extra_msa)
# operators.append(data_operators.select_feat(common_cfg.recycling_features))
return operators

Expand All @@ -179,11 +179,11 @@ def wrap_ensemble_fn(data, i):
)
new_d = compose(fns)(d)
if not multimer_mode or is_distillation:
new_d = ops.select_feat(common_cfg.recycling_features)(new_d)
new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d)
return compose(crop_fn)(new_d)
else: # select after crop for spatial cropping
d = compose(crop_fn)(d)
d = ops.select_feat(common_cfg.recycling_features)(d)
d = data_ops.select_feat(common_cfg.recycling_features)(d)
return d

nonensembled = nonensembled_fns(common_cfg, mode_cfg)
Expand All @@ -207,7 +207,7 @@ def wrap_ensemble_fn(data, i):
return tensors


@ops.curry1
@data_ops.curry1
def compose(x, fs):
for f in fs:
x = f(x)
Expand Down Expand Up @@ -243,11 +243,11 @@ def process_labels(labels_list, num_ensemble: Optional[int] = None):

def label_transform_fn():
return [
ops.make_atom14_masks,
ops.make_atom14_positions,
ops.atom37_to_frames,
ops.atom37_to_torsion_angles(""),
ops.make_pseudo_beta(""),
ops.get_backbone_frames,
ops.get_chi_angles,
data_ops.make_atom14_masks,
data_ops.make_atom14_positions,
data_ops.atom37_to_frames,
data_ops.atom37_to_torsion_angles(""),
data_ops.make_pseudo_beta(""),
data_ops.get_backbone_frames,
data_ops.get_chi_angles,
]

0 comments on commit 2cc9792

Please sign in to comment.