Skip to content

Commit

Permalink
[MRG] do not report untrusted jaccard ANI (#2011)
Browse files Browse the repository at this point in the history
* do not report untrusted jaccard ANI from search

* by default, do not return ANI from jaccardANIResult if untrustworthy

* clean up

* add sig and mh tests too
  • Loading branch information
bluegenes authored May 2, 2022
1 parent d425f6f commit adf35ea
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 48 deletions.
8 changes: 8 additions & 0 deletions src/sourmash/distance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class jaccardANIResult(ANIResult):
"""Class for distance/ANI from jaccard (includes jaccard_error)."""
jaccard_error: float = None
je_threshold: float = 1e-4
return_ani_despite_threshold: bool = False

def __post_init__(self):
# check values
Expand All @@ -72,6 +73,13 @@ def __post_init__(self):
else:
raise ValueError("Error: jaccard_error cannot be None.")

@property
def ani(self):
# if jaccard error is too high (exceeds threshold), do not trust ANI estimate
if self.je_exceeds_threshold and not self.return_ani_despite_threshold:
return ""
return 1 - self.dist


@dataclass
class ciANIResult(ANIResult):
Expand Down
3 changes: 0 additions & 3 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,6 @@ def estimate_search_ani(self):
elif self.searchtype == SearchType.JACCARD:
self.cmp.estimate_jaccard_ani(jaccard=self.similarity)
self.ani = self.cmp.jaccard_ani
# Jaccard error was too high for ANI estimation.
# Just report, or do we want to do something else?
self.ani_untrustworthy = self.cmp.jaccard_ani_untrustworthy
# this can be set from any of the above
self.potential_false_negative = self.cmp.potential_false_negative

Expand Down
23 changes: 13 additions & 10 deletions tests/test_distance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,18 @@ def test_aniresult_bad_distance():


def test_jaccard_aniresult():
res = jaccardANIResult(0.4, 0.1, jaccard_error=0.03)
assert res.dist == 0.4
res = jaccardANIResult(0.4, 0.1, jaccard_error=0.03, return_ani_despite_threshold=True)
res2 = jaccardANIResult(0.4, 0.1, jaccard_error=0.03)
assert res.dist == res2.dist == 0.4
assert res.ani == 0.6
assert res.p_nothing_in_common == 0.1
assert res2.ani == ""
assert res.p_nothing_in_common == res2.p_nothing_in_common == 0.1
assert res.jaccard_error == 0.03
assert res.p_exceeds_threshold ==True
assert res.je_exceeds_threshold ==True
res2 = jaccardANIResult(0.4, 0.1, jaccard_error=0.03, je_threshold=0.1)
assert res2.je_exceeds_threshold ==False
res3 = jaccardANIResult(0.4, 0.1, jaccard_error=0.03, je_threshold=0.1)
assert res3.je_exceeds_threshold ==False
assert res3.ani == 0.6


def test_jaccard_aniresult_nojaccarderror():
Expand Down Expand Up @@ -260,8 +263,8 @@ def test_jaccard_to_distance_scaled():
res = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers)
print(res)
# check results
assert res.dist == 0.019122659390482077
assert res.ani == 0.9808773406095179
assert round(res.dist, 3) == round(0.019122659390482077, 3)
assert res.ani == ""
assert res.p_exceeds_threshold == False
assert res.jaccard_error == 0.00018351337045518042
assert res.je_exceeds_threshold ==True
Expand All @@ -282,12 +285,12 @@ def test_jaccard_to_distance_k31():
res = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers)
print(res)
# check results
assert res.ani == 0.9870056455892898
assert res.p_exceeds_threshold == False
assert res.je_exceeds_threshold ==True
assert res.ani == ""
assert res.p_exceeds_threshold == False
res2 = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers, err_threshold=0.1)
assert res2.ani == res.ani
assert res2.je_exceeds_threshold == False
assert res2.ani == 0.9870056455892898


