Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] standardize and simplify search, prefetch, gather results by using dataclasses #1955

Merged
merged 31 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ec09882
add some search/gather/prefetch columns to enable ANI estimation
bluegenes Apr 15, 2022
e630498
Merge branch 'latest' into add-cols
bluegenes Apr 15, 2022
a983d48
Merge branch 'latest' into add-cols
bluegenes Apr 15, 2022
29fd59d
Merge branch 'latest' into add-cols
bluegenes Apr 15, 2022
479d315
fix introduced err
bluegenes Apr 15, 2022
36fd866
init SearchResult dataclass
bluegenes Apr 15, 2022
71569a1
define generic write_cols
bluegenes Apr 15, 2022
4ddf304
add prefetchresult class, clean up post_init repetitiveness later
bluegenes Apr 15, 2022
970f7a0
clean up
bluegenes Apr 15, 2022
2fb306a
add gatherresult dataclass
bluegenes Apr 15, 2022
0f53be0
rm unused line
bluegenes Apr 15, 2022
6308224
upd
bluegenes Apr 15, 2022
367b8c0
init searchresult tests
bluegenes Apr 16, 2022
644a4dd
use query_n_hashes; remove num
bluegenes Apr 16, 2022
95c925d
Merge branch 'add-cols' into searchresult-dataclass
bluegenes Apr 16, 2022
3039fbc
add basic gatherresult test
bluegenes Apr 16, 2022
6d2cb5a
Merge branch 'latest' into searchresult-dataclass
bluegenes Apr 16, 2022
40097a2
save in progress changes
bluegenes Apr 17, 2022
ac4387c
closer...
bluegenes Apr 17, 2022
c1c044b
closer still...
bluegenes Apr 18, 2022
22dbf1a
handle num sketches; clean up unnecessary sig comparison cls
bluegenes Apr 19, 2022
602d68f
use base classes properly to simplify
bluegenes Apr 19, 2022
d98c2c5
add tests for multiple rounds of downsampling in prefetch and gather …
ctb Apr 19, 2022
0d7b3a5
split sketchcomparison to new file; clean up *Result
bluegenes Apr 19, 2022
304ca24
add minhash tests
bluegenes Apr 19, 2022
3cfe41f
init sketchcomparison tests
bluegenes Apr 19, 2022
5160376
test incompatible sketch comparisons
bluegenes Apr 19, 2022
54db2f4
test failing *Results
bluegenes Apr 19, 2022
7633eaf
test num SearchResult
bluegenes Apr 19, 2022
a21c035
fix calcs for gather
bluegenes Apr 20, 2022
4e06116
upd with suggestions from code review
bluegenes Apr 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 12 additions & 27 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .logging import notify, error, print_results, set_quiet
from .sourmash_args import (FileOutput, FileOutputCSV,
SaveSignaturesToLocation)
from .search import SearchResult, prefetch_database, PrefetchResult, GatherResult, calculate_prefetch_info
from .search import prefetch_database, SearchResult, PrefetchResult, GatherResult
from .index import LazyLinearIndex

WATERMARK_SIZE = 10000
Expand Down Expand Up @@ -533,17 +533,13 @@ def search(args):
notify("** reporting only one match because --best-only was set")

if args.output:
fieldnames = SearchResult._fields

fieldnames = SearchResult.search_write_cols
with FileOutputCSV(args.output) as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)

w.writeheader()
for sr in results:
d = dict(sr._asdict())
del d['match']
del d['query']
w.writerow(d)
w.writerow(sr.writedict)

# save matching signatures upon request
if args.save_matches:
Expand Down Expand Up @@ -688,7 +684,7 @@ def gather(args):
prefetch_csvout_fp = None
prefetch_csvout_w = None
if args.save_prefetch_csv:
fieldnames = PrefetchResult._fields
fieldnames = PrefetchResult.prefetch_write_cols
prefetch_csvout_fp = FileOutput(args.save_prefetch_csv, 'wt').open()
prefetch_csvout_w = csv.DictWriter(prefetch_csvout_fp, fieldnames=fieldnames)
prefetch_csvout_w.writeheader()
Expand Down Expand Up @@ -717,12 +713,8 @@ def gather(args):
if prefetch_csvout_fp:
assert scaled
# calculate intersection stats and info
prefetch_result = calculate_prefetch_info(prefetch_query, found_sig, scaled, args.threshold_bp)
# remove match and query signatures; write result to prefetch csv
d = dict(prefetch_result._asdict())
del d['match']
del d['query']
prefetch_csvout_w.writerow(d)
prefetch_result = PrefetchResult(prefetch_query, found_sig, cmp_scaled=scaled, threshold_bp=args.threshold_bp)
prefetch_csvout_w.writerow(prefetch_result.writedict)

