From 753fe63a6a14614a3d25801dbaedda4389bddbc3 Mon Sep 17 00:00:00 2001 From: "Erik Garrison (aider)" Date: Tue, 19 Nov 2024 19:31:44 -0600 Subject: [PATCH] refactor: Parallelize k-mer frequency counting and index building with thread-local processing --- src/map/include/winSketch.hpp | 165 +++++++++++++++++++++++----------- 1 file changed, 111 insertions(+), 54 deletions(-) diff --git a/src/map/include/winSketch.hpp b/src/map/include/winSketch.hpp index d196b8f0..c25fb566 100644 --- a/src/map/include/winSketch.hpp +++ b/src/map/include/winSketch.hpp @@ -231,72 +231,129 @@ namespace skch total_windows, "[wfmash::mashmap] building index"); - // First pass - count k-mer frequencies + // Parallel k-mer frequency counting + std::vector thread_kmer_freqs(param.threads); + std::vector freq_threads; + + // Split outputs into chunks for parallel processing + size_t chunk_size = (threadOutputs.size() + param.threads - 1) / param.threads; + + for (size_t t = 0; t < param.threads; ++t) { + freq_threads.emplace_back([&, t]() { + size_t start = t * chunk_size; + size_t end = std::min(start + chunk_size, threadOutputs.size()); + + for (size_t i = start; i < end; ++i) { + for (const MinmerInfo& mi : *threadOutputs[i]) { + thread_kmer_freqs[t][mi.hash]++; + } + } + }); + } + + for (auto& thread : freq_threads) { + thread.join(); + } + + // Merge frequency maps HF_Map_t kmer_freqs; - for (auto* output : threadOutputs) { - for (const MinmerInfo& mi : *output) { - kmer_freqs[mi.hash]++; + for (const auto& thread_freq : thread_kmer_freqs) { + for (const auto& [hash, freq] : thread_freq) { + kmer_freqs[hash] += freq; } } - // Second pass - build filtered indexes - uint64_t total_kmers = 0; - uint64_t filtered_kmers = 0; + // Parallel index building + std::vector thread_pos_indexes(param.threads); + std::vector thread_minmer_indexes(param.threads); + std::vector thread_total_kmers(param.threads, 0); + std::vector thread_filtered_kmers(param.threads, 0); + std::vector index_threads; + + for (size_t t = 0; t < param.threads; ++t) { + index_threads.emplace_back([&, t]() { + size_t start = t * chunk_size; + size_t end = std::min(start + chunk_size, threadOutputs.size()); + + for (size_t i = start; i < end; ++i) { + for (const MinmerInfo& mi : *threadOutputs[i]) { + thread_total_kmers[t]++; + + auto freq_it = kmer_freqs.find(mi.hash); + if (freq_it == kmer_freqs.end()) { + continue; // Should never happen + } - // Clear existing indexes - minmerPosLookupIndex.clear(); - minmerIndex.clear(); + uint64_t freq = freq_it->second; + uint64_t min_occ = 10; + uint64_t max_occ = std::numeric_limits::max(); + uint64_t count_threshold; + + if (param.max_kmer_freq <= 1.0) { + count_threshold = std::min(max_occ, + std::max(min_occ, + (uint64_t)(total_windows * param.max_kmer_freq))); + } else { + count_threshold = std::min(max_occ, + std::max(min_occ, + (uint64_t)param.max_kmer_freq)); + } - for (auto* output : threadOutputs) { - for (const MinmerInfo& mi : *output) { - total_kmers++; - - auto freq_it = kmer_freqs.find(mi.hash); - if (freq_it == kmer_freqs.end()) { - continue; // Should never happen - } + if (freq > count_threshold && freq > min_occ) { + thread_filtered_kmers[t]++; + continue; + } - uint64_t freq = freq_it->second; - uint64_t min_occ = 10; // minimum occurrence threshold to prevent over-filtering in small datasets - uint64_t max_occ = std::numeric_limits::max(); // no upper limit on occurrences - uint64_t count_threshold; - - if (param.max_kmer_freq <= 1.0) { - // Calculate threshold based on fraction, but respect min/max bounds - count_threshold = std::min(max_occ, - std::max(min_occ, - (uint64_t)(total_windows * param.max_kmer_freq))); - } else { - // Use direct count threshold, but respect min/max bounds - count_threshold = std::min(max_occ, - std::max(min_occ, - (uint64_t)param.max_kmer_freq)); - } + auto& pos_list = thread_pos_indexes[t][mi.hash]; + if (pos_list.size() == 0 + || pos_list.back().hash != mi.hash + || pos_list.back().pos != mi.wpos) { + pos_list.push_back(IntervalPoint {mi.wpos, mi.hash, mi.seqId, side::OPEN}); + pos_list.push_back(IntervalPoint {mi.wpos_end, mi.hash, mi.seqId, side::CLOSE}); + } else { + pos_list.back().pos = mi.wpos_end; + } - // Filter only if BOTH conditions are met: - // 1. Frequency exceeds the calculated threshold - // 2. Count exceeds minimum occurrence threshold - if (freq > count_threshold && freq > min_occ) { - filtered_kmers++; - continue; + thread_minmer_indexes[t].push_back(mi); + index_progress.increment(1); + } + delete threadOutputs[i]; } + }); + } - // Add to position lookup index - auto& pos_list = minmerPosLookupIndex[mi.hash]; - if (pos_list.size() == 0 - || pos_list.back().hash != mi.hash - || pos_list.back().pos != mi.wpos) { - pos_list.push_back(IntervalPoint {mi.wpos, mi.hash, mi.seqId, side::OPEN}); - pos_list.push_back(IntervalPoint {mi.wpos_end, mi.hash, mi.seqId, side::CLOSE}); - } else { - pos_list.back().pos = mi.wpos_end; - } + for (auto& thread : index_threads) { + thread.join(); + } + + // Merge results + uint64_t total_kmers = std::accumulate(thread_total_kmers.begin(), thread_total_kmers.end(), 0ULL); + uint64_t filtered_kmers = std::accumulate(thread_filtered_kmers.begin(), thread_filtered_kmers.end(), 0ULL); - // Add to minmer index - minmerIndex.push_back(mi); - index_progress.increment(1); + // Clear and resize main indexes + minmerPosLookupIndex.clear(); + minmerIndex.clear(); + + // Reserve approximate space + size_t total_minmers = 0; + for (const auto& thread_index : thread_minmer_indexes) { + total_minmers += thread_index.size(); + } + minmerIndex.reserve(total_minmers); + + // Merge position lookup indexes + for (auto& thread_pos_index : thread_pos_indexes) { + for (auto& [hash, pos_list] : thread_pos_index) { + auto& main_pos_list = minmerPosLookupIndex[hash]; + main_pos_list.insert(main_pos_list.end(), pos_list.begin(), pos_list.end()); } - delete output; + } + + // Merge minmer indexes + for (auto& thread_index : thread_minmer_indexes) { + minmerIndex.insert(minmerIndex.end(), + std::make_move_iterator(thread_index.begin()), + std::make_move_iterator(thread_index.end())); } // Finish second progress meter