Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add refinery max iterations arg #103

Merged
merged 8 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test-aviary.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ jobs:
run: |
aviary -h
python test/test_assemble.py
python test/test_recover.py
python test/test_run_checkm.py -b
2 changes: 1 addition & 1 deletion aviary/aviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def main():

binning_group.add_argument(
'--refinery-max-iterations', '--refinery_max_iterations',
help='Maximum number of iterations for Rosella refinery',
help='Maximum number of iterations for Rosella refinery. Set to 0 to skip refinery.',
dest='refinery_max_iterations',
default=5
)
Expand Down
35 changes: 8 additions & 27 deletions aviary/modules/binning/binning.smk
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ rule checkm_rosella:
pplacer_threads = config["pplacer_threads"],
checkm2_db_path = config["checkm2_db_folder"],
bin_folder = "data/rosella_bins/",
extension = "fna"
extension = "fna",
refinery_max_iterations = config["refinery_max_iterations"],
group: 'binning'
output:
output_folder = directory("data/rosella_bins/checkm2_out/"),
Expand All @@ -372,20 +373,8 @@ rule checkm_rosella:
"../../envs/checkm2.yaml"
threads:
config["max_threads"]
shell:
'touch {output.output_file}; '
'if [ `ls "{params.bin_folder}" |grep .fna$ |wc -l` -eq 0 ]; then '
'echo "No bins found in {params.bin_folder}"; '
'touch {output.output_file}; '
'mkdir -p {output.output_folder}; '
'else '

'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; '
'echo "Using CheckM2 database $CHECKM2DB"; '
'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; '
'cp {output.output_folder}/quality_report.tsv {output.output_file}; '

'fi'
script:
"scripts/run_checkm.py"

rule checkm_metabat2:
input:
Expand All @@ -403,12 +392,8 @@ rule checkm_metabat2:
"../../envs/checkm2.yaml"
threads:
config["max_threads"]
shell:
'touch {output.output_file}; '
'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; '
'echo "Using CheckM2 database $CHECKM2DB"; '
'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; '
'cp {output.output_folder}/quality_report.tsv {output.output_file}'
script:
"scripts/run_checkm.py"

rule checkm_semibin:
input:
Expand All @@ -426,12 +411,8 @@ rule checkm_semibin:
"../../envs/checkm2.yaml"
threads:
config["max_threads"]
shell:
'touch {output.output_file}; '
'export CHECKM2DB={params.checkm2_db_path}/uniref100.KO.1.dmnd; '
'echo "Using CheckM2 database $CHECKM2DB"; '
'checkm2 predict -i {params.bin_folder}/ -x {params.extension} -o {output.output_folder} -t {threads} --force; '
'cp {output.output_folder}/quality_report.tsv {output.output_file}'
script:
"scripts/run_checkm.py"

rule refine_rosella:
input:
Expand Down
26 changes: 26 additions & 0 deletions aviary/modules/binning/scripts/run_checkm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import subprocess
import shutil
import os
from pathlib import Path

def checkm(checkm2_db, bin_folder, bin_ext, refinery_max_iterations, output_folder, output_file, threads):
if len([f for f in os.listdir(bin_folder) if f.endswith(bin_ext)]) > 0 and refinery_max_iterations > 0:
print(f"Using CheckM2 database {checkm2_db}/uniref100.KO.1.dmnd")
subprocess.run(f"CHECKM2DB={checkm2_db}/uniref100.KO.1.dmnd checkm2 predict -i {bin_folder}/ -x {bin_ext} -o {output_folder} -t {threads} --force")
shutil.copy(f"{output_folder}/quality_report.tsv", output_file)
else:
print(f"No bins found in {bin_folder}")
os.makedirs(output_folder)
Path(output_file).touch()


if __name__ == '__main__':
checkm2_db = snakemake.params.checkm2_db_path
bin_folder = snakemake.params.bin_folder
bin_ext = snakemake.params.extension
refinery_max_iterations = snakemake.params.refinery_max_iterations
output_folder = snakemake.output.output_folder
output_file = snakemake.output.output_file
threads = snakemake.threads

checkm(checkm2_db, bin_folder, bin_ext, refinery_max_iterations, output_folder, output_file, threads)
59 changes: 59 additions & 0 deletions test/test_run_checkm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3

import unittest
import os
import tempfile
from aviary.modules.binning.scripts.run_checkm import checkm
from unittest.mock import patch
import subprocess
from pathlib import Path

def create_output(_):
os.makedirs("output_folder")
Path(os.path.join("output_folder", "quality_report.tsv")).touch()

class Tests(unittest.TestCase):
def test_run_checkm(self):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with patch.object(subprocess, "run", side_effect=create_output) as mock_subprocess:
checkm2_db = os.path.join("checkm2_db")
os.makedirs(checkm2_db)
bin_folder = os.path.join("bin_folder")
os.makedirs(bin_folder)
Path(os.path.join(bin_folder, "bin_1.fna")).touch()

checkm(checkm2_db, bin_folder, "fna", 1, "output_folder", "output_file", 1)
self.assertTrue(os.path.exists("output_file"))
mock_subprocess.assert_called_once_with(f"CHECKM2DB={checkm2_db}/uniref100.KO.1.dmnd checkm2 predict -i {bin_folder}/ -x fna -o output_folder -t 1 --force")

def test_run_checkm_no_bins(self):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with patch.object(subprocess, "run") as mock_subprocess:
checkm2_db = os.path.join("checkm2_db")
os.makedirs(checkm2_db)
bin_folder = os.path.join("bin_folder")
os.makedirs(bin_folder)

checkm(checkm2_db, bin_folder, "fna", 1, "output_folder", "output_file", 1)
self.assertTrue(os.path.exists("output_file"))
mock_subprocess.assert_not_called()

def test_run_checkm_no_refinery(self):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with patch.object(subprocess, "run") as mock_subprocess:
checkm2_db = os.path.join("checkm2_db")
os.makedirs(checkm2_db)
bin_folder = os.path.join("bin_folder")
os.makedirs(bin_folder)
Path(os.path.join(bin_folder, "bin_1.fna")).touch()

checkm(checkm2_db, bin_folder, "fna", 0, "output_folder", "output_file", 1)
self.assertTrue(os.path.exists("output_file"))
mock_subprocess.assert_not_called()


if __name__ == '__main__':
unittest.main()