Skip to content

Commit

Permalink
update get assembly from mmcif (dptech-corp#96)
Browse files Browse the repository at this point in the history
* update get assembly from mmcif

* add comment
  • Loading branch information
teslacool committed Feb 27, 2023
1 parent 0779e9c commit 4e66e16
Showing 1 changed file with 96 additions and 31 deletions.
127 changes: 96 additions & 31 deletions scripts/get_assembly_from_mmcif.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from tqdm import tqdm
from multiprocessing import Pool
from unifold.msa.mmcif import parse
import argparse
import gzip
from Bio.PDB import protein_letters_3to1
import numpy as np
from unifold.data.residue_constants import restype_order_with_x

rot_keys = """_pdbx_struct_oper_list.matrix[1][1]
_pdbx_struct_oper_list.matrix[1][2]
Expand Down Expand Up @@ -57,19 +62,23 @@ def process_block_to_dict(content):
else:
last_val = []
for line in lines:
if line.startswith(";"):
continue
t = shlex.split(line)
last_val.extend(t)
if len(last_val) == 2:
ret[last_val[0]] = [last_val[1]]
last_val = []
if len(last_val) > 2:
last_val = []
if last_val:
assert len(last_val) == 2
ret[last_val[0]] = [last_val[1]]
return ret


def get_transform(data, idx):
idx = int(idx) - 1
idx = data["_pdbx_struct_oper_list.id"].index(f"{idx}")
rot = []
for key in rot_keys:
rot.append(float(data[key][idx]))
Expand All @@ -87,10 +96,23 @@ def get_transform(data, idx):
return rot, trans


def mmcif_object_to_fasta(mmcif_object, auth_chain_id: str) -> str:
residues = mmcif_object.seqres_to_structure[auth_chain_id]
residue_names = [residues[t].name for t in range(len(residues))]
residue_letters = [
protein_letters_3to1[n] if n in protein_letters_3to1.keys() else "X" for n in residue_names
]
# take care of cases where residue letters are of length 3
# simply by replacing them as 'X' ('UNK')
filter_out_triple_letters = lambda x: x if len(x) == 1 else "X"
fasta_string = "".join([filter_out_triple_letters(n) for n in residue_letters])
return fasta_string


def parse_assembly(mmcif_path):
name = os.path.split(mmcif_path)[-1].split(".")[0]
with open(mmcif_path) as f:
mmcif_string = f.read()
with gzip.open(mmcif_path, "rb") as f:
mmcif_string = f.read().decode("utf8")
mmcif_lines = mmcif_string.split("\n")
parse_result = parse(file_id="", mmcif_string=mmcif_string)
if "No protein chains found in this file." in parse_result.errors.values():
Expand All @@ -101,7 +123,33 @@ def parse_assembly(mmcif_path):
return name, [], [], [], "parse error"
mmcif_to_author_chain_id = mmcif_obj.mmcif_to_author_chain_id
valid_chains = mmcif_obj.valid_chains.keys()
valid_chains = set(valid_chains) # valid chains is not auth_id
valid_chains_set = set(valid_chains) # valid chains is not auth_id

resolution = np.array([mmcif_obj.header["resolution"]])
if resolution > 9:
return name, [], [], [], "resolution"

invalid_chains = []
for chain_id in mmcif_obj.chain_to_seqres:
sequence = mmcif_object_to_fasta(mmcif_obj, chain_id)
aatype_idx = np.array(
[
restype_order_with_x[rn]
if rn in restype_order_with_x
else restype_order_with_x["X"]
for rn in sequence
]
)
seq_len = aatype_idx.shape[0]
_, counts = np.unique(aatype_idx, return_counts=True)
freqs = counts.astype(np.float32) / seq_len
max_freq = np.max(freqs)
if max_freq > 0.8:
invalid_chains.append(chain_id)
valid_chains = []
for chain_id in valid_chains_set:
if mmcif_to_author_chain_id[chain_id] not in invalid_chains:
valid_chains.append(chain_id)
new_section = False
is_loop = False
cur_lines = []
Expand All @@ -116,7 +164,9 @@ def parse_assembly(mmcif_path):
continue
if line == "#":
cur_str = "\n".join(cur_lines)
if "revision" in cur_str:
if "revision" in cur_str and (
assembly is not None and assembly_gen is not None and oper is not None
):
continue
if "_pdbx_struct_assembly.id" in cur_str:
assembly = process_block_to_dict(cur_str)
Expand All @@ -140,14 +190,12 @@ def parse_assembly(mmcif_path):
chains = []
chains_ops = []
for i, j in enumerate(asym_id):
if j == "1":
sss = (
op_idx[i]
.replace("(", "")
.replace(")", "")
.replace("'", "")
.replace('"', "")
)
idxxxx = "1"
if name in ["6lb3"]:
# the first assembly consists of two polydeoxyribonucleotides
idxxxx = "2"
if j == idxxxx:
sss = op_idx[i].replace("(", "").replace(")", "").replace("'", "").replace('"', "")
if "-" in sss:
s, t = sss.split("-")
indices = range(int(s), int(t) + 1)
Expand Down Expand Up @@ -175,25 +223,41 @@ def parse_assembly(mmcif_path):
return name, [], [], [], "index"


input_dir = sys.argv[1]
output_file = sys.argv[2]
input_files = glob.glob(input_dir + "*.cif")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input-dir", type=str, default="/you/mmcif")
parser.add_argument("--output-file", type=str, default="/you/mmcif_assembly3.json")
args = parser.parse_args()
print(args)

input_dir = args.input_dir
output_file = args.output_file

os.makedirs(os.path.dirname(output_file), exist_ok=True)

file_cnt = len(input_files)
meta_dict = {}
failed = []
with Pool(64) as pool:
for ret in tqdm(pool.imap(parse_assembly, input_files), total=file_cnt):
name, all_chains, all_chains_label, all_chains_ops, error_type = ret
if all_chains:
meta_dict[name] = {}
meta_dict[name]["chains"] = all_chains
meta_dict[name]["chains_label"] = all_chains_label
meta_dict[name]["opers"] = all_chains_ops
else:
failed.append(name + " " + error_type)
input_files = glob.glob(os.path.join(input_dir, "*.cif.gz"))

file_cnt = len(input_files)
print(f"len(input_files): {file_cnt}")
meta_dict = {}
failed = {}

with Pool(80) as pool:
for ret in tqdm(pool.imap(parse_assembly, input_files, chunksize=10), total=file_cnt):
name, all_chains, all_chains_label, all_chains_ops, error_type = ret
if all_chains:
meta_dict[name] = {}
meta_dict[name]["chains"] = all_chains
meta_dict[name]["chains_label"] = all_chains_label
meta_dict[name]["opers"] = all_chains_ops
else:
# failed.append(name + " " + error_type)
failed[name] = error_type

json.dump(meta_dict, open(output_file, "w"), indent=2)
json.dump(meta_dict, open(output_file, "w"), indent=2)
# write_list_to_file(failed, "failed_mmcif.txt")
root = os.path.splitext(output_file)[0]
json.dump(failed, open(f"{root}.failed.json", "w"), indent=2)


def write_list_to_file(a, file):
Expand All @@ -202,4 +266,5 @@ def write_list_to_file(a, file):
output.write(str(x) + "\n")


write_list_to_file(failed, "failed_mmcif.txt")
if __name__ == "__main__":
main()

0 comments on commit 4e66e16

Please sign in to comment.