diff --git a/scripts/get_assembly_from_mmcif.py b/scripts/get_assembly_from_mmcif.py index b4077f2..136984b 100644 --- a/scripts/get_assembly_from_mmcif.py +++ b/scripts/get_assembly_from_mmcif.py @@ -6,6 +6,11 @@ from tqdm import tqdm from multiprocessing import Pool from unifold.msa.mmcif import parse +import argparse +import gzip +from Bio.PDB import protein_letters_3to1 +import numpy as np +from unifold.data.residue_constants import restype_order_with_x rot_keys = """_pdbx_struct_oper_list.matrix[1][1] _pdbx_struct_oper_list.matrix[1][2] @@ -57,11 +62,15 @@ def process_block_to_dict(content): else: last_val = [] for line in lines: + if line.startswith(";"): + continue t = shlex.split(line) last_val.extend(t) if len(last_val) == 2: ret[last_val[0]] = [last_val[1]] last_val = [] + if len(last_val) > 2: + last_val = [] if last_val: assert len(last_val) == 2 ret[last_val[0]] = [last_val[1]] @@ -69,7 +78,7 @@ def process_block_to_dict(content): def get_transform(data, idx): - idx = int(idx) - 1 + idx = data["_pdbx_struct_oper_list.id"].index(f"{idx}") rot = [] for key in rot_keys: rot.append(float(data[key][idx])) @@ -87,10 +96,23 @@ def get_transform(data, idx): return rot, trans +def mmcif_object_to_fasta(mmcif_object, auth_chain_id: str) -> str: + residues = mmcif_object.seqres_to_structure[auth_chain_id] + residue_names = [residues[t].name for t in range(len(residues))] + residue_letters = [ + protein_letters_3to1[n] if n in protein_letters_3to1.keys() else "X" for n in residue_names + ] + # take care of cases where residue letters are of length 3 + # simply by replacing them as 'X' ('UNK') + filter_out_triple_letters = lambda x: x if len(x) == 1 else "X" + fasta_string = "".join([filter_out_triple_letters(n) for n in residue_letters]) + return fasta_string + + def parse_assembly(mmcif_path): name = os.path.split(mmcif_path)[-1].split(".")[0] - with open(mmcif_path) as f: - mmcif_string = f.read() + with gzip.open(mmcif_path, "rb") as f: + mmcif_string = f.read().decode("utf8") mmcif_lines = mmcif_string.split("\n") parse_result = parse(file_id="", mmcif_string=mmcif_string) if "No protein chains found in this file." in parse_result.errors.values(): @@ -101,7 +123,33 @@ def parse_assembly(mmcif_path): return name, [], [], [], "parse error" mmcif_to_author_chain_id = mmcif_obj.mmcif_to_author_chain_id valid_chains = mmcif_obj.valid_chains.keys() - valid_chains = set(valid_chains) # valid chains is not auth_id + valid_chains_set = set(valid_chains) # valid chains is not auth_id + + resolution = np.array([mmcif_obj.header["resolution"]]) + if resolution > 9: + return name, [], [], [], "resolution" + + invalid_chains = [] + for chain_id in mmcif_obj.chain_to_seqres: + sequence = mmcif_object_to_fasta(mmcif_obj, chain_id) + aatype_idx = np.array( + [ + restype_order_with_x[rn] + if rn in restype_order_with_x + else restype_order_with_x["X"] + for rn in sequence + ] + ) + seq_len = aatype_idx.shape[0] + _, counts = np.unique(aatype_idx, return_counts=True) + freqs = counts.astype(np.float32) / seq_len + max_freq = np.max(freqs) + if max_freq > 0.8: + invalid_chains.append(chain_id) + valid_chains = [] + for chain_id in valid_chains_set: + if mmcif_to_author_chain_id[chain_id] not in invalid_chains: + valid_chains.append(chain_id) new_section = False is_loop = False cur_lines = [] @@ -116,7 +164,9 @@ def parse_assembly(mmcif_path): continue if line == "#": cur_str = "\n".join(cur_lines) - if "revision" in cur_str: + if "revision" in cur_str and ( + assembly is not None and assembly_gen is not None and oper is not None + ): continue if "_pdbx_struct_assembly.id" in cur_str: assembly = process_block_to_dict(cur_str) @@ -140,14 +190,12 @@ def parse_assembly(mmcif_path): chains = [] chains_ops = [] for i, j in enumerate(asym_id): - if j == "1": - sss = ( - op_idx[i] - .replace("(", "") - .replace(")", "") - .replace("'", "") - .replace('"', "") - ) + idxxxx = "1" + if name in ["6lb3"]: + # the first assembly consists of two polydeoxyribonucleotides + idxxxx = "2" + if j == idxxxx: + sss = op_idx[i].replace("(", "").replace(")", "").replace("'", "").replace('"', "") if "-" in sss: s, t = sss.split("-") indices = range(int(s), int(t) + 1) @@ -175,25 +223,41 @@ def parse_assembly(mmcif_path): return name, [], [], [], "index" -input_dir = sys.argv[1] -output_file = sys.argv[2] -input_files = glob.glob(input_dir + "*.cif") +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input-dir", type=str, default="/you/mmcif") + parser.add_argument("--output-file", type=str, default="/you/mmcif_assembly3.json") + args = parser.parse_args() + print(args) + + input_dir = args.input_dir + output_file = args.output_file + + os.makedirs(os.path.dirname(output_file), exist_ok=True) -file_cnt = len(input_files) -meta_dict = {} -failed = [] -with Pool(64) as pool: - for ret in tqdm(pool.imap(parse_assembly, input_files), total=file_cnt): - name, all_chains, all_chains_label, all_chains_ops, error_type = ret - if all_chains: - meta_dict[name] = {} - meta_dict[name]["chains"] = all_chains - meta_dict[name]["chains_label"] = all_chains_label - meta_dict[name]["opers"] = all_chains_ops - else: - failed.append(name + " " + error_type) + input_files = glob.glob(os.path.join(input_dir, "*.cif.gz")) + + file_cnt = len(input_files) + print(f"len(input_files): {file_cnt}") + meta_dict = {} + failed = {} + + with Pool(80) as pool: + for ret in tqdm(pool.imap(parse_assembly, input_files, chunksize=10), total=file_cnt): + name, all_chains, all_chains_label, all_chains_ops, error_type = ret + if all_chains: + meta_dict[name] = {} + meta_dict[name]["chains"] = all_chains + meta_dict[name]["chains_label"] = all_chains_label + meta_dict[name]["opers"] = all_chains_ops + else: + # failed.append(name + " " + error_type) + failed[name] = error_type -json.dump(meta_dict, open(output_file, "w"), indent=2) + json.dump(meta_dict, open(output_file, "w"), indent=2) + # write_list_to_file(failed, "failed_mmcif.txt") + root = os.path.splitext(output_file)[0] + json.dump(failed, open(f"{root}.failed.json", "w"), indent=2) def write_list_to_file(a, file): @@ -202,4 +266,5 @@ def write_list_to_file(a, file): output.write(str(x) + "\n") -write_list_to_file(failed, "failed_mmcif.txt") +if __name__ == "__main__": + main()