diff --git a/cath_alphaflow/cli.py b/cath_alphaflow/cli.py index fdf8c33..59de3b9 100644 --- a/cath_alphaflow/cli.py +++ b/cath_alphaflow/cli.py @@ -7,6 +7,7 @@ from .commands import optimise_domain_boundaries from .commands import convert_dssp_to_sse_summary from .commands import convert_cif_to_dssp +from .commands import extract_plddt_and_lur logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s" @@ -47,3 +48,4 @@ def dump_config(): cli.add_command(optimise_domain_boundaries.optimise_domain_boundaries) cli.add_command(convert_dssp_to_sse_summary.convert_dssp_to_sse_summary) cli.add_command(convert_cif_to_dssp.convert_cif_to_dssp) +cli.add_command(extract_plddt_and_lur.convert_cif_to_plddt_summary) diff --git a/cath_alphaflow/commands/convert_dssp_to_sse_summary.py b/cath_alphaflow/commands/convert_dssp_to_sse_summary.py index 46b5905..604d774 100644 --- a/cath_alphaflow/commands/convert_dssp_to_sse_summary.py +++ b/cath_alphaflow/commands/convert_dssp_to_sse_summary.py @@ -8,11 +8,7 @@ get_sse_summary_writer, ) from cath_alphaflow.models import SecStrSummary -from cath_alphaflow.constants import ( - DEFAULT_DSSP_SUFFIX, - DEFAULT_HELIX_MIN_LENGTH, - DEFAULT_STRAND_MIN_LENGTH, -) +from cath_alphaflow.constants import DEFAULT_DSSP_SUFFIX @click.command() @@ -60,8 +56,6 @@ def get_sse_summary_from_dssp( dssp_string = [] read_headers = False - domain_length = 0 - ss_total = 0 if acc_id is None: acc_id = dssp_path.stem diff --git a/cath_alphaflow/commands/extract_plddt_and_lur.py b/cath_alphaflow/commands/extract_plddt_and_lur.py new file mode 100644 index 0000000..5fa2b8d --- /dev/null +++ b/cath_alphaflow/commands/extract_plddt_and_lur.py @@ -0,0 +1,141 @@ +from pathlib import Path +import gzip +from Bio.PDB import MMCIF2Dict +import logging +import click +from cath_alphaflow.io_utils import ( + yield_first_col, + get_plddt_summary_writer, +) +from cath_alphaflow.models import LURSummary, pLDDTSummary +from cath_alphaflow.constants import MIN_LENGTH_LUR + +LOG = logging.getLogger() + + +@click.command() +@click.option( + "--cif_in_dir", + type=click.Path(exists=True, file_okay=False, dir_okay=True, resolve_path=True), + required=True, + help="Input: directory of CIF files", +) +@click.option( + "--id_file", + type=click.File("rt"), + required=True, + help="Input: CSV file containing list of ids to process from CIF to pLDDT", +) +@click.option( + "--plddt_stats_file", + type=click.File("wt"), + required=True, + help="Output: pLDDT and LUR output file", +) +@click.option( + "--cif_suffix", + type=str, + default=".cif", + help="Option: suffix to use for mmCIF files (default: .cif)", +) +def convert_cif_to_plddt_summary( + cif_in_dir, + id_file, + plddt_stats_file, + cif_suffix, +): + "Creates summary of secondary structure elements (SSEs) from DSSP files" + + plddt_out_writer = get_plddt_summary_writer(plddt_stats_file) + + for file_stub in yield_first_col(id_file): + cif_path = Path(cif_in_dir) / f"{file_stub}{cif_suffix}" + if not cif_path.exists(): + msg = f"failed to locate CIF input file {cif_path}" + LOG.error(msg) + raise FileNotFoundError(msg) + + avg_plddt = get_average_plddt_from_plddt_string(cif_path) + perc_LUR_summary = get_LUR_residues_percentage(cif_path) + plddt_stats = pLDDTSummary( + af_domain_id=file_stub, + avg_plddt=avg_plddt, + perc_LUR=perc_LUR_summary.LUR_perc, + ) + plddt_out_writer.writerow(plddt_stats.__dict__) + + click.echo("DONE") + + +def get_average_plddt_from_plddt_string( + cif_path: Path, *, chopping=None, acc_id=None +) -> float: + if acc_id is None: + acc_id = cif_path.stem + open_func = open + if cif_path.name.endswith(".gz"): + open_func = gzip.open + with open_func(str(cif_path), mode="rt") as cif_fh: + mmcif_dict = MMCIF2Dict.MMCIF2Dict(cif_fh) + chain_plddt = mmcif_dict["_ma_qa_metric_global.metric_value"][0] + plddt_strings = mmcif_dict["_ma_qa_metric_local.metric_value"] + chopping_plddt = [] + if chopping: + for segment in chopping.segments: + segment_plddt = [ + float(plddt) + for plddt in plddt_strings[int(segment.start) - 1 : int(segment.end)] + ] + chopping_plddt += segment_plddt + domain_length = len(chopping_plddt) + average_plddt = round((sum(chopping_plddt) / domain_length), 2) + + else: + average_plddt = chain_plddt + return average_plddt + + +def get_LUR_residues_percentage(cif_path: Path, *, chopping=None, acc_id=None): + if acc_id is None: + acc_id = cif_path.stem + open_func = open + if cif_path.name.endswith(".gz"): + open_func = gzip.open + with open_func(str(cif_path), mode="rt") as cif_fh: + mmcif_dict = MMCIF2Dict.MMCIF2Dict(cif_fh) + plddt_strings = mmcif_dict["_ma_qa_metric_local.metric_value"] + chopping_plddt = [] + if chopping: + for segment in chopping.segments: + segment_plddt = [ + float(plddt) + for plddt in plddt_strings[int(segment.start) - 1 : int(segment.end)] + ] + chopping_plddt += segment_plddt + else: + chopping_plddt = plddt_strings + # Calculate LUR + LUR_perc = 0 + LUR_total = 0 + LUR_res = 0 + LUR_stretch = False + min_res_lur = MIN_LENGTH_LUR + for residue in segment_plddt: + plddt_res = float(residue) + if plddt_res < 70: + LUR_res += 1 + if LUR_stretch: + LUR_total += 1 + + if LUR_res == min_res_lur and not LUR_stretch: + LUR_stretch = True + LUR_total += min_res_lur + + else: + LUR_stretch = False + LUR_res = 0 + LUR_perc = round(LUR_total / len(chopping_plddt) * 100, 2) + + return LURSummary( + LUR_perc=LUR_perc, LUR_total=LUR_total, residues_total=len(chopping_plddt) + ) diff --git a/cath_alphaflow/constants.py b/cath_alphaflow/constants.py index 255a607..46b5a42 100644 --- a/cath_alphaflow/constants.py +++ b/cath_alphaflow/constants.py @@ -2,3 +2,4 @@ DEFAULT_DSSP_SUFFIX = ".dssp" DEFAULT_HELIX_MIN_LENGTH = 3 DEFAULT_STRAND_MIN_LENGTH = 2 +MIN_LENGTH_LUR = 5 diff --git a/cath_alphaflow/io_utils.py b/cath_alphaflow/io_utils.py index c99d58e..62caa1e 100644 --- a/cath_alphaflow/io_utils.py +++ b/cath_alphaflow/io_utils.py @@ -53,6 +53,19 @@ def get_sse_summary_writer(csvfile): return writer +def get_plddt_summary_writer(csvfile): + writer = get_csv_dictwriter( + csvfile, + fieldnames=[ + "af_domain_id", + "avg_plddt", + "perc_LUR", + ], + ) + writer.writeheader() + return writer + + class AFDomainIDReader(csv.DictReader): def __init__(self, *args): self._seen_header = False diff --git a/cath_alphaflow/models.py b/cath_alphaflow/models.py index 6521079..43bf794 100644 --- a/cath_alphaflow/models.py +++ b/cath_alphaflow/models.py @@ -124,6 +124,13 @@ def to_str(self): return self.af_domain_id +@dataclass +class LURSummary: + LUR_perc: float + LUR_total: int + residues_total: int + + @dataclass class SecStrSummary: af_domain_id: str @@ -170,13 +177,13 @@ def new_from_dssp_str( for residue in dssp_str: if residue == "H": sse_H_res += 1 - if sse_H_res >= min_helix_length and sse_H == False: + if sse_H_res >= min_helix_length and not sse_H: sse_H = True sse_H_num += 1 if residue == "E": sse_E_res += 1 - if sse_E_res >= min_strand_length and sse_E == False: + if sse_E_res >= min_strand_length and not sse_E: sse_E = True sse_E_num += 1 @@ -194,3 +201,12 @@ def new_from_dssp_str( ) return ss_sum + + +@dataclass +class pLDDTSummary: + af_domain_id: str + avg_plddt: float + perc_LUR: float + LUR_residues: int + total_residues: int diff --git a/tests/test_extract_plddt_and_lur.py b/tests/test_extract_plddt_and_lur.py new file mode 100644 index 0000000..e90ec54 --- /dev/null +++ b/tests/test_extract_plddt_and_lur.py @@ -0,0 +1,90 @@ +import os +from pathlib import Path +from click.testing import CliRunner +from cath_alphaflow.cli import cli +from cath_alphaflow.commands.extract_plddt_and_lur import ( + get_average_plddt_from_plddt_string, + get_LUR_residues_percentage, +) +from cath_alphaflow.models import Chopping, LURSummary, Segment + + +UNIPROT_IDS = ["P00520"] +FIXTURE_PATH = Path(__file__).parent / "fixtures" +EXAMPLE_CIF_FILE = FIXTURE_PATH / "cif" / "AF-P00520-F1-model_v3.cif.gz" + +SUBCOMMAND = "convert-cif-to-plddt-summary" + + +def test_cli_usage(): + runner = CliRunner() + with runner.isolated_filesystem(): + result = runner.invoke(cli, [SUBCOMMAND, "--help"]) + assert result.exit_code == 0 + assert "Usage:" in result.output + + +def create_fake_cif_path(dirname, cif_id, cif_src=EXAMPLE_CIF_FILE): + dir_path = Path(dirname) + dir_path.mkdir() + + cif_path_dest = dir_path / f"{cif_id}.cif.gz" + os.symlink(cif_src, f"{cif_path_dest}") + return cif_path_dest + + +def test_extract_plddt_summary(tmp_path): + acc_id = "test1" + cif_path = create_fake_cif_path(tmp_path.name, acc_id) + chopping = Chopping(segments=[Segment("10", "20")]) + + average_plddt = get_average_plddt_from_plddt_string( + cif_path, chopping=chopping, acc_id=acc_id + ) + + assert average_plddt == 33.71 + + chopping = Chopping(segments=[Segment("10", "20"), Segment("20", "35")]) + + average_plddt = get_average_plddt_from_plddt_string( + cif_path, chopping=chopping, acc_id=acc_id + ) + + assert average_plddt == 32.88 + + +def get_total_residues_from_chopping(chopping): + return sum([int(seg.end) - int(seg.start) + 1 for seg in chopping.segments]) + + +def test_extract_LUR_summary(tmp_path): + acc_id = "test1" + cif_path = create_fake_cif_path(tmp_path.name, acc_id) + chopping = Chopping(segments=[Segment("10", "20")]) + + lur_summary = get_LUR_residues_percentage( + cif_path, chopping=chopping, acc_id=acc_id + ) + + assert lur_summary == LURSummary( + LUR_perc=100.0, + LUR_total=11, + residues_total=get_total_residues_from_chopping(chopping), + ) + # clean up after test + del chopping + del lur_summary + + chopping = Chopping(segments=[Segment("1", "200"), Segment("200", "1120")]) + + lur_summary = get_LUR_residues_percentage( + cif_path, chopping=chopping, acc_id=acc_id + ) + + assert lur_summary == LURSummary( + LUR_perc=47.55, + LUR_total=533, + residues_total=get_total_residues_from_chopping(chopping), + ) + del chopping + del lur_summary