Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] update Index protocol tests to include tests for peek and consume #2111

Merged
merged 26 commits into from
Jul 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
241dbc5
move most CounterGather tests over to index protocol tests
ctb Jul 8, 2022
66490a4
add LinearIndex wrapper
ctb Jul 8, 2022
ebb00ea
getting closer
ctb Jul 8, 2022
a8a4dd9
fix a bunch of the tests
ctb Jul 8, 2022
fdc8d4f
Merge branch 'latest' of https://github.com/sourmash-bio/sourmash int…
ctb Jul 8, 2022
ba114dd
Merge branch 'latest' of https://github.com/sourmash-bio/sourmash int…
ctb Jul 9, 2022
b444a68
fix call to 'peek'
ctb Jul 9, 2022
f87c9d4
adjust 'counter.add' call signature
ctb Jul 9, 2022
68458cf
add CounterGather_LCA
ctb Jul 9, 2022
b835c96
move CounterGather.calc_threshold into search.py
ctb Jul 9, 2022
1903920
minor refactoring
ctb Jul 9, 2022
5099d5a
resolve downsampling for linear index wrapper
ctb Jul 9, 2022
a8125b4
fix downsampling for LCA-based CounterGather
ctb Jul 9, 2022
1760ada
fix location foo
ctb Jul 9, 2022
5c9748a
fix remaining test
ctb Jul 9, 2022
c2d2637
minor cleanup
ctb Jul 10, 2022
6f9eb78
add doc
ctb Jul 10, 2022
f82e1d7
test multiple identical matches
ctb Jul 10, 2022
d9472ed
adjust LinearIndex implementation to skip identical matches
ctb Jul 10, 2022
4c14e01
cleanup protocol tests
ctb Jul 11, 2022
3df8c66
revert LCA_Database fix
ctb Jul 11, 2022
36d4c2c
Merge branch 'latest' into refactor/counter_gather_tests
ctb Jul 11, 2022
1a4e01b
cleanup
ctb Jul 11, 2022
ee0fd18
Merge branch 'refactor/counter_gather_tests' of https://github.com/so…
ctb Jul 11, 2022
dbabfe9
Merge branch 'latest' into refactor/counter_gather_tests
ctb Jul 11, 2022
c6078a6
Merge branch 'latest' into refactor/counter_gather_tests
ctb Jul 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 43 additions & 41 deletions src/sourmash/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
import sourmash
from abc import abstractmethod, ABC
from collections import namedtuple, Counter
from collections import defaultdict

from ..search import make_jaccard_search_query, make_gather_query
from ..manifest import CollectionManifest
from ..logging import debug_literal
from ..signature import load_signatures, save_signatures
from sourmash.search import (make_jaccard_search_query, make_gather_query,
calc_threshold_from_bp)
from sourmash.manifest import CollectionManifest
from sourmash.logging import debug_literal
from sourmash.signature import load_signatures, save_signatures

# generic return tuple for Index.search and Index.gather
IndexSearchResult = namedtuple('Result', 'score, signature, location')
Expand Down Expand Up @@ -277,8 +277,13 @@ def gather(self, query, threshold_bp=None, **kwargs):

return results[:1]

def peek(self, query_mh, threshold_bp=0):
"Mimic CounterGather.peek() on top of Index. Yes, this is backwards."
def peek(self, query_mh, *, threshold_bp=0):
"""Mimic CounterGather.peek() on top of Index.

This is implemented for situations where we don't want to use
'prefetch' functionality. It is a light wrapper around the
'gather'/search-by-containment method.
"""
from sourmash import SourmashSignature

# build a signature to use with self.gather...
Expand Down Expand Up @@ -323,7 +328,7 @@ def counter_gather(self, query, threshold_bp, **kwargs):
# find all matches and construct a CounterGather object.
counter = CounterGather(prefetch_query.minhash)
for result in self.prefetch(prefetch_query, threshold_bp, **kwargs):
counter.add(result.signature, result.location)
counter.add(result.signature, location=result.location)

# tada!
return counter
Expand Down Expand Up @@ -701,31 +706,42 @@ def select(self, **kwargs):


class CounterGather:
"""
Track and summarize matches for efficient 'gather' protocol. This
could be used downstream of prefetch (for example).
"""This is an ancillary class that is used to implement "fast
gather", post-prefetch. It tracks and summarize matches for
efficient min-set-cov/'gather'.

The class constructor takes a query MinHash that must be scaled, and
then takes signatures that have overlaps with the query (via 'add').

After all overlapping signatures have been loaded, the 'peek'
method is then used at each stage of the 'gather' procedure to
find the best match, and the 'consume' method is used to remove
a match from this counter.

