From 3dd1a9f92cdfd9d624b4587dc89f936d2b356137 Mon Sep 17 00:00:00 2001 From: Camille Scott Date: Mon, 8 Jun 2015 21:03:51 -0600 Subject: [PATCH] A simple optimization to improve normalize-by-median: one can quickly observe that checking if the median of a set is greater than some cutoff is equivalent to checking if more than half the elements of that set are greater than some cutoff. The latter avoids doing a lookup for every kmer every time, and avoids a costly sort. On a small dataset (1m ecoli reads), this was an 18% performance improvement. Implements the median_at_least function in C++ land, exposes it in CPython, and updates normalize-by-median.py. --- ChangeLog | 8 ++ khmer/_khmermodule.cc | 27 +++++++ lib/hashtable.cc | 20 +++++ lib/hashtable.hh | 3 + scripts/normalize-by-median.py | 4 +- tests/test_counting_hash.py | 131 +++++++++++++++++++++++++++++++++ tests/test_scripts.py | 24 ++++++ 7 files changed, 214 insertions(+), 3 deletions(-) diff --git a/ChangeLog b/ChangeLog index f2da76b0ea..0818010c3e 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,11 @@ +2015-06-08 Camille Scott + + * lib/hashtable.{cc,hh}: Add filter_on_median method to check + if median k-mer count is above a cutoff + * khmer/_khmermodule.cc: Expose filter_on_median to python-land + * scripts/normalize-by-median.py: Switch to new filter_on_median + * tests/test_counting_hash.py: Tests for new method + 2015-06-08 Luiz Irber * tests/test_hll.py: test return values from consume_{string,fasta}. diff --git a/khmer/_khmermodule.cc b/khmer/_khmermodule.cc index d5f9c09955..70e44b8633 100644 --- a/khmer/_khmermodule.cc +++ b/khmer/_khmermodule.cc @@ -1149,6 +1149,32 @@ hashtable_get_median_count(khmer_KHashtable_Object * me, PyObject * args) return Py_BuildValue("iff", med, average, stddev); } +static +PyObject * +hashtable_median_at_least(khmer_KHashtable_Object * me, PyObject * args) +{ + Hashtable * hashtable = me->hashtable; + + const char * long_str; + unsigned int cutoff; + + if (!PyArg_ParseTuple(args, "sI", &long_str, &cutoff)) { + return NULL; + } + + if (strlen(long_str) < hashtable->ksize()) { + PyErr_SetString(PyExc_ValueError, + "string length must >= the hashtable k-mer size"); + return NULL; + } + + if (hashtable->median_at_least(long_str, cutoff)) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; + +} + static PyObject * hashtable_n_tags(khmer_KHashtable_Object * me, PyObject * args) @@ -2302,6 +2328,7 @@ static PyMethodDef khmer_hashtable_methods[] = { { "find_all_tags_list", (PyCFunction)hashtable_find_all_tags_list, METH_VARARGS, "Find all tags within range of the given k-mer, return as list" }, { "consume_fasta_and_tag", (PyCFunction)hashtable_consume_fasta_and_tag, METH_VARARGS, "Count all k-mers in a given file" }, { "get_median_count", (PyCFunction)hashtable_get_median_count, METH_VARARGS, "Get the median, average, and stddev of the k-mer counts in the string" }, + { "median_at_least", (PyCFunction)hashtable_median_at_least, METH_VARARGS, "Return true if the median is at least the given cutoff" }, { "extract_unique_paths", (PyCFunction)hashtable_extract_unique_paths, METH_VARARGS, "" }, { "load_stop_tags", (PyCFunction)hashtable_load_stop_tags, METH_VARARGS, "" }, { "save_stop_tags", (PyCFunction)hashtable_save_stop_tags, METH_VARARGS, "" }, diff --git a/lib/hashtable.cc b/lib/hashtable.cc index 211cb7052a..52a9fe4589 100644 --- a/lib/hashtable.cc +++ b/lib/hashtable.cc @@ -232,6 +232,26 @@ void Hashtable::get_median_count(const std::string &s, median = counts[counts.size() / 2]; // rounds down } +// +// Optimized filter function for normalize-by-median +// +bool Hashtable::median_at_least(const std::string &s, + unsigned int cutoff) { + KMerIterator kmers(s.c_str(), _ksize); + unsigned int min_req = 0.5 + float(s.size() - _ksize + 1) / 2; + unsigned int num_cutoff_kmers = 0; + while(!kmers.done()) { + HashIntoType kmer = kmers.next(); + if (this->get_count(kmer) >= cutoff) { + ++num_cutoff_kmers; + if (num_cutoff_kmers >= min_req) { + return true; + } + } + } + return false; +} + void Hashtable::save_tagset(std::string outfilename) { ofstream outfile(outfilename.c_str(), ios::binary); diff --git a/lib/hashtable.hh b/lib/hashtable.hh index 932babc495..41ff3e6b4a 100644 --- a/lib/hashtable.hh +++ b/lib/hashtable.hh @@ -292,6 +292,9 @@ public: unsigned long long &n_consumed ); + bool median_at_least(const std::string &s, + unsigned int cutoff); + void get_median_count(const std::string &s, BoundedCounterType &median, float &average, diff --git a/scripts/normalize-by-median.py b/scripts/normalize-by-median.py index ae87e6d52c..f6d758d47e 100755 --- a/scripts/normalize-by-median.py +++ b/scripts/normalize-by-median.py @@ -117,9 +117,7 @@ def __call__(self, input_filename, force_paired=False): continue seq = record.sequence.replace('N', 'A') - med, _, _ = self.htable.get_median_count(seq) - - if med < desired_coverage: + if not self.htable.median_at_least(seq, desired_coverage): passed_filter = True if passed_length and passed_filter: diff --git a/tests/test_counting_hash.py b/tests/test_counting_hash.py index 9346901665..9b6e3c8e8d 100644 --- a/tests/test_counting_hash.py +++ b/tests/test_counting_hash.py @@ -217,6 +217,137 @@ def test_simple_median(): assert int(stddev * 100) == 50 # .5 +def test_median_at_least(): + hi = khmer.new_counting_hash(6, 1e6, 2) + + hi.consume("AAAAAA") + assert hi.median_at_least("AAAAAA", 1) + assert hi.median_at_least("AAAAAA", 2) is False + + hi.consume("AAAAAA") + assert hi.median_at_least("AAAAAA", 2) + assert hi.median_at_least("AAAAAA", 3) is False + + hi.consume("AAAAAA") + assert hi.median_at_least("AAAAAA", 3) + assert hi.median_at_least("AAAAAA", 4) is False + + hi.consume("AAAAAA") + assert hi.median_at_least("AAAAAA", 4) + assert hi.median_at_least("AAAAAA", 5) is False + + hi.consume("AAAAAA") + assert hi.median_at_least("AAAAAA", 5) + assert hi.median_at_least("AAAAAA", 6) is False + + +def test_median_at_least_single_gt(): + K = 20 + hi = khmer.new_counting_hash(K, 1e6, 2) + + kmers = ['ATCGATCGATCGATCGATCG', + 'GTACGTACGTACGTACGTAC', + 'TTAGTTAGTTAGTTAGTTAG'] + + for kmer in kmers: + hi.consume(kmer) + assert hi.median_at_least(kmer, 1) is True + + +def test_median_at_least_single_lt(): + K = 20 + hi = khmer.new_counting_hash(K, 1e6, 2) + + kmers = ['ATCGATCGATCGATCGATCG', + 'GTACGTACGTACGTACGTAC', + 'TTAGTTAGTTAGTTAGTTAG'] + + for kmer in kmers: + hi.consume(kmer) + assert hi.median_at_least(kmer, 2) is False + + +def test_median_at_least_odd_gt(): + # test w/odd number of k-mers + K = 20 + hi = khmer.new_counting_hash(K, 1e6, 2) + + seqs = ['ATCGATCGATCGATCGATCGCC', + 'GTACGTACGTACGTACGTACCC', + 'TTAGTTAGTTAGTTAGTTAGCC'] + + for seq in seqs: + hi.consume(seq) + assert hi.median_at_least(seq, 1) is True + + +def test_median_at_least_odd_lt(): + K = 20 + hi = khmer.new_counting_hash(K, 1e6, 2) + + seqs = ['ATCGATCGATCGATCGATCGCC', + 'GTACGTACGTACGTACGTACCC', + 'TTAGTTAGTTAGTTAGTTAGCC'] + + for seq in seqs: + hi.consume(seq) + assert hi.median_at_least(seq, 2) is False + + +# Test median with even number of k-mers +def test_median_at_least_even_gt(): + K = 20 + hi = khmer.new_counting_hash(K, 1e6, 2) + + seqs = ['ATCGATCGATCGATCGATCGCCC', + 'GTACGTACGTACGTACGTACCCC', + 'TTAGTTAGTTAGTTAGTTAGCCC'] + + for seq in seqs: + hi.consume(seq) + assert hi.median_at_least(seq, 1) is True + + +def test_median_at_least_even_lt(): + K = 20 + hi = khmer.new_counting_hash(K, 1e6, 2) + + seqs = ['ATCGATCGATCGATCGATCGCCC', + 'GTACGTACGTACGTACGTACCCC', + 'TTAGTTAGTTAGTTAGTTAGCCC'] + + for seq in seqs: + hi.consume(seq) + assert hi.median_at_least(seq, 2) is False + + +def test_median_at_least_comp(): + K = 20 + C = 4 + hi = khmer.new_counting_hash(K, 1e6, 2) + + seqs = ['ATCGATCGATCGATCGATCGCCC', + 'GTACGTACGTACGTACGTACCCC', + 'TTAGTTAGTTAGTTAGTTAGCCC'] + + for seq in seqs: + hi.consume(seq) + hi.consume(seq) + hi.consume(seq) + + med, _, _ = hi.get_median_count(seq) + assert hi.median_at_least(seq, C) is (med >= C) + + +def test_median_at_least_exception(): + ht = khmer.new_counting_hash(20, 1e6, 2) + try: + ht.median_at_least('ATGGCTGATCGAT', 1) + assert 0, "should have thrown ValueError" + except ValueError as e: + pass + + def test_simple_kadian(): hi = khmer.new_counting_hash(6, 1e6, 2) hi.consume("ACTGCTATCTCTAGAGCTATG") diff --git a/tests/test_scripts.py b/tests/test_scripts.py index ca3a1d7582..79d219ffab 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -603,6 +603,30 @@ def test_normalize_by_median(): assert seqs[0].startswith('GGTTGACGGGGCTCAGGGGG'), seqs +@attr('known_failing') +def test_normalize_by_median_known_good(): + CUTOFF = '2' + + infile = utils.get_temp_filename('test.fa.gz') + in_dir = os.path.dirname(infile) + shutil.copyfile(utils.get_test_data('100k-filtered.fa.gz'), infile) + + script = scriptpath('normalize-by-median.py') + args = ['-C', CUTOFF, '-k', '20', '-x', '4e6', infile] + (status, out, err) = utils.runscript(script, args, in_dir) + + outfile = infile + '.keep' + assert os.path.exists(outfile), outfile + iter_known = screed.open(utils.get_test_data('100k-filtered.fa.keep.gz')) + iter_out = screed.open(outfile) + try: + for rknown, rout in zip(iter_known, iter_out): + assert rknown.name == rout.name + except Exception as e: + print e + assert False + + def test_normalize_by_median_report_fp(): infile = utils.get_temp_filename('test.fa') in_dir = os.path.dirname(infile)