Skip to content

Commit

Permalink
Merge branch 'add/picklist_zf_manifests' into add/manifest_lazy_sigfile
Browse files Browse the repository at this point in the history
  • Loading branch information
ctb committed Jun 22, 2021
2 parents 287cb7b + c3f1a3d commit 6ebec9c
Show file tree
Hide file tree
Showing 13 changed files with 1,171 additions and 124 deletions.
12 changes: 12 additions & 0 deletions doc/command-line.md
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,18 @@ One way to build a picklist is to use `sourmash sig describe --csv
out.csv <signatures>` to construct an initial CSV file that you can
then edit further.

The picklist functionality also supports excluding (rather than
including) signatures matching the picklist arguments. To specify a
picklist for exclusion, add `:exclude` to the `--picklist` argument
string, e.g. `pickfile:colname:coltype:exclude`.

For example,
```
sourmash sig extract --picklist list.csv:md5:md5sum:exclude <signatures>
```
will extract only the signatures that have md5sums that **do not** match
entries in the column `md5sum` in the CSV file `list.csv`.

In addition to `sig extract`, the following commands support
`--picklist` selection: `index`, `search`, `gather`, `prefetch`,
`compare`, `index`, and `lca index`.
Expand Down
90 changes: 65 additions & 25 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def _yield_all_sigs(queries, ksize, moltype):


def gather(args):
from .search import gather_databases, format_bp
from .search import GatherDatabases, format_bp

set_quiet(args.quiet, args.debug)
moltype = sourmash_args.calculate_moltype(args)
Expand Down Expand Up @@ -676,36 +676,47 @@ def gather(args):
notify("Starting prefetch sweep across databases.")
prefetch_query = query.copy()
prefetch_query.minhash = prefetch_query.minhash.flatten()
noident_mh = prefetch_query.minhash.to_mutable()
save_prefetch = SaveSignaturesToLocation(args.save_prefetch)
save_prefetch.open()

counters = []
for db in databases:
counter = None
try:
counter = db.counter_gather(prefetch_query, args.threshold_bp)
except ValueError:
if picklist:
# catch "no signatures to search" ValueError...
continue
else:
raise # re-raise other errors, if no picklist.

save_prefetch.add_many(counter.siglist)
# subtract found hashes as we can.
for found_sig in counter.siglist:
noident_mh.remove_many(found_sig.minhash)
counters.append(counter)

notify(f"Found {len(save_prefetch)} signatures via prefetch; now doing gather.")
save_prefetch.close()
else:
counters = databases
# we can't track unidentified hashes w/o prefetch
noident_mh = None

## ok! now do gather -

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
orig_query_mh = query.minhash
next_query = query
gather_iter = GatherDatabases(query, counters,
threshold_bp=args.threshold_bp,
ignore_abundance=args.ignore_abundance,
noident_mh=noident_mh)

gather_iter = gather_databases(query, counters, args.threshold_bp,
args.ignore_abundance)
for result, weighted_missed, next_query in gather_iter:
for result, weighted_missed in gather_iter:
if not len(found): # first result? print header.
if is_abundance:
print_results("")
Expand Down Expand Up @@ -737,6 +748,11 @@ def gather(args):
break


# report on thresholding -
if gather_iter.query:
# if still a query, then we failed the threshold.
notify(f'found less than {format_bp(args.threshold_bp)} in common. => exiting')

# basic reporting:
print_results(f'\nfound {len(found)} matches total;')
if args.num_results and len(found) == args.num_results:
Expand All @@ -745,6 +761,8 @@ def gather(args):
p_covered = (1 - weighted_missed) * 100
print_results(f'the recovered matches hit {p_covered:.1f}% of the query')
print_results('')
if gather_iter.scaled != query.minhash.scaled:
print_results(f'WARNING: final scaled was {gather_iter.scaled}, vs query scaled of {query.minhash.scaled}')

# save CSV?
if found and args.output:
Expand Down Expand Up @@ -772,25 +790,31 @@ def gather(args):

# save unassigned hashes?
if args.output_unassigned:
if not len(next_query.minhash):
remaining_query = gather_iter.query
if not (remaining_query.minhash or noident_mh):
notify('no unassigned hashes to save with --output-unassigned!')
else:
notify(f"saving unassigned hashes to '{args.output_unassigned}'")

if noident_mh:
remaining_mh = remaining_query.minhash.to_mutable()
remaining_mh += noident_mh
remaining_query.minhash = remaining_mh

if is_abundance:
# next_query is flattened; reinflate abundances
hashes = set(next_query.minhash.hashes)
# remaining_query is flattened; reinflate abundances
hashes = set(remaining_query.minhash.hashes)
orig_abunds = orig_query_mh.hashes
abunds = { h: orig_abunds[h] for h in hashes }

