diff --git a/binchicken/workflow/scripts/sketch_samples.py b/binchicken/workflow/scripts/sketch_samples.py index e582bb42..bab5d6fd 100644 --- a/binchicken/workflow/scripts/sketch_samples.py +++ b/binchicken/workflow/scripts/sketch_samples.py @@ -10,6 +10,7 @@ from sourmash.sourmash_args import SaveSignaturesToLocation from concurrent.futures import ProcessPoolExecutor import extern +import re SINGLEM_OTU_TABLE_SCHEMA = { "gene": str, @@ -30,38 +31,62 @@ def process_groups(groups, output_path): signature = SourmashSignature(mh, name=sample) save_sigs.add(signature) -def processing(unbinned_path, output_path, taxa_of_interest=None, threads=1): - unbinned = pl.read_csv(unbinned_path, separator="\t", schema_overrides=SINGLEM_OTU_TABLE_SCHEMA) - - if taxa_of_interest: - logging.info(f"Filtering for taxa of interest: {taxa_of_interest}") - unbinned = unbinned.filter( - pl.col("taxonomy").str.contains(taxa_of_interest) - ) - +def processing(unbinned_path, output_path, taxa_of_interest=None, threads=1, samples_per_group=1000): output_dir = os.path.dirname(output_path) - logging.info("Grouping samples") - groups = [(s[0], d.get_column("sequence").to_list()) for s,d in unbinned.set_sorted("sample").select("sample", "sequence").group_by(["sample"])] - threads = min(threads, len(groups)) - - # Distribute groups among threads more evenly - grouped = [[] for _ in range(threads)] - for i, group in enumerate(groups): - grouped[i % threads].append(group) - - del groups - - logging.info("Generating sketches in separate threads") - with ProcessPoolExecutor(max_workers=threads) as executor: - futures = [] - for i, group_subset in enumerate(grouped): - output_subpath = os.path.join(output_dir, f"signatures_thread_{i}.sig") - future = executor.submit(process_groups, group_subset, output_subpath) - futures.append(future) - - for future in futures: - future.result() + with open(unbinned_path) as f: + logging.info(f"Reading unbinned OTU table from {unbinned_path}") + if taxa_of_interest: + logging.info(f"Filtering for taxa of interest: {taxa_of_interest}") + + logging.info("Generating sketches in separate threads") + with ProcessPoolExecutor(max_workers=threads) as executor: + current_sample = "" + current_sequences = [] + i = 0 + group_subset = [] + futures = [] + for line in f: + line = line.strip() + if line == "gene\tsample\tsequence\tnum_hits\tcoverage\ttaxonomy": + continue + + sample = line.split("\t")[1] + sequence = line.split("\t")[2] + + if taxa_of_interest: + taxonomy = line.split("\t")[5] + if not re.search(taxa_of_interest, taxonomy): + continue + + if not current_sample: + current_sample = sample + + if sample != current_sample: + group_subset.append((current_sample, current_sequences)) + current_sample = sample + current_sequences = [sequence] + + if len(group_subset) == samples_per_group: + output_subpath = os.path.join(output_dir, f"signatures_thread_{i}.sig") + future = executor.submit(process_groups, group_subset, output_subpath) + futures.append(future) + i += 1 + group_subset = [] + else: + current_sequences.append(sequence) + + # Submit any remaining groups that are smaller than samples_per_group + if sample == current_sample: + group_subset.append((current_sample, current_sequences)) + + if group_subset: + output_subpath = os.path.join(output_dir, f"signatures_thread_{i}.sig") + future = executor.submit(process_groups, group_subset, output_subpath) + futures.append(future) + + for future in futures: + future.result() logging.info("Concatenating sketches") extern.run(f"sourmash sig cat {os.path.join(output_dir, 'signatures_thread_*.sig')} -o {output_path}") diff --git a/test/test_sketch_samples.py b/test/test_sketch_samples.py index b3074a8f..8e5d8cd3 100644 --- a/test/test_sketch_samples.py +++ b/test/test_sketch_samples.py @@ -108,6 +108,94 @@ def test_sketch_samples_taxa_of_interest(self): self.assertEqual(sample_2_sig.jaccard(sample_4_sig), 0.25) self.assertEqual(sample_3_sig.jaccard(sample_4_sig), 0.5) + def test_sketch_samples_exact_groups(self): + with in_tempdir(): + unbinned = pl.DataFrame([ + ["S3.1", "sample_1", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1 + ["S3.1", "sample_1", "TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC", 5, 10, "Root"], # 2 + + ["S3.1", "sample_2", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1 + ["S3.1", "sample_2", "TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC", 5, 10, "Root"], # 2 + + ["S3.1", "sample_3", "ATCGACTGACTTGATCGATCTTTGACGACGAGAGAGAGAGCGACGCGCCGAGAGGTTTCA", 5, 10, "Root"], # 3 + ["S3.1", "sample_3", "TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 4 + ["S3.1", "sample_3", "TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 5 + + ["S3.1", "sample_4", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1 + ["S3.1", "sample_4", "TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 4 + ["S3.1", "sample_4", "TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 5 + ], orient="row", schema=OTU_TABLE_COLUMNS) + unbinned_path = "unbinned.otu_table.tsv" + unbinned.write_csv(unbinned_path, separator="\t") + + expected_names = [ + "sample_1", + "sample_2", + "sample_3", + "sample_4", + ] + + signatures_path = processing(unbinned_path, output_path="./signatures.sig", threads=1, samples_per_group=4) + signatures = [s for s in load_file_as_signatures(signatures_path)] + observed_names = [s.name for s in signatures] + self.assertEqual(sorted(expected_names), sorted(observed_names)) + + sample_1_sig = signatures[observed_names.index("sample_1")] + sample_2_sig = signatures[observed_names.index("sample_2")] + sample_3_sig = signatures[observed_names.index("sample_3")] + sample_4_sig = signatures[observed_names.index("sample_4")] + + self.assertEqual(sample_1_sig.jaccard(sample_2_sig), 1.0) + self.assertEqual(sample_1_sig.jaccard(sample_3_sig), 0.0) + self.assertEqual(sample_1_sig.jaccard(sample_4_sig), 0.25) + self.assertEqual(sample_2_sig.jaccard(sample_3_sig), 0.0) + self.assertEqual(sample_2_sig.jaccard(sample_4_sig), 0.25) + self.assertEqual(sample_3_sig.jaccard(sample_4_sig), 0.5) + + def test_sketch_samples_small_groups(self): + with in_tempdir(): + unbinned = pl.DataFrame([ + ["S3.1", "sample_1", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1 + ["S3.1", "sample_1", "TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC", 5, 10, "Root"], # 2 + + ["S3.1", "sample_2", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1 + ["S3.1", "sample_2", "TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC", 5, 10, "Root"], # 2 + + ["S3.1", "sample_3", "ATCGACTGACTTGATCGATCTTTGACGACGAGAGAGAGAGCGACGCGCCGAGAGGTTTCA", 5, 10, "Root"], # 3 + ["S3.1", "sample_3", "TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 4 + ["S3.1", "sample_3", "TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 5 + + ["S3.1", "sample_4", "ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA", 5, 10, "Root"], # 1 + ["S3.1", "sample_4", "TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 4 + ["S3.1", "sample_4", "TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG", 5, 10, "Root"], # 5 + ], orient="row", schema=OTU_TABLE_COLUMNS) + unbinned_path = "unbinned.otu_table.tsv" + unbinned.write_csv(unbinned_path, separator="\t") + + expected_names = [ + "sample_1", + "sample_2", + "sample_3", + "sample_4", + ] + + signatures_path = processing(unbinned_path, output_path="./signatures.sig", threads=1, samples_per_group=2) + signatures = [s for s in load_file_as_signatures(signatures_path)] + observed_names = [s.name for s in signatures] + self.assertEqual(sorted(expected_names), sorted(observed_names)) + + sample_1_sig = signatures[observed_names.index("sample_1")] + sample_2_sig = signatures[observed_names.index("sample_2")] + sample_3_sig = signatures[observed_names.index("sample_3")] + sample_4_sig = signatures[observed_names.index("sample_4")] + + self.assertEqual(sample_1_sig.jaccard(sample_2_sig), 1.0) + self.assertEqual(sample_1_sig.jaccard(sample_3_sig), 0.0) + self.assertEqual(sample_1_sig.jaccard(sample_4_sig), 0.25) + self.assertEqual(sample_2_sig.jaccard(sample_3_sig), 0.0) + self.assertEqual(sample_2_sig.jaccard(sample_4_sig), 0.25) + self.assertEqual(sample_3_sig.jaccard(sample_4_sig), 0.5) + if __name__ == '__main__': unittest.main()