Skip to content

Commit

Permalink
A simple optimization to improve normalize-by-median: one can quickly…
Browse files Browse the repository at this point in the history
… observe that checking if the median of a set is greater than some cutoff is equivalent to checking if more than half the elements of that set are greater than some cutoff. The latter avoids doing a lookup for every kmer every time, and avoids a costly sort. On a small dataset (1m ecoli reads), this was an 18% performance improvement.

Implements the median_at_least function in C++ land, exposes it in CPython, and updates normalize-by-median.py.
  • Loading branch information
camillescott committed Jun 9, 2015
1 parent f824f64 commit 3dd1a9f
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 3 deletions.
8 changes: 8 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
2015-06-08 Camille Scott <camille.scott.w@gmail.com>

* lib/hashtable.{cc,hh}: Add filter_on_median method to check
if median k-mer count is above a cutoff
* khmer/_khmermodule.cc: Expose filter_on_median to python-land
* scripts/normalize-by-median.py: Switch to new filter_on_median
* tests/test_counting_hash.py: Tests for new method

2015-06-08 Luiz Irber <khmer@luizirber.org>

* tests/test_hll.py: test return values from consume_{string,fasta}.
Expand Down
27 changes: 27 additions & 0 deletions khmer/_khmermodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,32 @@ hashtable_get_median_count(khmer_KHashtable_Object * me, PyObject * args)
return Py_BuildValue("iff", med, average, stddev);
}

static
PyObject *
hashtable_median_at_least(khmer_KHashtable_Object * me, PyObject * args)
{
Hashtable * hashtable = me->hashtable;

const char * long_str;
unsigned int cutoff;

if (!PyArg_ParseTuple(args, "sI", &long_str, &cutoff)) {
return NULL;
}

if (strlen(long_str) < hashtable->ksize()) {
PyErr_SetString(PyExc_ValueError,
"string length must >= the hashtable k-mer size");
return NULL;
}

if (hashtable->median_at_least(long_str, cutoff)) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;

}