The public interface is `peek(...)` and `consume(...)` only.
This particular implementation maintains a collections.Counter that
is used to quickly find the best match when 'peek' is called, but
other implementations are possible ;).
"""
def __init__(self, query_mh):
"Constructor - takes a query FracMinHash."
ctb marked this conversation as resolved.
Show resolved Hide resolved
if not query_mh.scaled:
raise ValueError('gather requires scaled signatures')

# track query
self.orig_query_mh = query_mh.copy().flatten()
self.scaled = query_mh.scaled

# track matching signatures & their locations
# use these to track loaded matches & their locations
self.siglist = []
self.locations = []

# ...and overlaps with query
# ...and also track overlaps with the progressive query
self.counter = Counter()

# cannot add matches once query has started.
# fence to make sure we do add matches once query has started.
self.query_started = 0

def add(self, ss, location=None, require_overlap=True):
def add(self, ss, *, location=None, require_overlap=True):
"Add this signature in as a potential match."
if self.query_started:
raise ValueError("cannot add more signatures to counter after peek/consume")
Expand All @@ -748,26 +764,11 @@ def downsample(self, scaled):
"Track highest scaled across all possible matches."
if scaled > self.scaled:
self.scaled = scaled
return self.scaled

def calc_threshold(self, threshold_bp, scaled, query_size):
# CTB: this code doesn't need to be in this class.
threshold = 0.0
n_threshold_hashes = 0

if threshold_bp:
# if we have a threshold_bp of N, then that amounts to N/scaled
# hashes:
n_threshold_hashes = float(threshold_bp) / scaled

# that then requires the following containment:
threshold = n_threshold_hashes / query_size

return threshold, n_threshold_hashes

def peek(self, cur_query_mh, threshold_bp=0):
def peek(self, cur_query_mh, *, threshold_bp=0):
"Get next 'gather' result for this database, w/o changing counters."
self.query_started = 1
scaled = cur_query_mh.scaled

# empty? nothing to search.
counter = self.counter
Expand All @@ -777,38 +778,39 @@ def peek(self, cur_query_mh, threshold_bp=0):
siglist = self.siglist
assert siglist

self.downsample(scaled)
scaled = self.scaled
scaled = self.downsample(cur_query_mh.scaled)
cur_query_mh = cur_query_mh.downsample(scaled=scaled)

if not cur_query_mh: # empty query? quit.
return []

# CTB: could probably remove this check unless debug requested.
if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1:
raise ValueError("current query not a subset of original query")

# are we setting a threshold?
threshold, n_threshold_hashes = self.calc_threshold(threshold_bp,
scaled,
len(cur_query_mh))
threshold, n_threshold_hashes = calc_threshold_from_bp(threshold_bp,
scaled,
len(cur_query_mh))
# is it too high to ever match? if so, exit.
if threshold > 1.0:
return []

# Find the best match -
# Find the best match using the internal Counter.
most_common = counter.most_common()
dataset_id, match_size = most_common[0]

# below threshold? no match!
if match_size < n_threshold_hashes:
return []

## at this point, we must have a legitimate match above threshold!
## at this point, we have a legitimate match above threshold!

# pull match and location.
match = siglist[dataset_id]

# calculate containment
# CTB: this check is probably redundant with intersect_mh calc, below.
cont = cur_query_mh.contained_by(match.minhash, downsample=True)
assert cont
assert cont >= threshold
Expand All @@ -822,7 +824,7 @@ def peek(self, cur_query_mh, threshold_bp=0):
return (IndexSearchResult(cont, match, location), intersect_mh)

def consume(self, intersect_mh):
"Remove the given hashes from this counter."
"Maintain the internal counter by removing the given hashes."
self.query_started = 1

if not intersect_mh:
Expand Down
21 changes: 20 additions & 1 deletion src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@
from .sketchcomparison import FracMinHashComparison, NumMinHashComparison


def calc_threshold_from_bp(threshold_bp, scaled, query_size):
ctb marked this conversation as resolved.
Show resolved Hide resolved
"""
Convert threshold_bp (threshold in estimated bp) to
fraction of query & minimum number of hashes needed.
"""
threshold = 0.0
n_threshold_hashes = 0

if threshold_bp:
# if we have a threshold_bp of N, then that amounts to N/scaled
# hashes:
n_threshold_hashes = float(threshold_bp) / scaled
ctb marked this conversation as resolved.
Show resolved Hide resolved

# that then requires the following containment:
threshold = n_threshold_hashes / query_size

return threshold, n_threshold_hashes


class SearchType(Enum):
JACCARD = 1
CONTAINMENT = 2
Expand Down Expand Up @@ -621,7 +640,7 @@ def _find_best(counters, query, threshold_bp):

# find the best score across multiple counters, without consuming
for counter in counters:
result = counter.peek(query.minhash, threshold_bp)
result = counter.peek(query.minhash, threshold_bp=threshold_bp)
if result:
(sr, intersect_mh) = result

Expand Down
Loading