Skip to content

Commit

Permalink
change sketch_samples to submit with for loop
Browse files Browse the repository at this point in the history
  • Loading branch information
AroneyS committed Feb 10, 2025
1 parent e683836 commit 32d8ac4
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 30 deletions.
85 changes: 55 additions & 30 deletions binchicken/workflow/scripts/sketch_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sourmash.sourmash_args import SaveSignaturesToLocation
from concurrent.futures import ProcessPoolExecutor
import extern
import re

SINGLEM_OTU_TABLE_SCHEMA = {
"gene": str,
Expand All @@ -30,38 +31,62 @@ def process_groups(groups, output_path):
signature = SourmashSignature(mh, name=sample)
save_sigs.add(signature)

def processing(unbinned_path, output_path, taxa_of_interest=None, threads=1):
unbinned = pl.read_csv(unbinned_path, separator="\t", schema_overrides=SINGLEM_OTU_TABLE_SCHEMA)

if taxa_of_interest:
logging.info(f"Filtering for taxa of interest: {taxa_of_interest}")
unbinned = unbinned.filter(
pl.col("taxonomy").str.contains(taxa_of_interest)
)

def processing(unbinned_path, output_path, taxa_of_interest=None, threads=1, samples_per_group=1000):
output_dir = os.path.dirname(output_path)

logging.info("Grouping samples")
groups = [(s[0], d.get_column("sequence").to_list()) for s,d in unbinned.set_sorted("sample").select("sample", "sequence").group_by(["sample"])]
threads = min(threads, len(groups))

# Distribute groups among threads more evenly
grouped = [[] for _ in range(threads)]
for i, group in enumerate(groups):
grouped[i % threads].append(group)

del groups

logging.info("Generating sketches in separate threads")
with ProcessPoolExecutor(max_workers=threads) as executor:
futures = []
for i, group_subset in enumerate(grouped):
output_subpath = os.path.join(output_dir, f"signatures_thread_{i}.sig")
future = executor.submit(process_groups, group_subset, output_subpath)
futures.append(future)

for future in futures:
future.result()
with open(unbinned_path) as f:
logging.info(f"Reading unbinned OTU table from {unbinned_path}")
if taxa_of_interest:
logging.info(f"Filtering for taxa of interest: {taxa_of_interest}")

logging.info("Generating sketches in separate threads")
with ProcessPoolExecutor(max_workers=threads) as executor:
current_sample = ""
current_sequences = []
i = 0
group_subset = []
futures = []
for line in f:
line = line.strip()
if line == "gene\tsample\tsequence\tnum_hits\tcoverage\ttaxonomy":
continue

sample = line.split("\t")[1]
sequence = line.split("\t")[2]

if taxa_of_interest:
taxonomy = line.split("\t")[5]
if not re.search(taxa_of_interest, taxonomy):
continue

if not current_sample:
current_sample = sample

if sample != current_sample:
group_subset.append((current_sample, current_sequences))
current_sample = sample
current_sequences = [sequence]

if len(group_subset) == samples_per_group:
output_subpath = os.path.join(output_dir, f"signatures_thread_{i}.sig")
future = executor.submit(process_groups, group_subset, output_subpath)
futures.append(future)
i += 1
group_subset = []
else:
current_sequences.append(sequence)

# Submit any remaining groups that are smaller than samples_per_group
if sample == current_sample:
group_subset.append((current_sample, current_sequences))

if group_subset:
output_subpath = os.path.join(output_dir, f"signatures_thread_{i}.sig")
future = executor.submit(process_groups, group_subset, output_subpath)
futures.append(future)

for future in futures:
future.result()

logging.info("Concatenating sketches")
extern.run(f"sourmash sig cat {os.path.join(output_dir, 'signatures_thread_*.sig')} -o {output_path}")
Expand Down
88 changes: 88 additions & 0 deletions test/test_sketch_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,94 @@ def test_sketch_samples_taxa_of_interest(self):
self.assertEqual(sample_2_sig.jaccard(sample_4_sig), 0.25)
self.assertEqual(sample_3_sig.jaccard(sample_4_sig), 0.5)

def test_sketch_samples_exact_groups(self):
with in_tempdir():
unbinned = pl.DataFrame([
["S3.1", "sample_1", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1
["S3.1", "sample_1", "TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC", 5, 10, "Root"], # 2

["S3.1", "sample_2", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1
["S3.1", "sample_2", "TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC", 5, 10, "Root"], # 2

["S3.1", "sample_3", "ATCGACTGACTTGATCGATCTTTGACGACGAGAGAGAGAGCGACGCGCCGAGAGGTTTCA", 5, 10, "Root"], # 3
["S3.1", "sample_3", "TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 4
["S3.1", "sample_3", "TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 5

["S3.1", "sample_4", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1
["S3.1", "sample_4", "TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 4
["S3.1", "sample_4", "TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 5
], orient="row", schema=OTU_TABLE_COLUMNS)
unbinned_path = "unbinned.otu_table.tsv"
unbinned.write_csv(unbinned_path, separator="\t")

expected_names = [
"sample_1",
"sample_2",
"sample_3",
"sample_4",
]

signatures_path = processing(unbinned_path, output_path="./signatures.sig", threads=1, samples_per_group=4)
signatures = [s for s in load_file_as_signatures(signatures_path)]
observed_names = [s.name for s in signatures]
self.assertEqual(sorted(expected_names), sorted(observed_names))

sample_1_sig = signatures[observed_names.index("sample_1")]
sample_2_sig = signatures[observed_names.index("sample_2")]
sample_3_sig = signatures[observed_names.index("sample_3")]
sample_4_sig = signatures[observed_names.index("sample_4")]

self.assertEqual(sample_1_sig.jaccard(sample_2_sig), 1.0)
self.assertEqual(sample_1_sig.jaccard(sample_3_sig), 0.0)
self.assertEqual(sample_1_sig.jaccard(sample_4_sig), 0.25)
self.assertEqual(sample_2_sig.jaccard(sample_3_sig), 0.0)
self.assertEqual(sample_2_sig.jaccard(sample_4_sig), 0.25)
self.assertEqual(sample_3_sig.jaccard(sample_4_sig), 0.5)

def test_sketch_samples_small_groups(self):
with in_tempdir():
unbinned = pl.DataFrame([
["S3.1", "sample_1", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1
["S3.1", "sample_1", "TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC", 5, 10, "Root"], # 2

["S3.1", "sample_2", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1
["S3.1", "sample_2", "TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC", 5, 10, "Root"], # 2

["S3.1", "sample_3", "ATCGACTGACTTGATCGATCTTTGACGACGAGAGAGAGAGCGACGCGCCGAGAGGTTTCA", 5, 10, "Root"], # 3
["S3.1", "sample_3", "TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 4
["S3.1", "sample_3", "TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 5

["S3.1", "sample_4", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1
["S3.1", "sample_4", "TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 4
["S3.1", "sample_4", "TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 5
], orient="row", schema=OTU_TABLE_COLUMNS)
unbinned_path = "unbinned.otu_table.tsv"
unbinned.write_csv(unbinned_path, separator="\t")

expected_names = [
"sample_1",
"sample_2",
"sample_3",
"sample_4",
]

signatures_path = processing(unbinned_path, output_path="./signatures.sig", threads=1, samples_per_group=2)
signatures = [s for s in load_file_as_signatures(signatures_path)]
observed_names = [s.name for s in signatures]
self.assertEqual(sorted(expected_names), sorted(observed_names))

sample_1_sig = signatures[observed_names.index("sample_1")]
sample_2_sig = signatures[observed_names.index("sample_2")]
sample_3_sig = signatures[observed_names.index("sample_3")]
sample_4_sig = signatures[observed_names.index("sample_4")]

self.assertEqual(sample_1_sig.jaccard(sample_2_sig), 1.0)
self.assertEqual(sample_1_sig.jaccard(sample_3_sig), 0.0)
self.assertEqual(sample_1_sig.jaccard(sample_4_sig), 0.25)
self.assertEqual(sample_2_sig.jaccard(sample_3_sig), 0.0)
self.assertEqual(sample_2_sig.jaccard(sample_4_sig), 0.25)
self.assertEqual(sample_3_sig.jaccard(sample_4_sig), 0.5)


if __name__ == '__main__':
unittest.main()

0 comments on commit 32d8ac4

Please sign in to comment.