def test_jaccard_to_distance_k31_2():
Expand Down
14 changes: 14 additions & 0 deletions tests/test_minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2914,6 +2914,20 @@ def test_jaccard_ANI():
assert (m1_jani_m2.ani, m1_jani_m2.p_nothing_in_common, m1_jani_m2.jaccard_error) == (0.9783711630110239, 0.0, 3.891666770716877e-07)


def test_jaccard_ANI_untrustworthy():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
mh1 = sourmash.load_one_signature(f1, ksize=31).minhash
mh2 = sourmash.load_one_signature(f2).minhash

print("\nJACCARD_ANI", mh1.jaccard_ani(mh2))

m1_jani_m2 = mh1.jaccard_ani(mh2, err_threshold=1e-7)
assert m1_jani_m2.ani == ""
assert m1_jani_m2.je_exceeds_threshold==True
assert m1_jani_m2.je_threshold == 1e-7


def test_jaccard_ANI_precalc_jaccard():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
Expand Down
14 changes: 14 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,20 @@ def test_jaccard_ANI():
assert (s1_jani_s2.ani, s1_jani_s2.p_nothing_in_common, s1_jani_s2.jaccard_error) == (0.9783711630110239, 0.0, 3.891666770716877e-07)


def test_jaccard_ANI_untrustworthy():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
ss1 = sourmash.load_one_signature(f1, ksize=31)
ss2 = sourmash.load_one_signature(f2)

print("\nJACCARD_ANI", ss1.jaccard_ani(ss2))

s1_jani_s2 = ss1.jaccard_ani(ss2, err_threshold=1e-7)
assert s1_jani_s2.ani == ""
assert s1_jani_s2.je_exceeds_threshold==True
assert s1_jani_s2.je_threshold == 1e-7


def test_jaccard_ANI_precalc_jaccard():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
Expand Down
96 changes: 61 additions & 35 deletions tests/test_sourmash.py
Original file line number Diff line number Diff line change
Expand Up @@ -5321,6 +5321,32 @@ def test_standalone_manifest_search_fail(runtmp):

@utils.in_tempdir
def test_search_ani_jaccard(c):
sig47 = utils.get_test_data('47.fa.sig')
sig4763 = utils.get_test_data('47+63.fa.sig')

c.run_sourmash('search', sig47, sig4763, '-o', 'xxx.csv')
print(c.last_result.status, c.last_result.out, c.last_result.err)

search_result_names = SearchResult.search_write_cols

csv_file = c.output('xxx.csv')

with open(csv_file) as fp:
reader = csv.DictReader(fp)
row = next(reader)
print(row)
assert search_result_names == list(row.keys())
assert float(row['similarity']) == 0.6564798376870403
assert row['filename'].endswith('47+63.fa.sig')
assert row['md5'] == '491c0a81b2cfb0188c0d3b46837c2f42'
assert row['query_filename'].endswith('47.fa')
assert row['query_name'] == 'NC_009665.1 Shewanella baltica OS185, complete genome'
assert row['query_md5'] == '09a08691'
assert row['ani'] == "0.992530907924384"


@utils.in_tempdir
def test_search_ani_jaccard_error_too_high(c):
testdata1 = utils.get_test_data('short.fa')
testdata2 = utils.get_test_data('short2.fa')
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', testdata1, testdata2)
Expand All @@ -5343,7 +5369,8 @@ def test_search_ani_jaccard(c):
assert row['query_filename'].endswith('short.fa')
assert row['query_name'] == ''
assert row['query_md5'] == '9191284a'
assert row['ani'] == "0.9987884602947684"
#assert row['ani'] == "0.9987884602947684"
assert row['ani'] == ""


@utils.in_tempdir
Expand Down Expand Up @@ -5522,54 +5549,53 @@ def test_search_ani_max_containment_estimate_ci(c):

