From 241dbc550a0a9ab623d6235161d5633f3e5a93a1 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Fri, 8 Jul 2022 07:30:53 -0700 Subject: [PATCH 01/20] move most CounterGather tests over to index protocol tests --- tests/test_index.py | 539 +--------------------------------- tests/test_index_protocol.py | 554 +++++++++++++++++++++++++++++++++++ 2 files changed, 557 insertions(+), 536 deletions(-) diff --git a/tests/test_index.py b/tests/test_index.py index bd216b8f32..0f27e974c6 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1543,545 +1543,12 @@ def is_found(ss, xx): assert is_found(ss63, results) ### -### CounterGather tests +### CounterGather tests - implementation specific. See test_index_protocol +### for protocol tests. ### - -def _consume_all(query_mh, counter, threshold_bp=0): - results = [] - query_mh = query_mh.to_mutable() - - last_intersect_size = None - while 1: - result = counter.peek(query_mh, threshold_bp) - if not result: - break - - sr, intersect_mh = result - print(sr.signature.name, len(intersect_mh)) - if last_intersect_size: - assert len(intersect_mh) <= last_intersect_size - - last_intersect_size = len(intersect_mh) - - counter.consume(intersect_mh) - query_mh.remove_many(intersect_mh.hashes) - - results.append((sr, len(intersect_mh))) - - return results - - -def test_counter_gather_1(): - # check a contrived set of non-overlapping gather results, - # generated via CounterGather - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - match_mh_1 = query_mh.copy_and_clear() - match_mh_1.add_many(range(0, 10)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - match_mh_2 = query_mh.copy_and_clear() - match_mh_2.add_many(range(10, 15)) - match_ss_2 = SourmashSignature(match_mh_2, name='match2') - - match_mh_3 = query_mh.copy_and_clear() - match_mh_3.add_many(range(15, 17)) - match_ss_3 = SourmashSignature(match_mh_3, name='match3') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1) - counter.add(match_ss_2) - counter.add(match_ss_3) - - results = _consume_all(query_ss.minhash, counter) - - expected = (['match1', 10], - ['match2', 5], - ['match3', 2],) - assert len(results) == len(expected), results - - for (sr, size), (exp_name, exp_size) in zip(results, expected): - sr_name = sr.signature.name.split()[0] - - assert sr_name == exp_name - assert size == exp_size - - -def test_counter_gather_1_b(): - # check a contrived set of somewhat-overlapping gather results, - # generated via CounterGather. Here the overlaps are structured - # so that the gather results are the same as those in - # test_counter_gather_1(), even though the overlaps themselves are - # larger. - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - match_mh_1 = query_mh.copy_and_clear() - match_mh_1.add_many(range(0, 10)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - match_mh_2 = query_mh.copy_and_clear() - match_mh_2.add_many(range(7, 15)) - match_ss_2 = SourmashSignature(match_mh_2, name='match2') - - match_mh_3 = query_mh.copy_and_clear() - match_mh_3.add_many(range(13, 17)) - match_ss_3 = SourmashSignature(match_mh_3, name='match3') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1) - counter.add(match_ss_2) - counter.add(match_ss_3) - - results = _consume_all(query_ss.minhash, counter) - - expected = (['match1', 10], - ['match2', 5], - ['match3', 2],) - assert len(results) == len(expected), results - - for (sr, size), (exp_name, exp_size) in zip(results, expected): - sr_name = sr.signature.name.split()[0] - - assert sr_name == exp_name - assert size == exp_size - - -def test_counter_gather_1_c_with_threshold(): - # check a contrived set of somewhat-overlapping gather results, - # generated via CounterGather. Here the overlaps are structured - # so that the gather results are the same as those in - # test_counter_gather_1(), even though the overlaps themselves are - # larger. - # use a threshold, here. - - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - match_mh_1 = query_mh.copy_and_clear() - match_mh_1.add_many(range(0, 10)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - match_mh_2 = query_mh.copy_and_clear() - match_mh_2.add_many(range(7, 15)) - match_ss_2 = SourmashSignature(match_mh_2, name='match2') - - match_mh_3 = query_mh.copy_and_clear() - match_mh_3.add_many(range(13, 17)) - match_ss_3 = SourmashSignature(match_mh_3, name='match3') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1) - counter.add(match_ss_2) - counter.add(match_ss_3) - - results = _consume_all(query_ss.minhash, counter, - threshold_bp=3) - - expected = (['match1', 10], - ['match2', 5]) - assert len(results) == len(expected), results - - for (sr, size), (exp_name, exp_size) in zip(results, expected): - sr_name = sr.signature.name.split()[0] - - assert sr_name == exp_name - assert size == exp_size - - -def test_counter_gather_1_d_diff_scaled(): - # test as above, but with different scaled. - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - match_mh_1 = query_mh.copy_and_clear().downsample(scaled=10) - match_mh_1.add_many(range(0, 10)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - match_mh_2 = query_mh.copy_and_clear().downsample(scaled=20) - match_mh_2.add_many(range(7, 15)) - match_ss_2 = SourmashSignature(match_mh_2, name='match2') - - match_mh_3 = query_mh.copy_and_clear().downsample(scaled=30) - match_mh_3.add_many(range(13, 17)) - match_ss_3 = SourmashSignature(match_mh_3, name='match3') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1) - counter.add(match_ss_2) - counter.add(match_ss_3) - - results = _consume_all(query_ss.minhash, counter) - - expected = (['match1', 10], - ['match2', 5], - ['match3', 2],) - assert len(results) == len(expected), results - - for (sr, size), (exp_name, exp_size) in zip(results, expected): - sr_name = sr.signature.name.split()[0] - - assert sr_name == exp_name - assert size == exp_size - - -def test_counter_gather_1_d_diff_scaled_query(): - # test as above, but with different scaled for QUERY. - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - - match_mh_1 = query_mh.copy_and_clear().downsample(scaled=10) - match_mh_1.add_many(range(0, 10)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - match_mh_2 = query_mh.copy_and_clear().downsample(scaled=20) - match_mh_2.add_many(range(7, 15)) - match_ss_2 = SourmashSignature(match_mh_2, name='match2') - - match_mh_3 = query_mh.copy_and_clear().downsample(scaled=30) - match_mh_3.add_many(range(13, 17)) - match_ss_3 = SourmashSignature(match_mh_3, name='match3') - - # downsample query now - - query_ss = SourmashSignature(query_mh.downsample(scaled=100), name='query') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1) - counter.add(match_ss_2) - counter.add(match_ss_3) - - results = _consume_all(query_ss.minhash, counter) - - expected = (['match1', 10], - ['match2', 5], - ['match3', 2],) - assert len(results) == len(expected), results - - for (sr, size), (exp_name, exp_size) in zip(results, expected): - sr_name = sr.signature.name.split()[0] - - assert sr_name == exp_name - assert size == exp_size - - -def test_counter_gather_1_e_abund_query(): - # test as above, but abund query - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1, track_abundance=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - match_mh_1 = query_mh.copy_and_clear().flatten() - match_mh_1.add_many(range(0, 10)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - match_mh_2 = query_mh.copy_and_clear().flatten() - match_mh_2.add_many(range(7, 15)) - match_ss_2 = SourmashSignature(match_mh_2, name='match2') - - match_mh_3 = query_mh.copy_and_clear().flatten() - match_mh_3.add_many(range(13, 17)) - match_ss_3 = SourmashSignature(match_mh_3, name='match3') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1) - counter.add(match_ss_2) - counter.add(match_ss_3) - - # must flatten before peek! - results = _consume_all(query_ss.minhash.flatten(), counter) - - expected = (['match1', 10], - ['match2', 5], - ['match3', 2],) - assert len(results) == len(expected), results - - for (sr, size), (exp_name, exp_size) in zip(results, expected): - sr_name = sr.signature.name.split()[0] - - assert sr_name == exp_name - assert size == exp_size - - -def test_counter_gather_1_f_abund_match(): - # test as above, but abund query - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1, track_abundance=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh.flatten(), name='query') - - match_mh_1 = query_mh.copy_and_clear() - match_mh_1.add_many(range(0, 10)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - match_mh_2 = query_mh.copy_and_clear() - match_mh_2.add_many(range(7, 15)) - match_ss_2 = SourmashSignature(match_mh_2, name='match2') - - match_mh_3 = query_mh.copy_and_clear() - match_mh_3.add_many(range(13, 17)) - match_ss_3 = SourmashSignature(match_mh_3, name='match3') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1) - counter.add(match_ss_2) - counter.add(match_ss_3) - - # must flatten before peek! - results = _consume_all(query_ss.minhash.flatten(), counter) - - expected = (['match1', 10], - ['match2', 5], - ['match3', 2],) - assert len(results) == len(expected), results - - for (sr, size), (exp_name, exp_size) in zip(results, expected): - sr_name = sr.signature.name.split()[0] - - assert sr_name == exp_name - assert size == exp_size - - -def test_counter_gather_2(): - # check basic set of gather results on semi-real data, - # generated via CounterGather - testdata_combined = utils.get_test_data('gather/combined.sig') - testdata_glob = utils.get_test_data('gather/GCF*.sig') - testdata_sigs = glob.glob(testdata_glob) - - query_ss = sourmash.load_one_signature(testdata_combined, ksize=21) - subject_sigs = [ (sourmash.load_one_signature(t, ksize=21), t) - for t in testdata_sigs ] - - # load up the counter - counter = CounterGather(query_ss.minhash) - for ss, loc in subject_sigs: - counter.add(ss, loc) - - results = _consume_all(query_ss.minhash, counter) - - expected = (['NC_003198.1', 487], - ['NC_000853.1', 192], - ['NC_011978.1', 169], - ['NC_002163.1', 157], - ['NC_003197.2', 152], - ['NC_009486.1', 92], - ['NC_006905.1', 76], - ['NC_011080.1', 59], - ['NC_011274.1', 42], - ['NC_006511.1', 31], - ['NC_011294.1', 7], - ['NC_004631.1', 2]) - assert len(results) == len(expected) - - for (sr, size), (exp_name, exp_size) in zip(results, expected): - sr_name = sr.signature.name.split()[0] - print(sr_name, size) - - assert sr_name == exp_name - assert size == exp_size - - -def test_counter_gather_exact_match(): - # query == match - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') - - results = _consume_all(query_ss.minhash, counter) - assert len(results) == 1 - (sr, intersect_mh) = results[0] - - assert sr.score == 1.0 - assert sr.signature == query_ss - assert sr.location == 'somewhere over the rainbow' - - -def test_counter_gather_add_after_peek(): - # cannot add after peek or consume - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') - - counter.peek(query_ss.minhash) - - with pytest.raises(ValueError): - counter.add(query_ss, "try again") - - -def test_counter_gather_add_after_consume(): - # cannot add after peek or consume - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') - - counter.consume(query_ss.minhash) - - with pytest.raises(ValueError): - counter.add(query_ss, "try again") - - -def test_counter_gather_consume_empty_intersect(): - # check that consume works fine when there is an empty signature. - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') - - # nothing really happens here :laugh:, just making sure there's no error - counter.consume(query_ss.minhash.copy_and_clear()) - - -def test_counter_gather_empty_initial_query(): - # check empty initial query - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_ss = SourmashSignature(query_mh, name='query') - - match_mh_1 = query_mh.copy_and_clear() - match_mh_1.add_many(range(0, 10)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1, require_overlap=False) - - assert counter.peek(query_ss.minhash) == [] - - -def test_counter_gather_num_query(): - # check num query - query_mh = sourmash.MinHash(n=500, ksize=31) - query_mh.add_many(range(0, 10)) - query_ss = SourmashSignature(query_mh, name='query') - - with pytest.raises(ValueError): - counter = CounterGather(query_ss.minhash) - - -def test_counter_gather_empty_cur_query(): - # test empty cur query - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') - - cur_query_mh = query_ss.minhash.copy_and_clear() - results = _consume_all(cur_query_mh, counter) - assert results == [] - - -def test_counter_gather_add_num_matchy(): - # test add num query - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - match_mh = sourmash.MinHash(n=500, ksize=31) - match_mh.add_many(range(0, 20)) - match_ss = SourmashSignature(match_mh, name='query') - - # load up the counter - counter = CounterGather(query_ss.minhash) - with pytest.raises(ValueError): - counter.add(match_ss, 'somewhere over the rainbow') - - -def test_counter_gather_bad_cur_query(): - # test cur query that is not subset of original query - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') - - cur_query_mh = query_ss.minhash.copy_and_clear() - cur_query_mh.add_many(range(20, 30)) - with pytest.raises(ValueError): - counter.peek(cur_query_mh) - - -def test_counter_gather_add_no_overlap(): - # check adding match with no overlap w/query - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 10)) - query_ss = SourmashSignature(query_mh, name='query') - - match_mh_1 = query_mh.copy_and_clear() - match_mh_1.add_many(range(10, 20)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - # load up the counter - counter = CounterGather(query_ss.minhash) - with pytest.raises(ValueError): - counter.add(match_ss_1) - - assert counter.peek(query_ss.minhash) == [] - - -def test_counter_gather_big_threshold(): - # check 'peek' with a huge threshold - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_mh.add_many(range(0, 20)) - query_ss = SourmashSignature(query_mh, name='query') - - match_mh_1 = query_mh.copy_and_clear() - match_mh_1.add_many(range(0, 10)) - match_ss_1 = SourmashSignature(match_mh_1, name='match1') - - # load up the counter - counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1) - - # impossible threshold: - threshold_bp=30*query_ss.minhash.scaled - results = counter.peek(query_ss.minhash, threshold_bp=threshold_bp) - assert results == [] - - -def test_counter_gather_empty_counter(): - # check empty counter - query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) - query_ss = SourmashSignature(query_mh, name='query') - - # empty counter! - counter = CounterGather(query_ss.minhash) - - assert counter.peek(query_ss.minhash) == [] - - def test_counter_gather_3_test_consume(): - # open-box testing of consume(...) + # open-box testing of CounterGather.consume(...) query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) query_mh.add_many(range(0, 20)) query_ss = SourmashSignature(query_mh, name='query') diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index 15fd70aad0..f81f736e9a 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -4,6 +4,7 @@ """ import pytest +import glob import sourmash from sourmash import SourmashSignature @@ -463,3 +464,556 @@ def test_gather_threshold_5(index_obj): containment, match_sig, name = results[0] assert containment == 1.0 assert match_sig.minhash == ss2.minhash + + +### +### CounterGather tests +### + + +def basic_counter_gather(runtmp): + "build a basic CounterGather class" + from sourmash.index import CounterGather + return CounterGather + +@pytest.fixture(params=[basic_counter_gather, + ] +) +def counter_gather_constructor(request, runtmp): + build_fn = request.param + + # build on demand + return build_fn(runtmp) + + +def _consume_all(query_mh, counter, threshold_bp=0): + results = [] + query_mh = query_mh.to_mutable() + + last_intersect_size = None + while 1: + result = counter.peek(query_mh, threshold_bp) + if not result: + break + + sr, intersect_mh = result + print(sr.signature.name, len(intersect_mh)) + if last_intersect_size: + assert len(intersect_mh) <= last_intersect_size + + last_intersect_size = len(intersect_mh) + + counter.consume(intersect_mh) + query_mh.remove_many(intersect_mh.hashes) + + results.append((sr, len(intersect_mh))) + + return results + + +def test_counter_gather_1(counter_gather_constructor): + # check a contrived set of non-overlapping gather results, + # generated via CounterGather + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear() + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + match_mh_2 = query_mh.copy_and_clear() + match_mh_2.add_many(range(10, 15)) + match_ss_2 = SourmashSignature(match_mh_2, name='match2') + + match_mh_3 = query_mh.copy_and_clear() + match_mh_3.add_many(range(15, 17)) + match_ss_3 = SourmashSignature(match_mh_3, name='match3') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(match_ss_1) + counter.add(match_ss_2) + counter.add(match_ss_3) + + results = _consume_all(query_ss.minhash, counter) + + expected = (['match1', 10], + ['match2', 5], + ['match3', 2],) + assert len(results) == len(expected), results + + for (sr, size), (exp_name, exp_size) in zip(results, expected): + sr_name = sr.signature.name.split()[0] + + assert sr_name == exp_name + assert size == exp_size + + +def test_counter_gather_1_b(counter_gather_constructor): + # check a contrived set of somewhat-overlapping gather results, + # generated via CounterGather. Here the overlaps are structured + # so that the gather results are the same as those in + # test_counter_gather_1(), even though the overlaps themselves are + # larger. + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear() + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + match_mh_2 = query_mh.copy_and_clear() + match_mh_2.add_many(range(7, 15)) + match_ss_2 = SourmashSignature(match_mh_2, name='match2') + + match_mh_3 = query_mh.copy_and_clear() + match_mh_3.add_many(range(13, 17)) + match_ss_3 = SourmashSignature(match_mh_3, name='match3') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(match_ss_1) + counter.add(match_ss_2) + counter.add(match_ss_3) + + results = _consume_all(query_ss.minhash, counter) + + expected = (['match1', 10], + ['match2', 5], + ['match3', 2],) + assert len(results) == len(expected), results + + for (sr, size), (exp_name, exp_size) in zip(results, expected): + sr_name = sr.signature.name.split()[0] + + assert sr_name == exp_name + assert size == exp_size + + +def test_counter_gather_1_c_with_threshold(counter_gather_constructor): + # check a contrived set of somewhat-overlapping gather results, + # generated via CounterGather. Here the overlaps are structured + # so that the gather results are the same as those in + # test_counter_gather_1(), even though the overlaps themselves are + # larger. + # use a threshold, here. + + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear() + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + match_mh_2 = query_mh.copy_and_clear() + match_mh_2.add_many(range(7, 15)) + match_ss_2 = SourmashSignature(match_mh_2, name='match2') + + match_mh_3 = query_mh.copy_and_clear() + match_mh_3.add_many(range(13, 17)) + match_ss_3 = SourmashSignature(match_mh_3, name='match3') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(match_ss_1) + counter.add(match_ss_2) + counter.add(match_ss_3) + + results = _consume_all(query_ss.minhash, counter, + threshold_bp=3) + + expected = (['match1', 10], + ['match2', 5]) + assert len(results) == len(expected), results + + for (sr, size), (exp_name, exp_size) in zip(results, expected): + sr_name = sr.signature.name.split()[0] + + assert sr_name == exp_name + assert size == exp_size + + +def test_counter_gather_1_d_diff_scaled(counter_gather_constructor): + # test as above, but with different scaled. + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear().downsample(scaled=10) + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + match_mh_2 = query_mh.copy_and_clear().downsample(scaled=20) + match_mh_2.add_many(range(7, 15)) + match_ss_2 = SourmashSignature(match_mh_2, name='match2') + + match_mh_3 = query_mh.copy_and_clear().downsample(scaled=30) + match_mh_3.add_many(range(13, 17)) + match_ss_3 = SourmashSignature(match_mh_3, name='match3') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(match_ss_1) + counter.add(match_ss_2) + counter.add(match_ss_3) + + results = _consume_all(query_ss.minhash, counter) + + expected = (['match1', 10], + ['match2', 5], + ['match3', 2],) + assert len(results) == len(expected), results + + for (sr, size), (exp_name, exp_size) in zip(results, expected): + sr_name = sr.signature.name.split()[0] + + assert sr_name == exp_name + assert size == exp_size + + +def test_counter_gather_1_d_diff_scaled_query(counter_gather_constructor): + # test as above, but with different scaled for QUERY. + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + + match_mh_1 = query_mh.copy_and_clear().downsample(scaled=10) + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + match_mh_2 = query_mh.copy_and_clear().downsample(scaled=20) + match_mh_2.add_many(range(7, 15)) + match_ss_2 = SourmashSignature(match_mh_2, name='match2') + + match_mh_3 = query_mh.copy_and_clear().downsample(scaled=30) + match_mh_3.add_many(range(13, 17)) + match_ss_3 = SourmashSignature(match_mh_3, name='match3') + + # downsample query now - + query_ss = SourmashSignature(query_mh.downsample(scaled=100), name='query') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(match_ss_1) + counter.add(match_ss_2) + counter.add(match_ss_3) + + results = _consume_all(query_ss.minhash, counter) + + expected = (['match1', 10], + ['match2', 5], + ['match3', 2],) + assert len(results) == len(expected), results + + for (sr, size), (exp_name, exp_size) in zip(results, expected): + sr_name = sr.signature.name.split()[0] + + assert sr_name == exp_name + assert size == exp_size + + +def test_counter_gather_1_e_abund_query(counter_gather_constructor): + # test as above, but abund query + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1, track_abundance=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear().flatten() + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + match_mh_2 = query_mh.copy_and_clear().flatten() + match_mh_2.add_many(range(7, 15)) + match_ss_2 = SourmashSignature(match_mh_2, name='match2') + + match_mh_3 = query_mh.copy_and_clear().flatten() + match_mh_3.add_many(range(13, 17)) + match_ss_3 = SourmashSignature(match_mh_3, name='match3') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(match_ss_1) + counter.add(match_ss_2) + counter.add(match_ss_3) + + # must flatten before peek! + results = _consume_all(query_ss.minhash.flatten(), counter) + + expected = (['match1', 10], + ['match2', 5], + ['match3', 2],) + assert len(results) == len(expected), results + + for (sr, size), (exp_name, exp_size) in zip(results, expected): + sr_name = sr.signature.name.split()[0] + + assert sr_name == exp_name + assert size == exp_size + + +def test_counter_gather_1_f_abund_match(counter_gather_constructor): + # test as above, but abund query + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1, track_abundance=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh.flatten(), name='query') + + match_mh_1 = query_mh.copy_and_clear() + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + match_mh_2 = query_mh.copy_and_clear() + match_mh_2.add_many(range(7, 15)) + match_ss_2 = SourmashSignature(match_mh_2, name='match2') + + match_mh_3 = query_mh.copy_and_clear() + match_mh_3.add_many(range(13, 17)) + match_ss_3 = SourmashSignature(match_mh_3, name='match3') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(match_ss_1) + counter.add(match_ss_2) + counter.add(match_ss_3) + + # must flatten before peek! + results = _consume_all(query_ss.minhash.flatten(), counter) + + expected = (['match1', 10], + ['match2', 5], + ['match3', 2],) + assert len(results) == len(expected), results + + for (sr, size), (exp_name, exp_size) in zip(results, expected): + sr_name = sr.signature.name.split()[0] + + assert sr_name == exp_name + assert size == exp_size + + +def test_counter_gather_2(counter_gather_constructor): + # check basic set of gather results on semi-real data, + # generated via CounterGather + testdata_combined = utils.get_test_data('gather/combined.sig') + testdata_glob = utils.get_test_data('gather/GCF*.sig') + testdata_sigs = glob.glob(testdata_glob) + + query_ss = sourmash.load_one_signature(testdata_combined, ksize=21) + subject_sigs = [ (sourmash.load_one_signature(t, ksize=21), t) + for t in testdata_sigs ] + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + for ss, loc in subject_sigs: + counter.add(ss, loc) + + results = _consume_all(query_ss.minhash, counter) + + expected = (['NC_003198.1', 487], + ['NC_000853.1', 192], + ['NC_011978.1', 169], + ['NC_002163.1', 157], + ['NC_003197.2', 152], + ['NC_009486.1', 92], + ['NC_006905.1', 76], + ['NC_011080.1', 59], + ['NC_011274.1', 42], + ['NC_006511.1', 31], + ['NC_011294.1', 7], + ['NC_004631.1', 2]) + assert len(results) == len(expected) + + for (sr, size), (exp_name, exp_size) in zip(results, expected): + sr_name = sr.signature.name.split()[0] + print(sr_name, size) + + assert sr_name == exp_name + assert size == exp_size + + +def test_counter_gather_exact_match(counter_gather_constructor): + # query == match + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(query_ss, 'somewhere over the rainbow') + + results = _consume_all(query_ss.minhash, counter) + assert len(results) == 1 + (sr, intersect_mh) = results[0] + + assert sr.score == 1.0 + assert sr.signature == query_ss + assert sr.location == 'somewhere over the rainbow' + + +def test_counter_gather_add_after_peek(counter_gather_constructor): + # cannot add after peek or consume + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(query_ss, 'somewhere over the rainbow') + + counter.peek(query_ss.minhash) + + with pytest.raises(ValueError): + counter.add(query_ss, "try again") + + +def test_counter_gather_add_after_consume(counter_gather_constructor): + # cannot add after peek or consume + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(query_ss, 'somewhere over the rainbow') + + counter.consume(query_ss.minhash) + + with pytest.raises(ValueError): + counter.add(query_ss, "try again") + + +def test_counter_gather_consume_empty_intersect(counter_gather_constructor): + # check that consume works fine when there is an empty signature. + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(query_ss, 'somewhere over the rainbow') + + # nothing really happens here :laugh:, just making sure there's no error + counter.consume(query_ss.minhash.copy_and_clear()) + + +def test_counter_gather_empty_initial_query(counter_gather_constructor): + # check empty initial query + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear() + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(match_ss_1, require_overlap=False) + + assert counter.peek(query_ss.minhash) == [] + + +def test_counter_gather_num_query(counter_gather_constructor): + # check num query + query_mh = sourmash.MinHash(n=500, ksize=31) + query_mh.add_many(range(0, 10)) + query_ss = SourmashSignature(query_mh, name='query') + + with pytest.raises(ValueError): + counter = counter_gather_constructor(query_ss.minhash) + + +def test_counter_gather_empty_cur_query(counter_gather_constructor): + # test empty cur query + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(query_ss, 'somewhere over the rainbow') + + cur_query_mh = query_ss.minhash.copy_and_clear() + results = _consume_all(cur_query_mh, counter) + assert results == [] + + +def test_counter_gather_add_num_matchy(counter_gather_constructor): + # test add num query + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh = sourmash.MinHash(n=500, ksize=31) + match_mh.add_many(range(0, 20)) + match_ss = SourmashSignature(match_mh, name='query') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + with pytest.raises(ValueError): + counter.add(match_ss, 'somewhere over the rainbow') + + +def test_counter_gather_bad_cur_query(counter_gather_constructor): + # test cur query that is not subset of original query + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(query_ss, 'somewhere over the rainbow') + + cur_query_mh = query_ss.minhash.copy_and_clear() + cur_query_mh.add_many(range(20, 30)) + with pytest.raises(ValueError): + counter.peek(cur_query_mh) + + +def test_counter_gather_add_no_overlap(counter_gather_constructor): + # check adding match with no overlap w/query + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 10)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear() + match_mh_1.add_many(range(10, 20)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + with pytest.raises(ValueError): + counter.add(match_ss_1) + + assert counter.peek(query_ss.minhash) == [] + + +def test_counter_gather_big_threshold(counter_gather_constructor): + # check 'peek' with a huge threshold + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + match_mh_1 = query_mh.copy_and_clear() + match_mh_1.add_many(range(0, 10)) + match_ss_1 = SourmashSignature(match_mh_1, name='match1') + + # load up the counter + counter = counter_gather_constructor(query_ss.minhash) + counter.add(match_ss_1) + + # impossible threshold: + threshold_bp=30*query_ss.minhash.scaled + results = counter.peek(query_ss.minhash, threshold_bp=threshold_bp) + assert results == [] + + +def test_counter_gather_empty_counter(counter_gather_constructor): + # check empty counter + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_ss = SourmashSignature(query_mh, name='query') + + # empty counter! + counter = counter_gather_constructor(query_ss.minhash) + + assert counter.peek(query_ss.minhash) == [] From 66490a437f0dea01e8d0340e476ea9f587461d12 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Fri, 8 Jul 2022 08:25:41 -0700 Subject: [PATCH 02/20] add LinearIndex wrapper --- tests/test_index_protocol.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index f81f736e9a..d81b059c5b 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -476,7 +476,29 @@ def basic_counter_gather(runtmp): from sourmash.index import CounterGather return CounterGather + +def linear_index_as_counter_gather(runtmp): + "test CounterGather API from LinearIndex" + + class LinearIndexWrapper: + def __init__(self, mh): + self.idx = LinearIndex() + self.mh = mh + + def add(self, ss): + self.idx.insert(ss) + + def peek(self, *args, **kwargs): + return self.idx.peek(*args, **kwargs) + + def consume(self, *args, **kwargs): + return self.idx.consume(*args, **kwargs) + + return LinearIndexWrapper + + @pytest.fixture(params=[basic_counter_gather, + linear_index_as_counter_gather, ] ) def counter_gather_constructor(request, runtmp): From ebb00eaf4413fac30b3bc3663ada83d8a07d303a Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Fri, 8 Jul 2022 09:00:35 -0700 Subject: [PATCH 03/20] getting closer --- src/sourmash/index/__init__.py | 4 ++-- tests/test_index_protocol.py | 44 +++++++++++++++++++++------------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index 5ef8304fd9..b4720e1f3c 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -277,7 +277,7 @@ def gather(self, query, threshold_bp=None, **kwargs): return results[:1] - def peek(self, query_mh, threshold_bp=0): + def peek(self, query_mh, *, threshold_bp=0): "Mimic CounterGather.peek() on top of Index. Yes, this is backwards." from sourmash import SourmashSignature @@ -764,7 +764,7 @@ def calc_threshold(self, threshold_bp, scaled, query_size): return threshold, n_threshold_hashes - def peek(self, cur_query_mh, threshold_bp=0): + def peek(self, cur_query_mh, *, threshold_bp=0): "Get next 'gather' result for this database, w/o changing counters." self.query_started = 1 scaled = cur_query_mh.scaled diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index d81b059c5b..c2ad1a6ba4 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -481,15 +481,27 @@ def linear_index_as_counter_gather(runtmp): "test CounterGather API from LinearIndex" class LinearIndexWrapper: - def __init__(self, mh): + def __init__(self, orig_query_mh): + if orig_query_mh.scaled == 0: + raise ValueError + self.idx = LinearIndex() - self.mh = mh + self.orig_query_mh = orig_query_mh + + def add(self, ss, *, location=None, require_overlap=True): + if not self.orig_query_mh & ss.minhash and require_overlap: + raise ValueError - def add(self, ss): self.idx.insert(ss) - def peek(self, *args, **kwargs): - return self.idx.peek(*args, **kwargs) + def peek(self, cur_query_mh, *, threshold_bp=0): + if not self.orig_query_mh: + return [] + + if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: + raise ValueError + + return self.idx.peek(cur_query_mh, threshold_bp=threshold_bp) def consume(self, *args, **kwargs): return self.idx.consume(*args, **kwargs) @@ -514,7 +526,7 @@ def _consume_all(query_mh, counter, threshold_bp=0): last_intersect_size = None while 1: - result = counter.peek(query_mh, threshold_bp) + result = counter.peek(query_mh, threshold_bp=threshold_bp) if not result: break @@ -828,7 +840,7 @@ def test_counter_gather_2(counter_gather_constructor): # load up the counter counter = counter_gather_constructor(query_ss.minhash) for ss, loc in subject_sigs: - counter.add(ss, loc) + counter.add(ss, location=loc) results = _consume_all(query_ss.minhash, counter) @@ -862,7 +874,7 @@ def test_counter_gather_exact_match(counter_gather_constructor): # load up the counter counter = counter_gather_constructor(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') + counter.add(query_ss, location='somewhere over the rainbow') results = _consume_all(query_ss.minhash, counter) assert len(results) == 1 @@ -881,12 +893,12 @@ def test_counter_gather_add_after_peek(counter_gather_constructor): # load up the counter counter = counter_gather_constructor(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') + counter.add(query_ss, location='somewhere over the rainbow') counter.peek(query_ss.minhash) with pytest.raises(ValueError): - counter.add(query_ss, "try again") + counter.add(query_ss, location="try again") def test_counter_gather_add_after_consume(counter_gather_constructor): @@ -897,12 +909,12 @@ def test_counter_gather_add_after_consume(counter_gather_constructor): # load up the counter counter = counter_gather_constructor(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') + counter.add(query_ss, location='somewhere over the rainbow') counter.consume(query_ss.minhash) with pytest.raises(ValueError): - counter.add(query_ss, "try again") + counter.add(query_ss, location="try again") def test_counter_gather_consume_empty_intersect(counter_gather_constructor): @@ -913,7 +925,7 @@ def test_counter_gather_consume_empty_intersect(counter_gather_constructor): # load up the counter counter = counter_gather_constructor(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') + counter.add(query_ss, location='somewhere over the rainbow') # nothing really happens here :laugh:, just making sure there's no error counter.consume(query_ss.minhash.copy_and_clear()) @@ -953,7 +965,7 @@ def test_counter_gather_empty_cur_query(counter_gather_constructor): # load up the counter counter = counter_gather_constructor(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') + counter.add(query_ss, location='somewhere over the rainbow') cur_query_mh = query_ss.minhash.copy_and_clear() results = _consume_all(cur_query_mh, counter) @@ -973,7 +985,7 @@ def test_counter_gather_add_num_matchy(counter_gather_constructor): # load up the counter counter = counter_gather_constructor(query_ss.minhash) with pytest.raises(ValueError): - counter.add(match_ss, 'somewhere over the rainbow') + counter.add(match_ss, location='somewhere over the rainbow') def test_counter_gather_bad_cur_query(counter_gather_constructor): @@ -984,7 +996,7 @@ def test_counter_gather_bad_cur_query(counter_gather_constructor): # load up the counter counter = counter_gather_constructor(query_ss.minhash) - counter.add(query_ss, 'somewhere over the rainbow') + counter.add(query_ss, location='somewhere over the rainbow') cur_query_mh = query_ss.minhash.copy_and_clear() cur_query_mh.add_many(range(20, 30)) From a8a4dd9b022e9be517b959cbab7498d2dbc99daa Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Fri, 8 Jul 2022 16:46:35 -0700 Subject: [PATCH 04/20] fix a bunch of the tests --- tests/test_index_protocol.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index c2ad1a6ba4..abf08a80ab 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -486,24 +486,33 @@ def __init__(self, orig_query_mh): raise ValueError self.idx = LinearIndex() - self.orig_query_mh = orig_query_mh + self.orig_query_mh = orig_query_mh.copy().flatten() + self.query_started = 0 def add(self, ss, *, location=None, require_overlap=True): - if not self.orig_query_mh & ss.minhash and require_overlap: + if self.query_started: + raise ValueError("cannot add more signatures to counter after peek/consume") + + add_mh = ss.minhash.flatten() + if not self.orig_query_mh & add_mh and require_overlap: raise ValueError self.idx.insert(ss) def peek(self, cur_query_mh, *, threshold_bp=0): - if not self.orig_query_mh: + self.query_started = 1 + if not self.orig_query_mh or not cur_query_mh: return [] + cur_query_mh = cur_query_mh.flatten() + if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: raise ValueError return self.idx.peek(cur_query_mh, threshold_bp=threshold_bp) def consume(self, *args, **kwargs): + self.query_started = 1 return self.idx.consume(*args, **kwargs) return LinearIndexWrapper From b444a68792f414fa83a77a57e4699d4f5d46c02e Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 9 Jul 2022 03:58:18 -0700 Subject: [PATCH 05/20] fix call to 'peek' --- src/sourmash/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 5a86fa8d85..54eba2b19c 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -621,7 +621,7 @@ def _find_best(counters, query, threshold_bp): # find the best score across multiple counters, without consuming for counter in counters: - result = counter.peek(query.minhash, threshold_bp) + result = counter.peek(query.minhash, threshold_bp=threshold_bp) if result: (sr, intersect_mh) = result From f87c9d49e5e44bb0eeb4b188a2da2963447ec3c9 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 9 Jul 2022 04:14:03 -0700 Subject: [PATCH 06/20] adjust 'counter.add' call signature --- src/sourmash/index/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index b4720e1f3c..08fb3a3c49 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -323,7 +323,7 @@ def counter_gather(self, query, threshold_bp, **kwargs): # find all matches and construct a CounterGather object. counter = CounterGather(prefetch_query.minhash) for result in self.prefetch(prefetch_query, threshold_bp, **kwargs): - counter.add(result.signature, result.location) + counter.add(result.signature, location=result.location) # tada! return counter @@ -725,8 +725,10 @@ def __init__(self, query_mh): # cannot add matches once query has started. self.query_started = 0 - def add(self, ss, location=None, require_overlap=True): + def add(self, ss, *, location=None, require_overlap=True): "Add this signature in as a potential match." +# if location is not None: # @CTB +# raise Exception if self.query_started: raise ValueError("cannot add more signatures to counter after peek/consume") From 68458cf0a43a9c9e3318d39c38beda40cbbba91d Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 9 Jul 2022 04:23:51 -0700 Subject: [PATCH 07/20] add CounterGather_LCA --- src/sourmash/index/__init__.py | 60 ++++++++++++++++++++++++++++++++++ tests/test_index_protocol.py | 14 +++++--- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index 08fb3a3c49..acece865f8 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -854,6 +854,66 @@ def consume(self, intersect_mh): del counter[dataset_id] +class CounterGather_LCA: + # @CTB + def __init__(self, mh): + from sourmash.lca.lca_db import LCA_Database + if mh.scaled == 0: + raise ValueError("must use scaled MinHash") + + self.orig_query_mh = mh + lca_db = LCA_Database(mh.ksize, mh.scaled, mh.moltype) + self.db = lca_db + self.siglist = [] + self.locations = [] + self.query_started = 0 + + def add(self, ss, *, location=None, require_overlap=True): + if self.query_started: + raise ValueError("cannot add more signatures to counter after peek/consume") + + overlap = self.orig_query_mh.count_common(ss.minhash, True) + if not overlap and require_overlap: + raise ValueError("no overlap between query and signature!?") + + self.db.insert(ss) + self.siglist.append(ss) + self.locations.append(location) + + def peek(self, query_mh, *, threshold_bp=0): + from sourmash import SourmashSignature + + self.query_started = 1 + if not self.orig_query_mh or not query_mh: + return [] + + if query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: + raise ValueError("current query not a subset of original query") + + query_ss = SourmashSignature(query_mh) + + # returns search_result, intersect_mh + try: + result = self.db.gather(query_ss, threshold_bp=threshold_bp) + except ValueError: + result = None + + if not result: + return [] + + sr = result[0] + match_mh = sr.signature.minhash + scaled = max(query_mh.scaled, match_mh.scaled) + match_mh = match_mh.downsample(scaled=scaled).flatten() + query_mh = query_mh.downsample(scaled=scaled) + intersect_mh = match_mh & query_mh + + return [sr, intersect_mh] + + def consume(self, intersect_mh): + self.query_started = 1 + + class MultiIndex(Index): """ Load a collection of signatures, and retain their original locations. diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index abf08a80ab..a343ec9b0a 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -471,13 +471,13 @@ def test_gather_threshold_5(index_obj): ### -def basic_counter_gather(runtmp): +def create_basic_counter_gather(runtmp): "build a basic CounterGather class" from sourmash.index import CounterGather return CounterGather -def linear_index_as_counter_gather(runtmp): +def create_linear_index_as_counter_gather(runtmp): "test CounterGather API from LinearIndex" class LinearIndexWrapper: @@ -518,8 +518,14 @@ def consume(self, *args, **kwargs): return LinearIndexWrapper -@pytest.fixture(params=[basic_counter_gather, - linear_index_as_counter_gather, +def create_counter_gather_lca(runtmp): + from sourmash.index import CounterGather_LCA + return CounterGather_LCA + + +@pytest.fixture(params=[create_basic_counter_gather, + create_linear_index_as_counter_gather, + create_counter_gather_lca, ] ) def counter_gather_constructor(request, runtmp): From b835c96d9f78eaeb3460676cc93ba14dc1c5d03f Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 9 Jul 2022 04:48:47 -0700 Subject: [PATCH 08/20] move CounterGather.calc_threshold into search.py --- src/sourmash/index/__init__.py | 30 ++++++++---------------------- src/sourmash/search.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index acece865f8..98d1392b08 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -40,10 +40,11 @@ from collections import namedtuple, Counter from collections import defaultdict -from ..search import make_jaccard_search_query, make_gather_query -from ..manifest import CollectionManifest -from ..logging import debug_literal -from ..signature import load_signatures, save_signatures +from sourmash.search import (make_jaccard_search_query, make_gather_query, + calc_threshold_from_bp) +from sourmash.manifest import CollectionManifest +from sourmash.logging import debug_literal +from sourmash.signature import load_signatures, save_signatures # generic return tuple for Index.search and Index.gather IndexSearchResult = namedtuple('Result', 'score, signature, location') @@ -751,21 +752,6 @@ def downsample(self, scaled): if scaled > self.scaled: self.scaled = scaled - def calc_threshold(self, threshold_bp, scaled, query_size): - # CTB: this code doesn't need to be in this class. - threshold = 0.0 - n_threshold_hashes = 0 - - if threshold_bp: - # if we have a threshold_bp of N, then that amounts to N/scaled - # hashes: - n_threshold_hashes = float(threshold_bp) / scaled - - # that then requires the following containment: - threshold = n_threshold_hashes / query_size - - return threshold, n_threshold_hashes - def peek(self, cur_query_mh, *, threshold_bp=0): "Get next 'gather' result for this database, w/o changing counters." self.query_started = 1 @@ -790,9 +776,9 @@ def peek(self, cur_query_mh, *, threshold_bp=0): raise ValueError("current query not a subset of original query") # are we setting a threshold? - threshold, n_threshold_hashes = self.calc_threshold(threshold_bp, - scaled, - len(cur_query_mh)) + threshold, n_threshold_hashes = calc_threshold_from_bp(threshold_bp, + scaled, + len(cur_query_mh)) # is it too high to ever match? if so, exit. if threshold > 1.0: return [] diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 54eba2b19c..3e03951978 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -11,6 +11,25 @@ from .sketchcomparison import FracMinHashComparison, NumMinHashComparison +def calc_threshold_from_bp(threshold_bp, scaled, query_size): + """ + Convert threshold_bp (threshold in estimated bp) to + fraction of query & minimum number of hashes needed. + """ + threshold = 0.0 + n_threshold_hashes = 0 + + if threshold_bp: + # if we have a threshold_bp of N, then that amounts to N/scaled + # hashes: + n_threshold_hashes = float(threshold_bp) / scaled + + # that then requires the following containment: + threshold = n_threshold_hashes / query_size + + return threshold, n_threshold_hashes + + class SearchType(Enum): JACCARD = 1 CONTAINMENT = 2 From 1903920794074f9dbb9502a3709211932b6fa86e Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 9 Jul 2022 05:20:54 -0700 Subject: [PATCH 09/20] minor refactoring --- src/sourmash/index/__init__.py | 5 ++--- tests/test_index_protocol.py | 21 ++++++++++++++++----- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index 98d1392b08..c592d8f056 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -751,11 +751,11 @@ def downsample(self, scaled): "Track highest scaled across all possible matches." if scaled > self.scaled: self.scaled = scaled + return self.scaled def peek(self, cur_query_mh, *, threshold_bp=0): "Get next 'gather' result for this database, w/o changing counters." self.query_started = 1 - scaled = cur_query_mh.scaled # empty? nothing to search. counter = self.counter @@ -765,8 +765,7 @@ def peek(self, cur_query_mh, *, threshold_bp=0): siglist = self.siglist assert siglist - self.downsample(scaled) - scaled = self.scaled + scaled = self.downsample(cur_query_mh.scaled) cur_query_mh = cur_query_mh.downsample(scaled=scaled) if not cur_query_mh: # empty query? quit. diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index a343ec9b0a..94f06029db 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -488,26 +488,37 @@ def __init__(self, orig_query_mh): self.idx = LinearIndex() self.orig_query_mh = orig_query_mh.copy().flatten() self.query_started = 0 + self.scaled = orig_query_mh.scaled def add(self, ss, *, location=None, require_overlap=True): if self.query_started: raise ValueError("cannot add more signatures to counter after peek/consume") add_mh = ss.minhash.flatten() - if not self.orig_query_mh & add_mh and require_overlap: - raise ValueError + if self.orig_query_mh & add_mh: + self.downsample(ss.minhash.scaled) + elif require_overlap: + raise ValueError("no overlap between query and signature!?") self.idx.insert(ss) + def downsample(self, scaled): + "Track highest scaled across all possible matches." + if scaled > self.scaled: + self.scaled = scaled + return self.scaled + def peek(self, cur_query_mh, *, threshold_bp=0): self.query_started = 1 + cur_query_mh = cur_query_mh.flatten() + scaled = self.downsample(cur_query_mh.scaled) + cur_query_mh = cur_query_mh.downsample(scaled=scaled) + if not self.orig_query_mh or not cur_query_mh: return [] - cur_query_mh = cur_query_mh.flatten() - if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: - raise ValueError + raise ValueError("current query not a subset of original query") return self.idx.peek(cur_query_mh, threshold_bp=threshold_bp) From 5099d5a3b90da58ddb2bb7c37d625c41d1e4e7b1 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 9 Jul 2022 05:26:07 -0700 Subject: [PATCH 10/20] resolve downsampling for linear index wrapper --- tests/test_index_protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index 94f06029db..7311f8b817 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -495,8 +495,10 @@ def add(self, ss, *, location=None, require_overlap=True): raise ValueError("cannot add more signatures to counter after peek/consume") add_mh = ss.minhash.flatten() - if self.orig_query_mh & add_mh: - self.downsample(ss.minhash.scaled) + overlap = self.orig_query_mh.count_common(add_mh, downsample=True) + + if overlap: + self.downsample(add_mh.scaled) elif require_overlap: raise ValueError("no overlap between query and signature!?") From a8125b4584696104e67c680751295706d8895821 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 9 Jul 2022 05:39:05 -0700 Subject: [PATCH 11/20] fix downsampling for LCA-based CounterGather --- src/sourmash/index/__init__.py | 13 ++++++++++++- src/sourmash/lca/lca_db.py | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index c592d8f056..8e0947840a 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -858,17 +858,28 @@ def add(self, ss, *, location=None, require_overlap=True): raise ValueError("cannot add more signatures to counter after peek/consume") overlap = self.orig_query_mh.count_common(ss.minhash, True) - if not overlap and require_overlap: + if overlap: + self.downsample(ss.minhash.scaled) + elif require_overlap: raise ValueError("no overlap between query and signature!?") self.db.insert(ss) self.siglist.append(ss) self.locations.append(location) + def downsample(self, scaled): + "Track highest scaled across all possible matches." + if scaled > self.db.scaled: + self.db.downsample_scaled(scaled) + return self.db.scaled + def peek(self, query_mh, *, threshold_bp=0): from sourmash import SourmashSignature self.query_started = 1 + scaled = self.downsample(query_mh.scaled) + query_mh = query_mh.downsample(scaled=scaled) + if not self.orig_query_mh or not query_mh: return [] diff --git a/src/sourmash/lca/lca_db.py b/src/sourmash/lca/lca_db.py index 280f810426..21ff15a2c1 100644 --- a/src/sourmash/lca/lca_db.py +++ b/src/sourmash/lca/lca_db.py @@ -460,7 +460,7 @@ def downsample_scaled(self, scaled): max_hash = _get_max_hash_for_scaled(scaled) # filter out all hashes over max_hash in value. - new_hashvals = {} + new_hashvals = defaultdict(set) for k, v in self._hashval_to_idx.items(): if k < max_hash: new_hashvals[k] = v From 1760ada3e7c3c755e32a2094f1d386cec44093e6 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 9 Jul 2022 06:20:23 -0700 Subject: [PATCH 12/20] fix location foo --- src/sourmash/index/__init__.py | 16 +++++++++++++--- tests/test_index_protocol.py | 19 +++++++++++++++++-- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index 8e0947840a..6f8dc30cc2 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -850,7 +850,7 @@ def __init__(self, mh): lca_db = LCA_Database(mh.ksize, mh.scaled, mh.moltype) self.db = lca_db self.siglist = [] - self.locations = [] + self.locations = {} self.query_started = 0 def add(self, ss, *, location=None, require_overlap=True): @@ -865,7 +865,10 @@ def add(self, ss, *, location=None, require_overlap=True): self.db.insert(ss) self.siglist.append(ss) - self.locations.append(location) + #self.locations.append(location) + + md5 = ss.md5sum() + self.locations[md5] = location def downsample(self, scaled): "Track highest scaled across all possible matches." @@ -898,13 +901,20 @@ def peek(self, query_mh, *, threshold_bp=0): return [] sr = result[0] + cont = sr.score + match = sr.signature + match_mh = sr.signature.minhash scaled = max(query_mh.scaled, match_mh.scaled) match_mh = match_mh.downsample(scaled=scaled).flatten() query_mh = query_mh.downsample(scaled=scaled) intersect_mh = match_mh & query_mh - return [sr, intersect_mh] + md5 = sr.signature.md5sum() + location = self.locations[md5] + + new_sr = IndexSearchResult(cont, match, location) + return [new_sr, intersect_mh] def consume(self, intersect_mh): self.query_started = 1 diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index 7311f8b817..910de0bd23 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -489,6 +489,7 @@ def __init__(self, orig_query_mh): self.orig_query_mh = orig_query_mh.copy().flatten() self.query_started = 0 self.scaled = orig_query_mh.scaled + self.locations = {} def add(self, ss, *, location=None, require_overlap=True): if self.query_started: @@ -504,6 +505,9 @@ def add(self, ss, *, location=None, require_overlap=True): self.idx.insert(ss) + md5 = ss.md5sum() + self.locations[md5] = location + def downsample(self, scaled): "Track highest scaled across all possible matches." if scaled > self.scaled: @@ -522,7 +526,17 @@ def peek(self, cur_query_mh, *, threshold_bp=0): if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: raise ValueError("current query not a subset of original query") - return self.idx.peek(cur_query_mh, threshold_bp=threshold_bp) + res = self.idx.peek(cur_query_mh, threshold_bp=threshold_bp) + if not res: + return [] + sr, intersect_mh = res + + from sourmash.index import IndexSearchResult + match = sr.signature + md5 = match.md5sum() + location = self.locations[md5] + new_sr = IndexSearchResult(sr.score, match, location) + return new_sr, intersect_mh def consume(self, *args, **kwargs): self.query_started = 1 @@ -900,7 +914,8 @@ def test_counter_gather_exact_match(counter_gather_constructor): query_mh.add_many(range(0, 20)) query_ss = SourmashSignature(query_mh, name='query') - # load up the counter + # load up the counter; provide a location override, too. + # @CTB split out into a separate test? counter = counter_gather_constructor(query_ss.minhash) counter.add(query_ss, location='somewhere over the rainbow') From 5c9748a80a8797e6c5d4276701d00427f8aa70d1 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sat, 9 Jul 2022 07:32:45 -0700 Subject: [PATCH 13/20] fix remaining test --- tests/test_index.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_index.py b/tests/test_index.py index 0f27e974c6..2f633668e4 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1567,9 +1567,9 @@ def test_counter_gather_3_test_consume(): # load up the counter counter = CounterGather(query_ss.minhash) - counter.add(match_ss_1, 'loc a') - counter.add(match_ss_2, 'loc b') - counter.add(match_ss_3, 'loc c') + counter.add(match_ss_1, location='loc a') + counter.add(match_ss_2, location='loc b') + counter.add(match_ss_3, location='loc c') ### ok, dig into actual counts... import pprint From c2d2637a1fc82457fb67a434607e22c24154ffd6 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sun, 10 Jul 2022 06:10:37 -0700 Subject: [PATCH 14/20] minor cleanup --- src/sourmash/index/__init__.py | 2 -- tests/test_index.py | 7 ++----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index 6f8dc30cc2..ead5a52a37 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -728,8 +728,6 @@ def __init__(self, query_mh): def add(self, ss, *, location=None, require_overlap=True): "Add this signature in as a potential match." -# if location is not None: # @CTB -# raise Exception if self.query_started: raise ValueError("cannot add more signatures to counter after peek/consume") diff --git a/tests/test_index.py b/tests/test_index.py index 2f633668e4..607830df87 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1542,13 +1542,10 @@ def is_found(ss, xx): assert not is_found(ss2, results) assert is_found(ss63, results) -### -### CounterGather tests - implementation specific. See test_index_protocol -### for protocol tests. -### -def test_counter_gather_3_test_consume(): +def test_counter_gather_test_consume(): # open-box testing of CounterGather.consume(...) + # (see test_index_protocol.py for generic CounterGather tests.) query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) query_mh.add_many(range(0, 20)) query_ss = SourmashSignature(query_mh, name='query') From 6f9eb787d5581ea4e3cf739bb50719dd8b782534 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sun, 10 Jul 2022 06:16:07 -0700 Subject: [PATCH 15/20] add doc --- src/sourmash/index/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index ead5a52a37..7ae78c54fe 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -279,7 +279,11 @@ def gather(self, query, threshold_bp=None, **kwargs): return results[:1] def peek(self, query_mh, *, threshold_bp=0): - "Mimic CounterGather.peek() on top of Index. Yes, this is backwards." + """Mimic CounterGather.peek() on top of Index. + + This is implemented for situations where we don't want to use + 'prefetch' functionality. + """ from sourmash import SourmashSignature # build a signature to use with self.gather... From f82e1d7850e88bb4c688a1b4f860c93acec8e0d0 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sun, 10 Jul 2022 06:23:44 -0700 Subject: [PATCH 16/20] test multiple identical matches --- tests/test_index_protocol.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index 910de0bd23..c452c2210f 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -915,7 +915,6 @@ def test_counter_gather_exact_match(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') # load up the counter; provide a location override, too. - # @CTB split out into a separate test? counter = counter_gather_constructor(query_ss.minhash) counter.add(query_ss, location='somewhere over the rainbow') @@ -928,6 +927,34 @@ def test_counter_gather_exact_match(counter_gather_constructor): assert sr.location == 'somewhere over the rainbow' +def test_counter_gather_multiple_identical_matches(counter_gather_constructor): + # query == match + query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) + query_mh.add_many(range(0, 20)) + query_ss = SourmashSignature(query_mh, name='query') + + # create counter... + counter = counter_gather_constructor(query_ss.minhash) + + # now add multiple identical matches. + match_mh = query_mh.copy_and_clear() + match_mh.add_many(range(5, 15)) + + for name in 'match1', 'match2', 'match3': + match_ss = SourmashSignature(match_mh, name=name) + counter.add(match_ss, location=name) + + results = _consume_all(query_ss.minhash, counter) + assert len(results) == 1 + + sr, overlap_count = results[0] + assert sr.score == 0.5 + assert overlap_count == 10 + + # any one of the three is valid + assert sr.location in ('match1', 'match2', 'match3') + + def test_counter_gather_add_after_peek(counter_gather_constructor): # cannot add after peek or consume query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) From d9472ed084e4ae04be3beff8e8f04fd520262e80 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sun, 10 Jul 2022 06:25:54 -0700 Subject: [PATCH 17/20] adjust LinearIndex implementation to skip identical matches --- tests/test_index_protocol.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index c452c2210f..561ba6afbb 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -495,6 +495,11 @@ def add(self, ss, *, location=None, require_overlap=True): if self.query_started: raise ValueError("cannot add more signatures to counter after peek/consume") + # skip duplicates + md5 = ss.md5sum() + if md5 in self.locations: + return + add_mh = ss.minhash.flatten() overlap = self.orig_query_mh.count_common(add_mh, downsample=True) @@ -504,8 +509,6 @@ def add(self, ss, *, location=None, require_overlap=True): raise ValueError("no overlap between query and signature!?") self.idx.insert(ss) - - md5 = ss.md5sum() self.locations[md5] = location def downsample(self, scaled): @@ -928,7 +931,7 @@ def test_counter_gather_exact_match(counter_gather_constructor): def test_counter_gather_multiple_identical_matches(counter_gather_constructor): - # query == match + # test multiple identical matches being inserted, with only one return query_mh = sourmash.MinHash(n=0, ksize=31, scaled=1) query_mh.add_many(range(0, 20)) query_ss = SourmashSignature(query_mh, name='query') From 4c14e01e9ac068efdae44f7164e44dce1471465f Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Mon, 11 Jul 2022 05:51:01 -0700 Subject: [PATCH 18/20] cleanup protocol tests --- src/sourmash/index/__init__.py | 117 +++++------------------ tests/test_index_protocol.py | 169 +++++++++++++++++---------------- 2 files changed, 114 insertions(+), 172 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index 7ae78c54fe..2334715bb0 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -282,7 +282,8 @@ def peek(self, query_mh, *, threshold_bp=0): """Mimic CounterGather.peek() on top of Index. This is implemented for situations where we don't want to use - 'prefetch' functionality. + 'prefetch' functionality. It is a light wrapper around the + 'gather'/search-by-containment method. """ from sourmash import SourmashSignature @@ -706,13 +707,24 @@ def select(self, **kwargs): class CounterGather: - """ - Track and summarize matches for efficient 'gather' protocol. This - could be used downstream of prefetch (for example). + """This is an ancillary class that is used to implement "fast + gather", post-prefetch. It tracks and summarize matches for + efficient min-set-cov/'gather'. + + The class constructor takes a query MinHash that must be scaled, and + then takes signatures that have overlaps with the query (via 'add'). - The public interface is `peek(...)` and `consume(...)` only. + After all overlapping signatures have been loaded, the 'peek' + method is then used at each stage of the 'gather' procedure to + find the best match, and the 'consume' method is used to remove + a match from this counter. + + This particular implementation maintains a collections.Counter that + is used to quickly find the best match when 'peek' is called, but + other implementations are possible ;). """ def __init__(self, query_mh): + "Constructor - takes a query FracMinHash." if not query_mh.scaled: raise ValueError('gather requires scaled signatures') @@ -720,14 +732,14 @@ def __init__(self, query_mh): self.orig_query_mh = query_mh.copy().flatten() self.scaled = query_mh.scaled - # track matching signatures & their locations + # use these to track loaded matches & their locations self.siglist = [] self.locations = [] - # ...and overlaps with query + # ...and also track overlaps with the progressive query self.counter = Counter() - # cannot add matches once query has started. + # fence to make sure we do add matches once query has started. self.query_started = 0 def add(self, ss, *, location=None, require_overlap=True): @@ -773,6 +785,7 @@ def peek(self, cur_query_mh, *, threshold_bp=0): if not cur_query_mh: # empty query? quit. return [] + # CTB: could probably remove this check unless debug requested. if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: raise ValueError("current query not a subset of original query") @@ -784,7 +797,7 @@ def peek(self, cur_query_mh, *, threshold_bp=0): if threshold > 1.0: return [] - # Find the best match - + # Find the best match using the internal Counter. most_common = counter.most_common() dataset_id, match_size = most_common[0] @@ -792,12 +805,13 @@ def peek(self, cur_query_mh, *, threshold_bp=0): if match_size < n_threshold_hashes: return [] - ## at this point, we must have a legitimate match above threshold! + ## at this point, we have a legitimate match above threshold! # pull match and location. match = siglist[dataset_id] # calculate containment + # CTB: this check is probably redundant with intersect_mh calc, below. cont = cur_query_mh.contained_by(match.minhash, downsample=True) assert cont assert cont >= threshold @@ -811,7 +825,7 @@ def peek(self, cur_query_mh, *, threshold_bp=0): return (IndexSearchResult(cont, match, location), intersect_mh) def consume(self, intersect_mh): - "Remove the given hashes from this counter." + "Maintain the internal counter by removing the given hashes." self.query_started = 1 if not intersect_mh: @@ -841,87 +855,6 @@ def consume(self, intersect_mh): del counter[dataset_id] -class CounterGather_LCA: - # @CTB - def __init__(self, mh): - from sourmash.lca.lca_db import LCA_Database - if mh.scaled == 0: - raise ValueError("must use scaled MinHash") - - self.orig_query_mh = mh - lca_db = LCA_Database(mh.ksize, mh.scaled, mh.moltype) - self.db = lca_db - self.siglist = [] - self.locations = {} - self.query_started = 0 - - def add(self, ss, *, location=None, require_overlap=True): - if self.query_started: - raise ValueError("cannot add more signatures to counter after peek/consume") - - overlap = self.orig_query_mh.count_common(ss.minhash, True) - if overlap: - self.downsample(ss.minhash.scaled) - elif require_overlap: - raise ValueError("no overlap between query and signature!?") - - self.db.insert(ss) - self.siglist.append(ss) - #self.locations.append(location) - - md5 = ss.md5sum() - self.locations[md5] = location - - def downsample(self, scaled): - "Track highest scaled across all possible matches." - if scaled > self.db.scaled: - self.db.downsample_scaled(scaled) - return self.db.scaled - - def peek(self, query_mh, *, threshold_bp=0): - from sourmash import SourmashSignature - - self.query_started = 1 - scaled = self.downsample(query_mh.scaled) - query_mh = query_mh.downsample(scaled=scaled) - - if not self.orig_query_mh or not query_mh: - return [] - - if query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: - raise ValueError("current query not a subset of original query") - - query_ss = SourmashSignature(query_mh) - - # returns search_result, intersect_mh - try: - result = self.db.gather(query_ss, threshold_bp=threshold_bp) - except ValueError: - result = None - - if not result: - return [] - - sr = result[0] - cont = sr.score - match = sr.signature - - match_mh = sr.signature.minhash - scaled = max(query_mh.scaled, match_mh.scaled) - match_mh = match_mh.downsample(scaled=scaled).flatten() - query_mh = query_mh.downsample(scaled=scaled) - intersect_mh = match_mh & query_mh - - md5 = sr.signature.md5sum() - location = self.locations[md5] - - new_sr = IndexSearchResult(cont, match, location) - return [new_sr, intersect_mh] - - def consume(self, intersect_mh): - self.query_started = 1 - - class MultiIndex(Index): """ Load a collection of signatures, and retain their original locations. diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index 561ba6afbb..c94172c09a 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -472,97 +472,106 @@ def test_gather_threshold_5(index_obj): def create_basic_counter_gather(runtmp): - "build a basic CounterGather class" - from sourmash.index import CounterGather + "Construct a CounterGather class." return CounterGather +class CounterGather_LinearIndex: + """ + Provides an (inefficient) CounterGather-style class, for + protocol testing purposes. + """ + def __init__(self, orig_query_mh): + "Constructor - take a FracMinHash that is the original query." + if orig_query_mh.scaled == 0: + raise ValueError + + # Index object used to actually track matches. + self.idx = LinearIndex() + self.orig_query_mh = orig_query_mh.copy().flatten() + self.query_started = 0 + self.scaled = orig_query_mh.scaled + self.locations = {} + + def add(self, ss, *, location=None, require_overlap=True): + "Insert potential match." + if self.query_started: + raise ValueError("cannot add more signatures to counter after peek/consume") + + # skip duplicates + md5 = ss.md5sum() + if md5 in self.locations: + return + + # confirm that this match has an overlap... + add_mh = ss.minhash.flatten() + overlap = self.orig_query_mh.count_common(add_mh, downsample=True) + + # ...figure out what scaled we are operating at now... + if overlap: + self.downsample(add_mh.scaled) + elif require_overlap: + raise ValueError("no overlap between query and signature!?") + + # ...and add to the Index, while also tracking location! + self.idx.insert(ss) + self.locations[md5] = location + + def downsample(self, scaled): + "Track highest scaled across all possible matches." + if scaled > self.scaled: + self.scaled = scaled + return self.scaled + + def peek(self, cur_query_mh, *, threshold_bp=0): + """ + Find best match to current query within this CounterGather object. + """ + self.query_started = 1 + cur_query_mh = cur_query_mh.flatten() + scaled = self.downsample(cur_query_mh.scaled) + cur_query_mh = cur_query_mh.downsample(scaled=scaled) + + # no match? exit. + if not self.orig_query_mh or not cur_query_mh: + return [] + + # verify current query is a subset of the original. + if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: + raise ValueError("current query not a subset of original query") + + # did we get a match? + res = self.idx.peek(cur_query_mh, threshold_bp=threshold_bp) + if not res: + return [] + sr, intersect_mh = res + + from sourmash.index import IndexSearchResult + match = sr.signature + md5 = match.md5sum() + location = self.locations[md5] + new_sr = IndexSearchResult(sr.score, match, location) + return new_sr, intersect_mh + + def consume(self, *args, **kwargs): + self.query_started = 1 + return self.idx.consume(*args, **kwargs) + + +from sourmash.index import CounterGather def create_linear_index_as_counter_gather(runtmp): "test CounterGather API from LinearIndex" + return CounterGather_LinearIndex - class LinearIndexWrapper: - def __init__(self, orig_query_mh): - if orig_query_mh.scaled == 0: - raise ValueError - - self.idx = LinearIndex() - self.orig_query_mh = orig_query_mh.copy().flatten() - self.query_started = 0 - self.scaled = orig_query_mh.scaled - self.locations = {} - - def add(self, ss, *, location=None, require_overlap=True): - if self.query_started: - raise ValueError("cannot add more signatures to counter after peek/consume") - - # skip duplicates - md5 = ss.md5sum() - if md5 in self.locations: - return - - add_mh = ss.minhash.flatten() - overlap = self.orig_query_mh.count_common(add_mh, downsample=True) - - if overlap: - self.downsample(add_mh.scaled) - elif require_overlap: - raise ValueError("no overlap between query and signature!?") - - self.idx.insert(ss) - self.locations[md5] = location - - def downsample(self, scaled): - "Track highest scaled across all possible matches." - if scaled > self.scaled: - self.scaled = scaled - return self.scaled - - def peek(self, cur_query_mh, *, threshold_bp=0): - self.query_started = 1 - cur_query_mh = cur_query_mh.flatten() - scaled = self.downsample(cur_query_mh.scaled) - cur_query_mh = cur_query_mh.downsample(scaled=scaled) - - if not self.orig_query_mh or not cur_query_mh: - return [] - - if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1: - raise ValueError("current query not a subset of original query") - - res = self.idx.peek(cur_query_mh, threshold_bp=threshold_bp) - if not res: - return [] - sr, intersect_mh = res - - from sourmash.index import IndexSearchResult - match = sr.signature - md5 = match.md5sum() - location = self.locations[md5] - new_sr = IndexSearchResult(sr.score, match, location) - return new_sr, intersect_mh - - def consume(self, *args, **kwargs): - self.query_started = 1 - return self.idx.consume(*args, **kwargs) - - return LinearIndexWrapper - - -def create_counter_gather_lca(runtmp): - from sourmash.index import CounterGather_LCA - return CounterGather_LCA - - -@pytest.fixture(params=[create_basic_counter_gather, - create_linear_index_as_counter_gather, - create_counter_gather_lca, +@pytest.fixture(params=[CounterGather, + CounterGather_LinearIndex, ] ) -def counter_gather_constructor(request, runtmp): +def counter_gather_constructor(request): build_fn = request.param # build on demand - return build_fn(runtmp) + return build_fn def _consume_all(query_mh, counter, threshold_bp=0): From 3df8c66ecc8a0585ec384a84f0ce43635b41955d Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Mon, 11 Jul 2022 05:57:00 -0700 Subject: [PATCH 19/20] revert LCA_Database fix --- src/sourmash/lca/lca_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sourmash/lca/lca_db.py b/src/sourmash/lca/lca_db.py index 21ff15a2c1..280f810426 100644 --- a/src/sourmash/lca/lca_db.py +++ b/src/sourmash/lca/lca_db.py @@ -460,7 +460,7 @@ def downsample_scaled(self, scaled): max_hash = _get_max_hash_for_scaled(scaled) # filter out all hashes over max_hash in value. - new_hashvals = defaultdict(set) + new_hashvals = {} for k, v in self._hashval_to_idx.items(): if k < max_hash: new_hashvals[k] = v From 1a4e01b5599e9bab2b22cbce8c180017cc5fc33f Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Mon, 11 Jul 2022 06:06:23 -0700 Subject: [PATCH 20/20] cleanup --- src/sourmash/index/__init__.py | 1 - tests/test_index.py | 1 - tests/test_index_protocol.py | 16 ++-------------- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index 2334715bb0..eb8a55a94c 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -38,7 +38,6 @@ import sourmash from abc import abstractmethod, ABC from collections import namedtuple, Counter -from collections import defaultdict from sourmash.search import (make_jaccard_search_query, make_gather_query, calc_threshold_from_bp) diff --git a/tests/test_index.py b/tests/test_index.py index 607830df87..e36275b092 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -8,7 +8,6 @@ import shutil import sourmash -from sourmash import index from sourmash import load_one_signature, SourmashSignature from sourmash.index import (LinearIndex, ZipFileLinearIndex, make_jaccard_search_query, CounterGather, diff --git a/tests/test_index_protocol.py b/tests/test_index_protocol.py index c94172c09a..19f27788c8 100644 --- a/tests/test_index_protocol.py +++ b/tests/test_index_protocol.py @@ -11,6 +11,7 @@ from sourmash.index import (LinearIndex, ZipFileLinearIndex, LazyLinearIndex, MultiIndex, StandaloneManifestIndex) +from sourmash.index import CounterGather from sourmash.index.sqlite_index import SqliteIndex from sourmash.index.revindex import RevIndex from sourmash.sbt import SBT, GraphFactory @@ -129,14 +130,6 @@ def build_lca_index_save_load(runtmp): return sourmash.load_file_as_index(outfile) -def build_lca_index_save_load(runtmp): - db = build_lca_index(runtmp) - outfile = runtmp.output('db.lca.json') - db.save(outfile) - - return sourmash.load_file_as_index(outfile) - - def build_sqlite_index(runtmp): filename = runtmp.output('idx.sqldb') db = SqliteIndex.create(filename) @@ -558,11 +551,6 @@ def consume(self, *args, **kwargs): return self.idx.consume(*args, **kwargs) -from sourmash.index import CounterGather -def create_linear_index_as_counter_gather(runtmp): - "test CounterGather API from LinearIndex" - return CounterGather_LinearIndex - @pytest.fixture(params=[CounterGather, CounterGather_LinearIndex, ] @@ -1036,7 +1024,7 @@ def test_counter_gather_num_query(counter_gather_constructor): query_ss = SourmashSignature(query_mh, name='query') with pytest.raises(ValueError): - counter = counter_gather_constructor(query_ss.minhash) + counter_gather_constructor(query_ss.minhash) def test_counter_gather_empty_cur_query(counter_gather_constructor):