static
PyObject *
hashtable_n_tags(khmer_KHashtable_Object * me, PyObject * args)
Expand Down Expand Up @@ -2302,6 +2328,7 @@ static PyMethodDef khmer_hashtable_methods[] = {
{ "find_all_tags_list", (PyCFunction)hashtable_find_all_tags_list, METH_VARARGS, "Find all tags within range of the given k-mer, return as list" },
{ "consume_fasta_and_tag", (PyCFunction)hashtable_consume_fasta_and_tag, METH_VARARGS, "Count all k-mers in a given file" },
{ "get_median_count", (PyCFunction)hashtable_get_median_count, METH_VARARGS, "Get the median, average, and stddev of the k-mer counts in the string" },
{ "median_at_least", (PyCFunction)hashtable_median_at_least, METH_VARARGS, "Return true if the median is at least the given cutoff" },
{ "extract_unique_paths", (PyCFunction)hashtable_extract_unique_paths, METH_VARARGS, "" },
{ "load_stop_tags", (PyCFunction)hashtable_load_stop_tags, METH_VARARGS, "" },
{ "save_stop_tags", (PyCFunction)hashtable_save_stop_tags, METH_VARARGS, "" },
Expand Down
20 changes: 20 additions & 0 deletions lib/hashtable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,26 @@ void Hashtable::get_median_count(const std::string &s,
median = counts[counts.size() / 2]; // rounds down
}

//
// Optimized filter function for normalize-by-median
//
bool Hashtable::median_at_least(const std::string &s,
unsigned int cutoff) {
KMerIterator kmers(s.c_str(), _ksize);
unsigned int min_req = 0.5 + float(s.size() - _ksize + 1) / 2;
unsigned int num_cutoff_kmers = 0;
while(!kmers.done()) {
HashIntoType kmer = kmers.next();
if (this->get_count(kmer) >= cutoff) {
++num_cutoff_kmers;
if (num_cutoff_kmers >= min_req) {
return true;
}
}
}
return false;
}

void Hashtable::save_tagset(std::string outfilename)
{
ofstream outfile(outfilename.c_str(), ios::binary);
Expand Down
3 changes: 3 additions & 0 deletions lib/hashtable.hh
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ public:
unsigned long long &n_consumed
);

bool median_at_least(const std::string &s,
unsigned int cutoff);

void get_median_count(const std::string &s,
BoundedCounterType &median,
float &average,
Expand Down
4 changes: 1 addition & 3 deletions scripts/normalize-by-median.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def __call__(self, input_filename, force_paired=False):
continue

seq = record.sequence.replace('N', 'A')
med, _, _ = self.htable.get_median_count(seq)

if med < desired_coverage:
if not self.htable.median_at_least(seq, desired_coverage):
passed_filter = True

if passed_length and passed_filter:
Expand Down
131 changes: 131 additions & 0 deletions tests/test_counting_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,137 @@ def test_simple_median():
assert int(stddev * 100) == 50 # .5


def test_median_at_least():
hi = khmer.new_counting_hash(6, 1e6, 2)

hi.consume("AAAAAA")
assert hi.median_at_least("AAAAAA", 1)
assert hi.median_at_least("AAAAAA", 2) is False

hi.consume("AAAAAA")
assert hi.median_at_least("AAAAAA", 2)
assert hi.median_at_least("AAAAAA", 3) is False

hi.consume("AAAAAA")
assert hi.median_at_least("AAAAAA", 3)
assert hi.median_at_least("AAAAAA", 4) is False

hi.consume("AAAAAA")
assert hi.median_at_least("AAAAAA", 4)
assert hi.median_at_least("AAAAAA", 5) is False

hi.consume("AAAAAA")
assert hi.median_at_least("AAAAAA", 5)
assert hi.median_at_least("AAAAAA", 6) is False


def test_median_at_least_single_gt():
K = 20
hi = khmer.new_counting_hash(K, 1e6, 2)

kmers = ['ATCGATCGATCGATCGATCG',
'GTACGTACGTACGTACGTAC',
'TTAGTTAGTTAGTTAGTTAG']

for kmer in kmers:
hi.consume(kmer)
assert hi.median_at_least(kmer, 1) is True


def test_median_at_least_single_lt():
K = 20
hi = khmer.new_counting_hash(K, 1e6, 2)

kmers = ['ATCGATCGATCGATCGATCG',
'GTACGTACGTACGTACGTAC',
'TTAGTTAGTTAGTTAGTTAG']

for kmer in kmers:
hi.consume(kmer)
assert hi.median_at_least(kmer, 2) is False


def test_median_at_least_odd_gt():
# test w/odd number of k-mers
K = 20
hi = khmer.new_counting_hash(K, 1e6, 2)

seqs = ['ATCGATCGATCGATCGATCGCC',
'GTACGTACGTACGTACGTACCC',
'TTAGTTAGTTAGTTAGTTAGCC']

for seq in seqs:
hi.consume(seq)
assert hi.median_at_least(seq, 1) is True


def test_median_at_least_odd_lt():
K = 20
hi = khmer.new_counting_hash(K, 1e6, 2)

seqs = ['ATCGATCGATCGATCGATCGCC',
'GTACGTACGTACGTACGTACCC',
'TTAGTTAGTTAGTTAGTTAGCC']

for seq in seqs:
hi.consume(seq)
assert hi.median_at_least(seq, 2) is False


# Test median with even number of k-mers
def test_median_at_least_even_gt():
K = 20
hi = khmer.new_counting_hash(K, 1e6, 2)

seqs = ['ATCGATCGATCGATCGATCGCCC',
'GTACGTACGTACGTACGTACCCC',
'TTAGTTAGTTAGTTAGTTAGCCC']

for seq in seqs:
hi.consume(seq)
assert hi.median_at_least(seq, 1) is True


def test_median_at_least_even_lt():
K = 20
hi = khmer.new_counting_hash(K, 1e6, 2)

seqs = ['ATCGATCGATCGATCGATCGCCC',
'GTACGTACGTACGTACGTACCCC',
'TTAGTTAGTTAGTTAGTTAGCCC']

for seq in seqs:
hi.consume(seq)
assert hi.median_at_least(seq, 2) is False


def test_median_at_least_comp():
K = 20
C = 4
hi = khmer.new_counting_hash(K, 1e6, 2)

seqs = ['ATCGATCGATCGATCGATCGCCC',
'GTACGTACGTACGTACGTACCCC',
'TTAGTTAGTTAGTTAGTTAGCCC']

for seq in seqs:
hi.consume(seq)
hi.consume(seq)
hi.consume(seq)

med, _, _ = hi.get_median_count(seq)
assert hi.median_at_least(seq, C) is (med >= C)


def test_median_at_least_exception():
ht = khmer.new_counting_hash(20, 1e6, 2)
try:
ht.median_at_least('ATGGCTGATCGAT', 1)
assert 0, "should have thrown ValueError"
except ValueError as e:
pass


def test_simple_kadian():
hi = khmer.new_counting_hash(6, 1e6, 2)
hi.consume("ACTGCTATCTCTAGAGCTATG")
Expand Down
24 changes: 24 additions & 0 deletions tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,30 @@ def test_normalize_by_median():
assert seqs[0].startswith('GGTTGACGGGGCTCAGGGGG'), seqs


@attr('known_failing')
def test_normalize_by_median_known_good():
CUTOFF = '2'

infile = utils.get_temp_filename('test.fa.gz')
in_dir = os.path.dirname(infile)
shutil.copyfile(utils.get_test_data('100k-filtered.fa.gz'), infile)

script = scriptpath('normalize-by-median.py')
args = ['-C', CUTOFF, '-k', '20', '-x', '4e6', infile]
(status, out, err) = utils.runscript(script, args, in_dir)

outfile = infile + '.keep'
assert os.path.exists(outfile), outfile
iter_known = screed.open(utils.get_test_data('100k-filtered.fa.keep.gz'))
iter_out = screed.open(outfile)
try:
for rknown, rout in zip(iter_known, iter_out):
assert rknown.name == rout.name
except Exception as e:
print e
assert False


def test_normalize_by_median_report_fp():
infile = utils.get_temp_filename('test.fa')
in_dir = os.path.dirname(infile)
Expand Down

0 comments on commit 3dd1a9f

Please sign in to comment.