counters.append(counter)

Expand Down Expand Up @@ -803,14 +795,12 @@ def gather(args):

# save CSV?
if found and args.output:
fieldnames = GatherResult._fields
fieldnames = GatherResult.gather_write_cols
with FileOutputCSV(args.output) as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)
w.writeheader()
for result in found:
d = dict(result._asdict())
del d['match'] # actual signature not in CSV.
w.writerow(d)
w.writerow(result.writedict)

# save matching signatures?
if found and args.save_matches:
Expand Down Expand Up @@ -970,14 +960,12 @@ def multigather(args):

output_base = os.path.basename(query_filename)
output_csv = output_base + '.csv'
fieldnames = GatherResult._fields
fieldnames = GatherResult.gather_write_cols
with FileOutputCSV(output_csv) as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)
w.writeheader()
for result in found:
d = dict(result._asdict())
del d['match'] # actual signature not output to CSV!
w.writerow(d)
w.writerow(result.writedict)

output_matches = output_base + '.matches.sig'
with open(output_matches, 'wt') as fp:
Expand Down Expand Up @@ -1174,7 +1162,7 @@ def prefetch(args):
csvout_fp = None
csvout_w = None
if args.output:
fieldnames = PrefetchResult._fields
fieldnames = PrefetchResult.prefetch_write_cols
csvout_fp = FileOutput(args.output, 'wt').open()
csvout_w = csv.DictWriter(csvout_fp, fieldnames=fieldnames)
csvout_w.writeheader()
Expand Down Expand Up @@ -1231,10 +1219,7 @@ def prefetch(args):

# output match info as we go
if csvout_fp:
d = dict(result._asdict())
del d['match'] # actual signatures not in CSV.
del d['query']
csvout_w.writerow(d)
csvout_w.writerow(result.writedict)

# output match signatures as we go (maybe)
matches_out.add(match)
Expand Down
37 changes: 36 additions & 1 deletion src/sourmash/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ class FrozenMinHash - read-only MinHash class.
from __future__ import unicode_literals, division
from .distance_utils import jaccard_to_distance, containment_to_distance

import numpy as np


__all__ = ['get_minhash_default_seed',
'get_minhash_max_hash',
'hash_murmur',
Expand Down Expand Up @@ -686,6 +689,8 @@ def similarity(self, other, ignore_abundance=False, downsample=False):

def angular_similarity(self, other):
"Calculate the angular similarity."
if not (self.track_abundance and other.track_abundance):
raise TypeError("Error: Angular (cosine) similarity requires both sketches to track hash abundance.")
return self._methodcall(lib.kmerminhash_angular_similarity,
other._get_objptr())

Expand Down Expand Up @@ -854,7 +859,37 @@ def inflate(self, from_mh):

return abund_mh
else:
raise ValueError("inflate operates on a flat MinHash and takes a MinHash object with track_abundance=True")
raise ValueError("inflate operates on a flat MinHash and takes a MinHash object with track_abundance=True")

@property
def sum_abundances(self):
if self.track_abundance:
return sum(v for v in self.hashes.values())
return None

@property
def mean_abundance(self):
if self.track_abundance:
return np.mean(list(self.hashes.values()))
return None

@property
def median_abundance(self):
if self.track_abundance:
return np.median(list(self.hashes.values()))
return None

@property
def std_abundance(self):
if self.track_abundance:
return np.std(list(self.hashes.values()))
return None

@property
def covered_bp(self):
if not self.scaled:
raise TypeError("can only calculate bp for scaled MinHashes")
return len(self.hashes) * self.scaled


class FrozenMinHash(MinHash):
Expand Down
Loading