Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved vectorized ordinal mapper #210

Merged
merged 3 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 70 additions & 62 deletions woltka/ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def ordinal_mapper(fh, coords, idmap, fmt=None, excl=None, n=2**20, th=0.8,
th : float
Minimum threshold of overlap length : alignment length for a match.
prefix : bool
Prefix gene IDs with nucleotide IDs.
Prefix gene IDs with genome IDs.

See Also
--------
Expand All @@ -194,93 +194,103 @@ def ordinal_mapper(fh, coords, idmap, fmt=None, excl=None, n=2**20, th=0.8,
Yields
------
list of str
Query queue.
Read (query) queue.
list of set of str
Subject(s) queue.
Genes (subjects) queue.
"""
it = iter_align(fh, fmt, excl, True)

# arguments for flush_chunk
args = (coords, idmap, th, prefix)

# cached lists of read Ids and lengths (pre-allocate space)
# gene Ids are unique, but read Ids can have duplicates (i.e., one read is
# mapped to multiple loci on a genome), therefore an incremental integer
# here replaces the original read Id as its identifer
rids = [None] * n
lens = np.empty((n,), dtype=np.uint32)
# cached read information
qrys = [None] * n
lens = np.empty(n, dtype=np.uint32)
begs = np.empty(n, dtype=np.int64)
ends = np.empty(n, dtype=np.int64)

# cached map of reads to per-genome coordinates
locmap = defaultdict(list)
# arguments for flush_chunk
args = (qrys, lens, begs, ends, coords, idmap, th, prefix)

# current read index in the cached lists; will reset after each flush
# current read index in cache; will reset after each flush
idx = 0

# subject-to-indices mapping of cache
sub2idx = defaultdict(list)

# parse alignment file
for query, records in it:

# exclude hits with unavailable or zero length
records = [x for x in records if x[2]]

# when chunk limit is about to be exceeded by the next query, match
# currently cached reads with genes, flush, and reset
if idx + len(records) > n:
yield flush_chunk(idx, locmap, rids, lens, *args)
locmap = defaultdict(list)
yield flush_chunk(idx, sub2idx, *args)
idx = 0
sub2idx = defaultdict(list)

# extract alignment info and add to cache
# extract read info and add to cache
# hits with unavailable or zero length are excluded
for subject, _, length, beg, end in records:
rids[idx] = query
lens[idx] = length
locmap[subject].extend((
(beg << 24) + idx,
(end << 24) + (1 << 23) + idx))
idx += 1
if length:
qrys[idx] = query
lens[idx] = length
begs[idx] = beg
ends[idx] = end
sub2idx[subject].append(idx)
idx += 1

# final flush
yield flush_chunk(idx, locmap, rids, lens, *args)
yield flush_chunk(idx, sub2idx, *args)


def flush_chunk(n, rlocmap, rids, rlens, glocmap, gidmap, th, prefix):
def flush_chunk(n, idxmap, rids, lens, begs, ends, glocmap, gidmap, th,
prefix):
"""Match reads in current chunk with genes from all genomes.

Parameters
----------
n : int
Number of reads to flush.
rlocmap : dict of list
Read coordinates per genome.
idxmap : dict of list of int
Read indices per genome.
rids : list of str
Read IDs.
rlens : np.array(-1, dtype=int64)
Read identifiers.
lens : np.array(-1, dtype=uint32)
Read lengths.
begs : np.array(-1, dtype=int64)
Read start coordinates.
ends : np.array(-1, dtype=int64)
Read end coordinates.
glocmap : dict of list
Gene coordinates per genome.
gidmap : dict of list
Gene identifiers.
Gene identifiers per genome.
th : float
Length threshold.
prefix : bool
Prefix gene IDs with nucleotide IDs.
Prefix gene IDs with genome IDs.

Returns
-------
list of str
Query queue.
Read (query) queue.
list of set of str
Subject(s) queue.
Genes (subjects) queue.
"""
# master read-to-gene(s) map
res = defaultdict(set)

# effective length = length * th
rels = np.ceil(rlens[:n] * th).astype(np.uint32)
# calculate effective lengths of reads
rels = np.ceil(lens[:n] * th).astype(np.uint32)

# iterate over nucleotides
for nucl, rlocs in rlocmap.items():
# encode read start and end positions
idx = np.arange(n)
begs[:n] <<= 24
ends[:n] <<= 24
begs[:n] += idx
ends[:n] += idx + (1 << 23)

# it's possible that no gene was annotated on the nucleotide
# iterate over genomes:
for nucl, idx in idxmap.items():

# in case no gene was annotated on the genome
try:
glocs = glocmap[nucl]
except KeyError:
Expand All @@ -292,33 +302,33 @@ def flush_chunk(n, rlocmap, rids, rlens, glocmap, gidmap, th, prefix):
# append prefix if needed
pfx = nucl + '_' if prefix else ''

# convert list to array
rlocs = np.array(rlocs, dtype=np.int64)
# pair read starts and ends
idx = np.array(idx, dtype=np.uint32)
m = idx.size
locs = np.empty(2 * m, dtype=np.int64)
locs[0::2] = begs[idx]
locs[1::2] = ends[idx]

# execute ordinal algorithm when reads are many
# 10 (>5 reads) is an empirically determined cutoff
if rlocs.size > 10:
# 5 is an empirically determined cutoff
if m > 5:

# merge pre-sorted genes with reads of unknown sorting status
queue = np.concatenate((glocs, rlocs))
queue = np.concatenate((glocs, locs))

# sort genes and reads into a mixture
# timsort is efficient for this task
queue.sort(kind='stable')

# a potentially more efficient method is to use sortednp:
# rlocs.sort(kind='stable')
# queue = sortednp.merge(glocs, rlocs)

# map reads to genes using the core algorithm
matches = match_read_gene(queue, rels)
gen = match_read_gene(queue, rels)

# execute naive algorithm when reads are few
else:
matches = match_read_gene_quart(glocs, rlocs, rels)
gen = match_read_gene_quart(glocs, locs, rels)

# add read-gene pairs to the master map
for read, gene in matches:
for read, gene in gen:
res[rids[read]].add(pfx + gids[gene])

# return matching read Ids and gene Ids
Expand Down Expand Up @@ -446,23 +456,21 @@ def encode_genes(lst):

# order each pair of start and end coordinates such that smaller one
# comes first
# faster than np.sort since there are only two numbers
# < is slightly faster than np.less
cmp = beg < end
lo = np.where(cmp, beg, end)
hi = np.where(cmp, end, beg)

# encode coordinate, start/end, is gene, and index into one integer
lo = np.left_shift(lo - 1, 24) + (1 << 22) + idx
hi = np.left_shift(hi, 24) + (3 << 22) + idx
lo = (lo - 1 << 24) + (1 << 22) + idx
hi = (hi << 24) + (3 << 22) + idx

# fastest way to interleave two arrays
# https://stackoverflow.com/questions/5347065/
que = np.empty((2 * n,), dtype=np.int64)
que[0::2] = lo
que[1::2] = hi
queue = np.empty(2 * n, dtype=np.int64)
queue[0::2] = lo
queue[1::2] = hi

return que
return queue


@njit((int64[:], uint32[:]))
Expand Down
19 changes: 12 additions & 7 deletions woltka/tests/test_ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tempfile import mkdtemp
from io import StringIO
from functools import partial
from collections import defaultdict

import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -332,17 +333,21 @@ def test_flush_chunk(self):
'r9 n1 95 20 0 0 1 20 95 82 1 1',
'rx nx 95 0 0 0 0 0 0 0 1 1',
'# end of file')))
idx, rids, rlens, locmap = 0, [], [], {}
idx, sub2idx = 0, defaultdict(list)
qrys, lens, begs, ends = [], [], [], []
for query, records in parse_b6o_file_ex(aln):
for subject, _, length, beg, end in records:
rids.append(query)
rlens.append(length)
locmap.setdefault(subject, []).extend((
(beg << 24) + idx, (end << 24) + (1 << 23) + idx))
qrys.append(query)
lens.append(length)
begs.append(beg)
ends.append(end)
sub2idx[subject].append(idx)
idx += 1
rlens = np.array(rlens)
lens = np.array(lens, dtype=np.uint32)
begs = np.array(begs, dtype=np.int64)
ends = np.array(ends, dtype=np.int64)
obs = flush_chunk(
len(rids), locmap, rids, rlens, coords, idmap, 0.8, False)
idx, sub2idx, qrys, lens, begs, ends, coords, idmap, 0.8, False)
exp = [('r1', 'g1'),
('r5', 'g2'),
('r6', 'g2'),
Expand Down