Skip to content

Commit

Permalink
upd with suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed Apr 20, 2022
1 parent a21c035 commit 4e06116
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 66 deletions.
39 changes: 20 additions & 19 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ def get_cmpinfo(self):
self.match_name = self.match.name
self.match_filename = self.match.filename
# sometimes filename is not set in sig (match_filename is None),
# and `search` is able to pass in the filename...
# and `search` is able to pass in the filename.
if self.filename is None and self.match_filename is not None:
self.filename = self.match_filename
self.match_md5 = self.match.md5sum()
# set these from self.match_*
self.md5= self.match_md5
self.name = self.match_name
# do we actually need these here? Def need for prefetch, but maybe define there?
# we may not need these here - could define in PrefetchResult instead
self.query_abundance = self.mh1.track_abundance
self.match_abundance = self.mh2.track_abundance
self.query_n_hashes = len(self.mh1.hashes)
Expand Down Expand Up @@ -250,17 +250,14 @@ def init_sigcomparison(self):

def __post_init__(self):
self.init_sigcomparison() # build sketch comparison
self.check_similarity() # set similarity (if not passed in)
self.check_similarity()

def check_similarity(self):
# Require similarity for SearchResult
# for now, require similarity for SearchResult
# future: consider returning SearchResult *during* search, and passing SearchType in.
# then allow similarity to be calculated here according to SearchType.
if self.similarity is None:
raise ValueError("Error: Must provide 'similarity' for SearchResult.")
#OR, if don't pass in similarity, return jaccard?
#if self.cmp.cmp_scaled is not None:
# self.similarity = self.cmp.mh1_containment
#else:
# self.similarity = self.cmp.jaccard

@property
def writedict(self):
Expand Down Expand Up @@ -320,14 +317,11 @@ def writedict(self):
return self.prefetchresultdict()



@dataclass
class GatherResult(PrefetchResult):
gather_querymh: MinHash = None
gather_result_rank: int = None
total_abund: int = None
# orig_query_len includes len(noident_mh), which has been subtracted out of query, right? This subtraction is an issue for query_covered bp,etc!
# can we set some original query info so we can get the right vals to output?
orig_query_len: int = None
orig_query_abunds: list = None

Expand Down Expand Up @@ -355,8 +349,15 @@ def check_gatherresult_input(self):
raise ValueError("Error: must provide original query abundances ('orig_query_abunds') to GatherResult")

def build_gather_result(self):
# build gather specific attributes
# build gather-specific attributes

# the 'query' that is passed into gather is all _matched_ hashes, after subtracting noident_mh
# this affects estimation of original query information, and requires us to pass in orig_query_len and orig_query_abunds.
# we also need to overwrite self.query_bp, self.query_n_hashes, and self.query_abundance
# todo: find a better solution?
self.query_bp = self.orig_query_len * self.query.minhash.scaled
self.query_n_hashes = self.orig_query_len

# calculate intersection with query hashes:
self.unique_intersect_bp = self.gather_comparison.intersect_bp

Expand All @@ -382,6 +383,8 @@ def build_gather_result(self):
self.average_abund = self.query_weighted_unique_intersection.mean_abundance
self.median_abund = self.query_weighted_unique_intersection.median_abundance
self.std_abund = self.query_weighted_unique_intersection.std_abundance
# 'query' will be flattened by default. reset track abundance if we have abunds
self.query_abundance = self.query_weighted_unique_intersection.track_abundance
# calculate scores weighted by abundances
self.f_unique_weighted = float(self.query_weighted_unique_intersection.sum_abundances) / self.total_abund
else:
Expand All @@ -396,17 +399,15 @@ def __post_init__(self):
def gatherresultdict(self):
# for gather, we only shorten the query_md5
self.query_md5 = self.shorten_md5(self.query_md5)
#self.md5 = self.shorten_md5(self.md5)
#self.match_md5 = self.shorten_md5(self.match_md5)
return self.to_write(columns=self.gather_write_cols)

@property
def writedict(self):
return self.gatherresultdict()

# we can write prefetch results from a GatherResult
@property
def prefetchwritedict(self):
# enable writing prefetch csv from a GatherResult
self.build_prefetch_result()
return self.prefetchresultdict()

Expand Down Expand Up @@ -457,7 +458,7 @@ def search_databases_with_abund_query(query, databases, **kwargs):
raise TypeError("containment searches cannot be done with abund sketches")

