diff --git a/src/sourmash/command_compute.py b/src/sourmash/command_compute.py index 1dda0bcccd..683c4d2ffd 100644 --- a/src/sourmash/command_compute.py +++ b/src/sourmash/command_compute.py @@ -328,6 +328,60 @@ def __init__(self, ksizes, seed, protein, dayhoff, hp, dna, num_hashes, track_ab self.track_abundance = track_abundance self.scaled = scaled + def to_param_str(self): + "Convert object to equivalent params str." + pi = [] + + if self.dna: + pi.append("dna") + elif self.protein: + pi.append("protein") + elif self.hp: + pi.append("hp") + elif self.dayhoff: + pi.append("dayhoff") + else: + assert 0 # must be one of the previous + + if self.dna: + kstr = [f"k={k}" for k in self.ksizes] + else: + # for protein, divide ksize by three. + kstr = [f"k={k//3}" for k in self.ksizes] + assert kstr + pi.extend(kstr) + + if self.num_hashes != 0: + pi.append(f"num={self.num_hashes}") + elif self.scaled != 0: + pi.append(f"scaled={self.scaled}") + else: + assert 0 + + if self.track_abundance: + pi.append("abund") + # noabund is default + + if self.seed != 42: + pi.append(f"seed={self.seed}") + # self.seed + + return ",".join(pi) + + def __repr__(self): + return f"ComputeParameters({self.ksizes}, {self.seed}, {self.protein}, {self.dayhoff}, {self.hp}, {self.dna}, {self.num_hashes}, {self.track_abundance}, {self.scaled})" + + def __eq__(self, other): + return (self.ksizes == other.ksizes and + self.seed == other.seed and + self.protein == other.protein and + self.dayhoff == other.dayhoff and + self.hp == other.hp and + self.dna == other.dna and + self.num_hashes == other.num_hashes and + self.track_abundance == other.track_abundance and + self.scaled == other.scaled) + @staticmethod def from_args(args): ptr = lib.computeparams_new() diff --git a/src/sourmash/command_sketch.py b/src/sourmash/command_sketch.py index 6b5c1c5b51..8414fcdba7 100644 --- a/src/sourmash/command_sketch.py +++ b/src/sourmash/command_sketch.py @@ -82,8 +82,7 @@ def _parse_params_str(params_str): class _signatures_for_sketch_factory(object): "Build sigs on demand, based on args input to 'sketch'." - def __init__(self, params_str_list, default_moltype, mult_ksize_by_3): - + def __init__(self, params_str_list, default_moltype): # first, set up defaults per-moltype defaults = {} for moltype, pstr in DEFAULTS.items(): @@ -94,7 +93,7 @@ def __init__(self, params_str_list, default_moltype, mult_ksize_by_3): # next, fill out params_list self.params_list = [] - self.mult_ksize_by_3 = mult_ksize_by_3 + self.mult_ksize_by_3 = True if params_str_list: # parse each params_str passed in, using default_moltype if none @@ -103,17 +102,21 @@ def __init__(self, params_str_list, default_moltype, mult_ksize_by_3): moltype, params = _parse_params_str(params_str) if moltype and moltype != 'dna' and default_moltype == 'dna': raise ValueError(f"Incompatible sketch type ({default_moltype}) and parameter override ({moltype}) in '{params_str}'; maybe use 'sketch translate'?") - elif moltype == 'dna' and default_moltype != 'dna': + elif moltype == 'dna' and default_moltype and default_moltype != 'dna': raise ValueError(f"Incompatible sketch type ({default_moltype}) and parameter override ({moltype}) in '{params_str}'") elif moltype is None: + if default_moltype is None: + raise ValueError(f"No default moltype and none specified in param string") moltype = default_moltype self.params_list.append((moltype, params)) else: + if default_moltype is None: + raise ValueError(f"No default moltype and none specified in param string") # no params str? default to a single sig, using default_moltype. self.params_list.append((default_moltype, {})) - def get_compute_params(self): + def get_compute_params(self, *, split_ksizes=False): for moltype, params_d in self.params_list: # get defaults for this moltype from self.defaults: default_params = self.defaults[moltype] @@ -134,26 +137,33 @@ def get_compute_params(self): if not ksizes: ksizes = def_ksizes - if self.mult_ksize_by_3: + # 'command sketch' adjusts k-mer sizes by 3 if non-DNA sketch. + if self.mult_ksize_by_3 and not def_dna: ksizes = [ k*3 for k in ksizes ] - params_obj = ComputeParameters(ksizes, - params_d.get('seed', def_seed), - def_protein, - def_dayhoff, - def_hp, - def_dna, - params_d.get('num', def_num), - params_d.get('track_abundance', - def_abund), - params_d.get('scaled', def_scaled)) - - yield params_obj - - def __call__(self): + make_param = lambda ksizes: ComputeParameters(ksizes, + params_d.get('seed', def_seed), + def_protein, + def_dayhoff, + def_hp, + def_dna, + params_d.get('num', def_num), + params_d.get('track_abundance', + def_abund), + params_d.get('scaled', def_scaled)) + + if split_ksizes: + for ksize in ksizes: + params_obj = make_param([ksize]) + yield params_obj + else: + params_obj = make_param(ksizes) + yield params_obj + + def __call__(self, *, split_ksizes=False): "Produce a new set of signatures built to match the param strings." sigs = [] - for params in self.get_compute_params(): + for params in self.get_compute_params(split_ksizes=split_ksizes): sig = SourmashSignature.from_params(params) sigs.append(sig) @@ -214,8 +224,7 @@ def dna(args): try: signatures_factory = _signatures_for_sketch_factory(args.param_string, - 'dna', - mult_ksize_by_3=False) + 'dna') except ValueError as e: error(f"Error creating signatures: {str(e)}") sys.exit(-1) @@ -244,8 +253,7 @@ def protein(args): try: signatures_factory = _signatures_for_sketch_factory(args.param_string, - moltype, - mult_ksize_by_3=True) + moltype) except ValueError as e: error(f"Error creating signatures: {str(e)}") sys.exit(-1) @@ -274,8 +282,7 @@ def translate(args): try: signatures_factory = _signatures_for_sketch_factory(args.param_string, - moltype, - mult_ksize_by_3=True) + moltype) except ValueError as e: error(f"Error creating signatures: {str(e)}") sys.exit(-1) diff --git a/src/sourmash/index/__init__.py b/src/sourmash/index/__init__.py index db8cb00c97..67f5605cbc 100644 --- a/src/sourmash/index/__init__.py +++ b/src/sourmash/index/__init__.py @@ -341,7 +341,7 @@ def select(self, ksize=None, moltype=None, scaled=None, num=None, def select_signature(ss, *, ksize=None, moltype=None, scaled=0, num=0, containment=False, abund=None, picklist=None): - "Check that the given signature matches the specificed requirements." + "Check that the given signature matches the specified requirements." # ksize match? if ksize and ksize != ss.minhash.ksize: return False diff --git a/tests/test_sourmash_sketch.py b/tests/test_sourmash_sketch.py index 20769fc27e..1391979a8b 100644 --- a/tests/test_sourmash_sketch.py +++ b/tests/test_sourmash_sketch.py @@ -67,8 +67,136 @@ def test_do_sourmash_sketch_check_num_bounds_more_than_maximum(runtmp): assert "WARNING: num value should be <= 50000. Continuing anyway." in runtmp.last_result.err +def test_empty_factory(): + with pytest.raises(ValueError): + factory = _signatures_for_sketch_factory([], None) + + +def test_no_default_moltype_factory_nonempty(): + with pytest.raises(ValueError): + factory = _signatures_for_sketch_factory(["k=31"], None) + + +def test_factory_no_default_moltype_dna(): + factory = _signatures_for_sketch_factory(['dna'], None) + params_list = list(factory.get_compute_params()) + assert len(params_list) == 1 + + params = params_list[0] + assert params.dna + + +def test_factory_no_default_moltype_protein(): + factory = _signatures_for_sketch_factory(['protein'], None) + params_list = list(factory.get_compute_params()) + assert len(params_list) == 1 + + params = params_list[0] + assert params.protein + + +def test_factory_dna_nosplit(): + factory = _signatures_for_sketch_factory(['k=31,k=51'], 'dna') + params_list = list(factory.get_compute_params(split_ksizes=False)) + assert len(params_list) == 1 + + params = params_list[0] + assert params.ksizes == [31,51] + + +def test_factory_dna_split(): + factory = _signatures_for_sketch_factory(['k=31,k=51'], 'dna') + params_list = list(factory.get_compute_params(split_ksizes=True)) + assert len(params_list) == 2 + + params = params_list[0] + assert params.ksizes == [31] + params = params_list[1] + assert params.ksizes == [51] + + +def test_factory_protein_nosplit(): + factory = _signatures_for_sketch_factory(['k=10,k=9'], 'protein') + params_list = list(factory.get_compute_params(split_ksizes=False)) + assert len(params_list) == 1 + + params = params_list[0] + assert params.ksizes == [30, 27] + + +def test_factory_protein_split(): + factory = _signatures_for_sketch_factory(['k=10,k=9'], 'protein') + params_list = list(factory.get_compute_params(split_ksizes=True)) + assert len(params_list) == 2 + + params = params_list[0] + assert params.ksizes == [30] + params = params_list[1] + assert params.ksizes == [27] + + +def test_factory_dna_equal(): + factory1 = _signatures_for_sketch_factory(['dna'], None) + params_list1 = list(factory1.get_compute_params()) + assert len(params_list1) == 1 + params1 = params_list1[0] + + factory2 = _signatures_for_sketch_factory([], 'dna') + params_list2 = list(factory2.get_compute_params()) + assert len(params_list2) == 1 + params2 = params_list2[0] + + assert params1 == params2 + assert repr(params1) == repr(params2) + + +def test_factory_protein_equal(): + factory1 = _signatures_for_sketch_factory(['protein'], None) + params_list1 = list(factory1.get_compute_params()) + assert len(params_list1) == 1 + params1 = params_list1[0] + + factory2 = _signatures_for_sketch_factory([], 'protein') + params_list2 = list(factory2.get_compute_params()) + assert len(params_list2) == 1 + params2 = params_list2[0] + + assert params1 == params2 + assert repr(params1) == repr(params2) + + +def test_factory_dna_multi_ksize_eq(): + factory1 = _signatures_for_sketch_factory(['k=21,k=31,dna'], None) + params_list1 = list(factory1.get_compute_params()) + assert len(params_list1) == 1 + params1 = params_list1[0] + + factory2 = _signatures_for_sketch_factory(['k=21,k=31'], 'dna') + params_list2 = list(factory2.get_compute_params()) + assert len(params_list2) == 1 + params2 = params_list2[0] + + assert params1 == params2 + assert repr(params1) == repr(params2) + + +def test_factory_protein_multi_ksize_eq(): + factory1 = _signatures_for_sketch_factory(['k=10,k=11,protein'], None) + params_list1 = list(factory1.get_compute_params()) + assert len(params_list1) == 1 + params1 = params_list1[0] + + factory2 = _signatures_for_sketch_factory(['k=10,k=11'], 'protein') + params_list2 = list(factory2.get_compute_params()) + assert len(params_list2) == 1 + params2 = params_list2[0] + + assert params1 == params2 + assert repr(params1) == repr(params2) + + def test_dna_defaults(): - factory = _signatures_for_sketch_factory([], 'dna', False) + factory = _signatures_for_sketch_factory([], 'dna') params_list = list(factory.get_compute_params()) assert len(params_list) == 1 @@ -87,7 +215,7 @@ def test_dna_defaults(): def test_dna_override_1(): factory = _signatures_for_sketch_factory(['k=21,scaled=2000,abund'], - 'dna', False) + 'dna') params_list = list(factory.get_compute_params()) assert len(params_list) == 1 @@ -106,44 +234,41 @@ def test_dna_override_1(): def test_scaled_param_requires_equal(): with pytest.raises(ValueError): - factory = _signatures_for_sketch_factory(['k=21,scaled'], - 'dna', False) + factory = _signatures_for_sketch_factory(['k=21,scaled'], 'dna') def test_k_param_requires_equal(): with pytest.raises(ValueError): - factory = _signatures_for_sketch_factory(['k'], - 'dna', False) + factory = _signatures_for_sketch_factory(['k'], 'dna') def test_k_param_requires_equal_2(): with pytest.raises(ValueError) as exc: - factory = _signatures_for_sketch_factory(['k='], - 'dna', False) + factory = _signatures_for_sketch_factory(['k='], 'dna') + def test_seed_param_requires_equal(): with pytest.raises(ValueError) as exc: - factory = _signatures_for_sketch_factory(['seed='], - 'dna', False) + factory = _signatures_for_sketch_factory(['seed='], 'dna') + def test_num_param_requires_equal(): with pytest.raises(ValueError) as exc: - factory = _signatures_for_sketch_factory(['num='], - 'dna', False) + factory = _signatures_for_sketch_factory(['num='], 'dna') + def test_dna_override_bad_1(): with pytest.raises(ValueError): factory = _signatures_for_sketch_factory(['k=21,scaledFOO=2000,abund'], - 'dna', False) + 'dna') def test_dna_override_bad_2(): with pytest.raises(ValueError): - factory = _signatures_for_sketch_factory(['k=21,protein'], - 'dna', False) + factory = _signatures_for_sketch_factory(['k=21,protein'], 'dna') def test_protein_defaults(): - factory = _signatures_for_sketch_factory([], 'protein', True) + factory = _signatures_for_sketch_factory([], 'protein') params_list = list(factory.get_compute_params()) assert len(params_list) == 1 @@ -162,12 +287,11 @@ def test_protein_defaults(): def test_protein_override_bad_2(): with pytest.raises(ValueError): - factory = _signatures_for_sketch_factory(['k=21,dna'], - 'protein', False) + factory = _signatures_for_sketch_factory(['k=21,dna'], 'protein') def test_protein_override_bad_rust_foo(): # mimic 'sourmash sketch protein -p dna' - factory = _signatures_for_sketch_factory([], 'protein', False) + factory = _signatures_for_sketch_factory([], 'protein') # reach in and avoid error checking to construct a bad params_list. factory.params_list = [('dna', {})] @@ -188,7 +312,7 @@ def test_protein_override_bad_rust_foo(): def test_dayhoff_defaults(): - factory = _signatures_for_sketch_factory([], 'dayhoff', True) + factory = _signatures_for_sketch_factory([], 'dayhoff') params_list = list(factory.get_compute_params()) assert len(params_list) == 1 @@ -207,11 +331,10 @@ def test_dayhoff_defaults(): def test_dayhoff_override_bad_2(): with pytest.raises(ValueError): - factory = _signatures_for_sketch_factory(['k=21,dna'], - 'dayhoff', False) + factory = _signatures_for_sketch_factory(['k=21,dna'], 'dayhoff') def test_hp_defaults(): - factory = _signatures_for_sketch_factory([], 'hp', True) + factory = _signatures_for_sketch_factory([], 'hp') params_list = list(factory.get_compute_params()) assert len(params_list) == 1 @@ -230,8 +353,7 @@ def test_hp_defaults(): def test_hp_override_bad_2(): with pytest.raises(ValueError): - factory = _signatures_for_sketch_factory(['k=21,dna'], - 'hp', False) + factory = _signatures_for_sketch_factory(['k=21,dna'], 'hp') def test_multiple_moltypes(): @@ -239,7 +361,7 @@ def test_multiple_moltypes(): 'k=19,num=400,dayhoff,abund', 'k=30,scaled=200,hp', 'k=30,scaled=200,seed=58'] - factory = _signatures_for_sketch_factory(params_foo, 'protein', True) + factory = _signatures_for_sketch_factory(params_foo, 'protein') params_list = list(factory.get_compute_params()) assert len(params_list) == 4 @@ -289,6 +411,28 @@ def test_multiple_moltypes(): assert params.protein +@pytest.mark.parametrize("input_param_str, expected_output", + [('protein', 'protein,k=10,scaled=200'), + ('dna', 'dna,k=31,scaled=1000'), + ('hp', 'hp,k=42,scaled=200'), + ('dayhoff', 'dayhoff,k=16,scaled=200'), + ('dna,seed=52', 'dna,k=31,scaled=1000,seed=52'), + ('dna,num=500', 'dna,k=31,num=500'), + ('scaled=1100,dna', 'dna,k=31,scaled=1100'), + ('dna,abund', 'dna,k=31,scaled=1000,abund') + ]) +def test_compute_parameters_to_param_str(input_param_str, expected_output): + factory = _signatures_for_sketch_factory([input_param_str], None) + params_list = list(factory.get_compute_params()) + assert len(params_list) == 1 + params = params_list[0] + + actual_output_str = params.to_param_str() + + assert actual_output_str == expected_output, (actual_output_str, + expected_output) + + ### command line tests