abund_query_mh = orig_query_mh.copy_and_clear()
# orig_query might have been downsampled...
abund_query_mh.downsample(scaled=next_query.minhash.scaled)
abund_query_mh.downsample(scaled=gather_iter.scaled)
abund_query_mh.set_abundances(abunds)
next_query.minhash = abund_query_mh
remaining_query.minhash = abund_query_mh

with FileOutput(args.output_unassigned, 'wt') as fp:
sig.save_signatures([ next_query ], fp)
sig.save_signatures([ remaining_query ], fp)

if picklist:
sourmash_args.report_picklist(args, picklist)
Expand All @@ -800,7 +824,7 @@ def gather(args):

def multigather(args):
"Gather many signatures against multiple databases."
from .search import gather_databases, format_bp
from .search import GatherDatabases, format_bp

set_quiet(args.quiet)
moltype = sourmash_args.calculate_moltype(args)
Expand Down Expand Up @@ -858,16 +882,23 @@ def multigather(args):
counters = []
prefetch_query = query.copy()
prefetch_query.minhash = prefetch_query.minhash.flatten()
noident_mh = prefetch_query.minhash.to_mutable()

counters = []
for db in databases:
counter = db.counter_gather(prefetch_query, args.threshold_bp)
for found_sig in counter.siglist:
noident_mh.remove_many(found_sig.minhash)
counters.append(counter)

found = []
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
for result, weighted_missed, next_query in gather_databases(query, counters, args.threshold_bp, args.ignore_abundance):
gather_iter = GatherDatabases(query, counters,
threshold_bp=args.threshold_bp,
ignore_abundance=args.ignore_abundance,
noident_mh=noident_mh)
for result, weighted_missed in gather_iter:
if not len(found): # first result? print header.
if is_abundance:
print_results("")
Expand Down Expand Up @@ -895,6 +926,10 @@ def multigather(args):
name)
found.append(result)

# report on thresholding -
if gather_iter.query.minhash:
# if still a query, then we failed the threshold.
notify(f'found less than {format_bp(args.threshold_bp)} in common. => exiting')

# basic reporting
print_results('\nfound {} matches total;', len(found))
Expand Down Expand Up @@ -938,18 +973,21 @@ def multigather(args):

output_unassigned = output_base + '.unassigned.sig'
with open(output_unassigned, 'wt') as fp:
remaining_query = gather_iter.query
if noident_mh:
remaining_mh = remaining_query.minhash.to_mutable()
remaining_mh += noident_mh.downsample(scaled=remaining_mh.scaled)
remaining_query.minhash = remaining_mh

if not found:
notify('nothing found - entire query signature unassigned.')
elif not len(query.minhash):
elif not remaining_query:
notify('no unassigned hashes! not saving.')
else:
notify('saving unassigned hashes to "{}"', output_unassigned)

e = MinHash(ksize=query.minhash.ksize, n=0,
scaled=next_query.minhash.scaled)
e.add_many(next_query.minhash.hashes)
# CTB: note, multigather does not save abundances
sig.save_signatures([ sig.SourmashSignature(e) ], fp)
sig.save_signatures([ remaining_query ], fp)
n += 1

# fini, next query!
Expand Down Expand Up @@ -1134,6 +1172,7 @@ def prefetch(args):

# iterate over signatures in db one at a time, for each db;
# find those with sufficient overlap
ident_mh = query_mh.copy_and_clear()
noident_mh = query_mh.to_mutable()

did_a_search = False # track whether we did _any_ search at all!
Expand All @@ -1157,8 +1196,10 @@ def prefetch(args):
for result in prefetch_database(query, db, args.threshold_bp):
match = result.match

# track remaining "untouched" hashes.
noident_mh.remove_many(match.minhash.hashes)
# track found & "untouched" hashes.
match_mh = match.minhash.downsample(scaled=query.minhash.scaled)
ident_mh += query.minhash & match_mh.flatten()
noident_mh.remove_many(match.minhash)

# output match info as we go
if csvout_fp:
Expand Down Expand Up @@ -1194,15 +1235,14 @@ def prefetch(args):
notify(f"saved {matches_out.count} matches to CSV file '{args.output}'")
csvout_fp.close()

matched_query_mh = query_mh.to_mutable()
matched_query_mh.remove_many(noident_mh.hashes)
notify(f"of {len(query_mh)} distinct query hashes, {len(matched_query_mh)} were found in matches above threshold.")
assert len(query_mh) == len(ident_mh) + len(noident_mh)
notify(f"of {len(query_mh)} distinct query hashes, {len(ident_mh)} were found in matches above threshold.")
notify(f"a total of {len(noident_mh)} query hashes remain unmatched.")