for db in databases:
search_iter = db.search_abund(query, **kwargs)
search_iter = db.search_abund(query, **kwargs) # could return SearchResult here instead of tuple?
for (score, match, filename) in search_iter:
md5 = match.md5sum()
if md5 not in found_md5:
Expand All @@ -470,7 +471,7 @@ def search_databases_with_abund_query(query, databases, **kwargs):
x = []
for (score, match, filename) in results:
x.append(SearchResult(query, match,
similarity=score, # this is actually cosine sim (abund). do we want to specify this in SearchResult somehow?
similarity=score,
filename = filename))
return x

Expand Down Expand Up @@ -663,7 +664,7 @@ def prefetch_database(query, database, threshold_bp):
assert scaled

# iterate over all signatures in database, find matches
for result in database.prefetch(query, threshold_bp):
for result in database.prefetch(query, threshold_bp): # future: could return PrefetchResult directly here
#result = calculate_prefetch_info(query, result.signature, threshold_bp)
result = PrefetchResult(query, result.signature, threshold_bp=threshold_bp)
assert result.pass_threshold
Expand Down
62 changes: 26 additions & 36 deletions src/sourmash/sketchcomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,6 @@ class BaseMinHashComparison:
mh2: MinHash
ignore_abundance: bool = False # optionally ignore abundances

def check_comparison_compatibility(self):
# do we need this check + error? Minhash functions should already complain appropriately...
k1 = self.mh1.ksize
k2 = self.mh2.ksize
if k1 != k2:
raise TypeError(f"Error: Invalid Comparison, ksizes: {k1}, {k2}. Must compare sketches of the same ksize.")
self.ksize = self.mh1.ksize
m1 = self.mh1.moltype
m2= self.mh2.moltype
if m1 != m2:
raise TypeError(f"Error: Invalid Comparison, moltypes: {m1}, {m2}. Must compare sketches of the same moltype.")
self.moltype= self.mh1.moltype
# check num, scaled
if not any([(self.mh1.num and self.mh2.num), (self.mh1.scaled and self.mh2.scaled)]):
raise TypeError("Error: Both sketches must be 'num' or 'scaled'.")