@utils.in_tempdir
def test_search_jaccard_ani_downsample(c):
testdata1 = utils.get_test_data('short.fa')
testdata2 = utils.get_test_data('short2.fa')
sig1_out = c.output('short.fa.sig')
sig2_out = c.output('short2.fa.sig')
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=2', '--force', testdata1, '-o', sig1_out)
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', '--force', testdata1, '-o', sig2_out)
sig1 = sourmash.load_one_signature(sig1_out)
sig2 = sourmash.load_one_signature(sig2_out)
print(f"SCALED: sig1: {sig1.minhash.scaled}, sig2: {sig2.minhash.scaled}") # if don't change name, just reads prior sigfile!!?

sig1F = c.output('sig1.sig')
sig2F = c.output('sig2.sig')
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=2', '--force', testdata1, '-o', sig1F)
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', '--force', testdata2, '-o', sig2F)

sig1 = sourmash.load_one_signature(sig1F)
sig2 = sourmash.load_one_signature(sig2F)
print(f"SCALED: sig1: {sig1.minhash.scaled}, sig2: {sig2.minhash.scaled}")

c.run_sourmash('search', sig1F, sig2F, '-o', 'xdx.csv')
sig47 = utils.get_test_data('47.fa.sig')
sig4763 = utils.get_test_data('47+63.fa.sig')
ss47 = sourmash.load_one_signature(sig47)
ss4763 = sourmash.load_one_signature(sig4763)
print(f"SCALED: sig1: {ss47.minhash.scaled}, sig2: {ss4763.minhash.scaled}")

c.run_sourmash('search', sig47, sig4763, '-o', 'xxx.csv')
print(c.last_result.status, c.last_result.out, c.last_result.err)

csv_file = c.output('xdx.csv')
search_result_names = SearchResult.search_write_cols
search_result_names_ci = SearchResult.search_write_cols_ci

csv_file = c.output('xxx.csv')

with open(csv_file) as fp:
reader = csv.DictReader(fp)
row = next(reader)
print(row)
assert search_result_names == list(row.keys())
assert search_result_names_ci != list(row.keys())
assert float(row['similarity']) == 0.9296066252587992
assert row['md5'] == 'bf752903d635b1eb83c53fe4aae951db'
assert row['filename'].endswith('sig2.sig')
assert row['query_filename'].endswith('short.fa')
assert row['query_name'] == ''
assert row['query_md5'] == '8f74b0b8'
assert row['ani'] == "0.9988019200011651"
assert float(row['similarity']) == 0.6564798376870403
assert row['filename'].endswith('47+63.fa.sig')
assert row['md5'] == '491c0a81b2cfb0188c0d3b46837c2f42'
assert row['query_filename'].endswith('47.fa')
assert row['query_name'] == 'NC_009665.1 Shewanella baltica OS185, complete genome'
assert row['query_md5'] == '09a08691'
assert row['ani'] == "0.992530907924384"

# downsample one and check similarity and ANI
ds_sig47 = c.output("ds_sig47.sig")
c.run_sourmash('sig', "downsample", sig47, "--scaled", "2000", '-o', ds_sig47)
c.run_sourmash('search', ds_sig47, sig4763, '-o', 'xxx.csv')

csv_file = c.output('xxx.csv')
with open(csv_file) as fp:
reader = csv.DictReader(fp)
row = next(reader)
print(row)
assert round(float(row['similarity']), 3) == round(0.6634517766497462, 3)
assert round(float(row['ani']), 3) == round(0.992530907924384, 3)

#downsample manually and assert same ANI
mh1 = sig1.minhash
mh2 = sig2.minhash
mh2_sc2 = mh2.downsample(scaled=mh1.scaled)
print("SCALED:", mh1.scaled, mh2_sc2.scaled)
ani_info = mh1.jaccard_ani(mh2_sc2)
ss47_ds = signature.load_one_signature(ds_sig47)
print("SCALED:", ss47_ds.minhash.scaled, ss4763.minhash.scaled)
ani_info = ss47_ds.jaccard_ani(ss4763, downsample=True)
print(ani_info)
assert ani_info.ani == 0.9988019200011651
assert round(ani_info.ani, 3) == round(0.992530907924384, 3)


def test_gather_ani_csv(runtmp, linear_gather, prefetch_gather):
Expand Down

0 comments on commit adf35ea

Please sign in to comment.