if args.save_matching_hashes:
filename = args.save_matching_hashes
notify(f"saving {len(matched_query_mh)} matched hashes to '{filename}'")
ss = sig.SourmashSignature(matched_query_mh)
notify(f"saving {len(ident_mh)} matched hashes to '{filename}'")
ss = sig.SourmashSignature(ident_mh)
with open(filename, "wt") as fp:
sig.save_signatures([ss], fp)

Expand Down
12 changes: 9 additions & 3 deletions src/sourmash/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,11 +1052,17 @@ def write_csv_header(cls, fp):
w = csv.DictWriter(fp, fieldnames=cls.required_keys)
w.writeheader()

def write_to_csv(self, fp):
def write_to_csv(self, fp, write_header=False):
"write manifest CSV to specified file handle"
w = csv.DictWriter(fp, fieldnames=self.required_keys)

if write_header:
self.write_csv_header(fp)

for row in self.rows:
# don't write signature!
if 'signature' in row:
del row['signature']
w.writerow(row)

@classmethod
Expand Down Expand Up @@ -1122,7 +1128,7 @@ def _select(self, *, ksize=None, moltype=None, scaled=0, num=0,

if picklist:
matching_rows = ( row for row in matching_rows
if picklist.matches_siginfo(row) )
if picklist.matches_manifest_row(row) )

# return only the internal filenames!
for row in matching_rows:
Expand All @@ -1141,8 +1147,8 @@ def locations(self):

# track/remove duplicates
if loc not in seen:
yield loc
seen.add(loc)
yield loc

def __contains__(self, ss):
"Does this manifest contain this signature?"
Expand Down
48 changes: 38 additions & 10 deletions src/sourmash/picklist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"Picklist code for extracting subsets of signatures."
import csv
from enum import Enum

# set up preprocessing functions for column stuff
preprocess = {}
Expand All @@ -17,6 +18,11 @@
preprocess['md5short'] = lambda x: x[:8]


class PickStyle(Enum):
INCLUDE = 1
EXCLUDE = 2


class SignaturePicklist:
"""Picklist class for subsetting collections of signatures.
Expand All @@ -41,15 +47,17 @@ class SignaturePicklist:
supported_coltypes = ('md5', 'md5prefix8', 'md5short',
'name', 'ident', 'identprefix')

def __init__(self, coltype, *, pickfile=None, column_name=None):
def __init__(self, coltype, *, pickfile=None, column_name=None, pickstyle=PickStyle.INCLUDE):
"create a picklist of column type 'coltype'."
self.coltype = coltype
self.pickfile = pickfile
self.column_name = column_name
self.pickstyle = pickstyle

if coltype not in self.supported_coltypes:
raise ValueError(f"invalid picklist column type '{coltype}'")


self.preprocess_fn = preprocess[coltype]
self.pickset = None
self.found = set()
Expand All @@ -60,6 +68,15 @@ def from_picklist_args(cls, argstr):
"load a picklist from an argument string 'pickfile:column:coltype'"
picklist = argstr.split(':')
if len(picklist) != 3:
if len(picklist) == 4:
pickfile, column, coltype, pickstyle = picklist
if pickstyle == 'include':
return cls(coltype, pickfile=pickfile, column_name=column, pickstyle=PickStyle.INCLUDE)
elif pickstyle == 'exclude':
return cls(coltype, pickfile=pickfile, column_name=column, pickstyle=PickStyle.EXCLUDE)
else:
raise ValueError(f"invalid picklist 'pickstyle' argument, '{pickstyle}': must be 'include' or 'exclude'")

raise ValueError(f"invalid picklist argument '{argstr}'")

assert len(picklist) == 3
Expand Down Expand Up @@ -131,13 +148,18 @@ def __contains__(self, ss):
self.n_queries += 1

# determine if ok or not.
if q in self.pickset:
self.found.add(q)
return True
if self.pickstyle == PickStyle.INCLUDE:
if q in self.pickset:
self.found.add(q)
return True
elif self.pickstyle == PickStyle.EXCLUDE:
if q not in self.pickset:
self.found.add(q)
return True
return False

def matches_siginfo(self, siginfo):
# match on metadata info for signature, not signature itself
def matches_manifest_row(self, row):
"does the given manifest row match this picklist?"
if self.coltype == 'md5':
colkey = 'md5'
elif self.coltype in ('md5prefix8', 'md5short'):
Expand All @@ -147,12 +169,18 @@ def matches_siginfo(self, siginfo):
else:
assert 0

q = siginfo[colkey]
q = row[colkey]
q = self.preprocess_fn(q)
self.n_queries += 1
if q in self.pickset:
self.found.add(q)
return True

if self.pickstyle == PickStyle.INCLUDE:
if q in self.pickset:
self.found.add(q)
return True
elif self.pickstyle == PickStyle.EXCLUDE:
if q not in self.pickset:
self.found.add(q)
return True
return False

def filter(self, it):
Expand Down
Loading

0 comments on commit 6ebec9c

Please sign in to comment.