def downsample_and_handle_ignore_abundance(self, cmp_num=None, cmp_scaled=None):
"""
Downsample and/or flatten minhashes for comparison
Expand All @@ -49,6 +33,17 @@ def downsample_and_handle_ignore_abundance(self, cmp_num=None, cmp_scaled=None):
else:
raise ValueError("Error: must pass in a comparison scaled or num value.")

def check_compatibility_and_downsample(self, cmp_num=None, cmp_scaled=None):
if not any([(self.mh1.num and self.mh2.num), (self.mh1.scaled and self.mh2.scaled)]):
raise TypeError("Error: Both sketches must be 'num' or 'scaled'.")

#need to downsample first because is_compatible checks scaled (though does not check num)
self.downsample_and_handle_ignore_abundance(cmp_num=cmp_num, cmp_scaled=cmp_scaled)
if not self.mh1_cmp.is_compatible(self.mh2_cmp):
raise TypeError("Error: Cannot compare incompatible sketches.")
self.ksize = self.mh1.ksize
self.moltype = self.mh1.moltype

@property
def intersect_mh(self):
# flatten and intersect
Expand All @@ -60,12 +55,8 @@ def jaccard(self):

@property
def angular_similarity(self):
# Note: this currently throws TypeError if self.ignore_abundance.
return self.mh1_cmp.angular_similarity(self.mh2_cmp)
# do we want to shield against error here? Or let TypeError through?
#if not (self.mh1_cmp.track_abundance and self.mh2_cmp.track_abundance):
# return self.mh1_cmp.angular_similarity(self.mh2_cmp)
#else:
# return ""

@property
def cosine_similarity(self):
Expand All @@ -74,31 +65,27 @@ def cosine_similarity(self):

@dataclass
class NumMinHashComparison(BaseMinHashComparison):
"""Class for standard comparison between two scaled minhashes"""
"""Class for standard comparison between two num minhashes"""
cmp_num: int = None

def __post_init__(self):
"Initialize NumMinHashComparison using values from provided MinHashes"
if self.cmp_num is None: # record the num we're doing this compa#rison on
if self.cmp_num is None: # record the num we're doing this comparison on
self.cmp_num = min(self.mh1.num, self.mh2.num)
self.check_comparison_compatibility()
self.downsample_and_handle_ignore_abundance(cmp_num=self.cmp_num)
self.check_compatibility_and_downsample(cmp_num=self.cmp_num)

@dataclass
class FracMinHashComparison(BaseMinHashComparison):
"""Class for standard comparison between two scaled minhashes"""
cmp_scaled: int = None # scaled value for this comparison (defaults to maximum scaled between the two sigs)
cmp_scaled: int = None # optionally force scaled value for this comparison
threshold_bp: int = 0

def __post_init__(self):
"Initialize ScaledComparison using values from provided FracMinHashes"
if self.cmp_scaled is None: # record the scaled we're doing this comparison on
if self.cmp_scaled is None:
# comparison scaled defaults to maximum scaled between the two sigs
self.cmp_scaled = max(self.mh1.scaled, self.mh2.scaled)
self.check_comparison_compatibility()
self.downsample_and_handle_ignore_abundance(cmp_scaled=self.cmp_scaled)
# for these, do we want the originals, or the cmp_scaled versions?? (or both?). do we need them at all?
self.mh1_scaled = self.mh1.scaled
self.mh2_scaled = self.mh2.scaled
self.check_compatibility_and_downsample(cmp_scaled=self.cmp_scaled)

@property
def pass_threshold(self):
Expand Down Expand Up @@ -128,11 +115,14 @@ def weighted_intersection(self, from_mh=None, from_abundD={}):
# map abundances to all intersection hashes.
abund_mh = self.intersect_mh.copy_and_clear()
abund_mh.track_abundance = True
if from_mh is not None:
# if from_mh is provided, it takes precedence over from_abund dict
if from_mh is not None and from_mh.track_abundance:
from_abundD = from_mh.hashes
if from_abundD is not None:
#abundD[k] for k in intersect_mh.hashes
if from_abundD:
# this sets any hash not present in abundD to 1. Is that desired? Or should we return 0?
abunds = {k: from_abundD.get(k, 1) for k in self.intersect_mh.hashes }
abund_mh.set_abundances(abunds)
return abund_mh
return self.intersect_mh # or do we want to set all abundances to 1?
# if no abundances are passed in, return intersect_mh
# future note: do we want to return 1 as abundance instead?
return self.intersect_mh
8 changes: 4 additions & 4 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,20 +577,20 @@ def test_GatherResult_incomplete_input_orig_query_abunds():
GatherResult(ss47, ss4763, cmp_scaled=1000,
gather_querymh=ss47.minhash,
gather_result_rank=1,
total_abund = None,
total_abund = 1,
orig_query_len=len(ss47.minhash),
orig_query_abunds=orig_query_abunds)
print(str(exc))
assert "Error: must provide sum of all abundances ('total_abund') to GatherResult" in str(exc)
assert "Error: must provide original query abundances ('orig_query_abunds') to GatherResult" in str(exc)

orig_query_abunds = {}

with pytest.raises(ValueError) as exc:
GatherResult(ss47, ss4763, cmp_scaled=1000,
gather_querymh=ss47.minhash,
gather_result_rank=1,
total_abund = 0,
total_abund = 1,
orig_query_len=len(ss47.minhash),
orig_query_abunds=orig_query_abunds)
print(str(exc))
assert "Error: must provide sum of all abundances ('total_abund') to GatherResult" in str(exc)
assert "Error: must provide original query abundances ('orig_query_abunds') to GatherResult" in str(exc)
55 changes: 48 additions & 7 deletions tests/test_sketchcomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


from sourmash.minhash import MinHash
from sourmash.sketchcomparison import BaseMinHashComparison, FracMinHashComparison, NumMinHashComparison
from sourmash.sketchcomparison import FracMinHashComparison, NumMinHashComparison

import sourmash_tst_utils as utils

Expand Down Expand Up @@ -287,7 +287,7 @@ def test_FracMinHashComparison_incompatible_ksize(track_abundance):
with pytest.raises(TypeError) as exc:
FracMinHashComparison(a, b)
print(str(exc))
assert "Error: Invalid Comparison, ksizes: 31, 21. Must compare sketches of the same ksize." in str(exc)
assert "Error: Cannot compare incompatible sketches." in str(exc)


def test_FracMinHashComparison_incompatible_moltype(track_abundance):
Expand All @@ -307,8 +307,7 @@ def test_FracMinHashComparison_incompatible_moltype(track_abundance):
with pytest.raises(TypeError) as exc:
FracMinHashComparison(a, b)
print(str(exc))
assert "Error: Invalid Comparison, moltypes: DNA, protein. Must compare sketches of the same moltype." in str(exc)

assert "Error: Cannot compare incompatible sketches." in str(exc)


def test_FracMinHashComparison_incompatible_sketchtype(track_abundance):
Expand All @@ -331,6 +330,27 @@ def test_FracMinHashComparison_incompatible_sketchtype(track_abundance):
assert "Error: Both sketches must be 'num' or 'scaled'." in str(exc)


def test_FracMinHashComparison_incompatible_cmp_scaled(track_abundance):
# pass in too low of a cmp_scaled value
a = MinHash(0, 31, scaled=1, track_abundance=track_abundance)
b = MinHash(0, 31, scaled=10, track_abundance=track_abundance)

a_values = { 1:5, 3:3, 5:2, 8:2}
b_values = { 1:3, 3:2, 5:1, 6:1, 8:1, 10:1 }

if track_abundance:
a.set_abundances(a_values)
b.set_abundances(b_values)
else:
a.add_many(a_values.keys())
b.add_many(b_values.keys())

with pytest.raises(ValueError) as exc:
FracMinHashComparison(a, b, cmp_scaled = 1)
print(str(exc))
assert "new scaled 1 is lower than current sample scaled 10" in str(exc)


def test_FracMinHashComparison_redownsample_without_scaled(track_abundance):
a = MinHash(0, 31, scaled=1, track_abundance=track_abundance)
b = MinHash(0, 31, scaled=10, track_abundance=track_abundance)
Expand All @@ -349,7 +369,7 @@ def test_FracMinHashComparison_redownsample_without_scaled(track_abundance):
assert cmp.cmp_scaled == 10

with pytest.raises(ValueError) as exc:
# try to redownsample without passing in cmp_num
# try to redownsample without passing in cmp_scaled
cmp.downsample_and_handle_ignore_abundance()
print(str(exc))
assert "Error: must pass in a comparison scaled or num value." in str(exc)
Expand Down Expand Up @@ -505,7 +525,7 @@ def test_NumMinHashComparison_incompatible_ksize(track_abundance):
with pytest.raises(TypeError) as exc:
NumMinHashComparison(a_num, b_num)
print(str(exc))
assert "Error: Invalid Comparison, ksizes: 31, 21. Must compare sketches of the same ksize." in str(exc)
assert "Error: Cannot compare incompatible sketches." in str(exc)


def test_NumMinHashComparison_incompatible_moltype(track_abundance):
Expand All @@ -525,7 +545,7 @@ def test_NumMinHashComparison_incompatible_moltype(track_abundance):
with pytest.raises(TypeError) as exc:
NumMinHashComparison(a_num, b_num)
print(str(exc))
assert "Error: Invalid Comparison, moltypes: DNA, protein. Must compare sketches of the same moltype." in str(exc)
assert "Error: Cannot compare incompatible sketches." in str(exc)


def test_NumMinHashComparison_incompatible_sketchtype(track_abundance):
Expand Down Expand Up @@ -569,3 +589,24 @@ def test_NumMinHashComparison_redownsample_without_num(track_abundance):
cmp.downsample_and_handle_ignore_abundance()
print(str(exc))
assert "Error: must pass in a comparison scaled or num value." in str(exc)


def test_NumMinHashComparison_incompatible_cmp_num(track_abundance):
# pass in too high of a cmp_num value
a = MinHash(200, 31, track_abundance=track_abundance)
b = MinHash(100, 31, track_abundance=track_abundance)

a_values = { 1:5, 3:3, 5:2, 8:2}
b_values = { 1:3, 3:2, 5:1, 6:1, 8:1, 10:1 }

if track_abundance:
a.set_abundances(a_values)
b.set_abundances(b_values)
else:
a.add_many(a_values.keys())
b.add_many(b_values.keys())

with pytest.raises(ValueError) as exc:
NumMinHashComparison(a, b, cmp_num = 150)
print(str(exc))
assert "new sample num is higher than current sample num" in str(exc)

0 comments on commit 4e06116

Please sign in to comment.