diff --git a/src/interface/parse_args.hpp b/src/interface/parse_args.hpp index c6eeb48a..c90aa263 100644 --- a/src/interface/parse_args.hpp +++ b/src/interface/parse_args.hpp @@ -99,12 +99,14 @@ void parse_args(int argc, args::Group mapping_opts(options_group, "Mapping:"); args::Flag approx_mapping(mapping_opts, "", "output approximate mappings (no alignment)", {'m', "approx-mapping"}); args::ValueFlag map_pct_identity(mapping_opts, "FLOAT", "minimum mapping identity [70]", {'p', "map-pct-id"}); - args::ValueFlag num_mappings(mapping_opts, "INT", "number of mappings to keep per query/target pair [1]", {'n', "mappings"}); + args::ValueFlag num_mappings(mapping_opts, "INT", "number of mappings to keep per segment [1]", {'n', "mappings"}); args::ValueFlag segment_length(mapping_opts, "INT", "segment length for mapping [1k]", {'s', "segment-length"}); args::ValueFlag block_length(mapping_opts, "INT", "minimum block length [3*segment-length]", {'l', "block-length"}); args::Flag one_to_one(mapping_opts, "", "Perform one-to-one filtering", {'o', "one-to-one"}); args::Flag lower_triangular(mapping_opts, "", "Only compute the lower triangular for all-vs-all mapping", {'L', "lower-triangular"}); args::ValueFlag skip_prefix(mapping_opts, "C", "map between sequence groups with different prefix [#]", {'Y', "group-prefix"}); + args::Flag disable_grouping(mapping_opts, "", "disable sequence grouping and exclude self mappings", {'G', "disable-grouping"}); + args::Flag enable_self_mappings(mapping_opts, "", "enable self mappings (overrides -G)", {'X', "self-maps"}); args::ValueFlag target_prefix(mapping_opts, "pfx", "use only targets whose names start with this prefix", {'T', "target-prefix"}); args::ValueFlag target_list(mapping_opts, "FILE", "file containing list of target sequence names to use", {'R', "target-list"}); args::ValueFlag query_prefix(mapping_opts, "pfxs", "filter queries by comma-separated prefixes", {'Q', "query-prefix"}); @@ -169,11 +171,17 @@ void parse_args(int argc, exit(1); } - map_parameters.skip_self = false; + map_parameters.skip_self = !args::get(enable_self_mappings); map_parameters.lower_triangular = lower_triangular ? args::get(lower_triangular) : false; map_parameters.keep_low_pct_id = true; - if (skip_prefix) { + if (disable_grouping) { + map_parameters.prefix_delim = '\0'; + map_parameters.skip_prefix = false; + if (!args::get(enable_self_mappings)) { + map_parameters.skip_self = true; + } + } else if (skip_prefix) { map_parameters.prefix_delim = args::get(skip_prefix); map_parameters.skip_prefix = map_parameters.prefix_delim != '\0'; } else { @@ -303,7 +311,7 @@ void parse_args(int argc, exit(1); } - if (!yeet_parameters.approx_mapping && s > 10000) { + if (!approx_mapping && s > 10000) { std::cerr << "[wfmash] ERROR: segment length (-s) must be <= 10kb when running alignment." << std::endl << "[wfmash] For larger values, use -m/--approx-mapping to generate mappings," << std::endl << "[wfmash] then align them with: wfmash ... -i mappings.paf" << std::endl; @@ -332,7 +340,7 @@ void parse_args(int argc, exit(1); } - if (!yeet_parameters.approx_mapping && l > 30000) { + if (!approx_mapping && l > 30000) { std::cerr << "[wfmash] ERROR: block length (-l) must be <= 30kb when running alignment." << std::endl << "[wfmash] For larger values, use -m/--approx-mapping to generate mappings," << std::endl << "[wfmash] then align them with: wfmash ... -i mappings.paf" << std::endl; @@ -363,7 +371,7 @@ void parse_args(int argc, std::cerr << "[wfmash] ERROR: max mapping length must be greater than 0." << std::endl; exit(1); } - if (!yeet_parameters.approx_mapping && l > 100000) { + if (!approx_mapping && l > 100000) { std::cerr << "[wfmash] ERROR: max mapping length (-P) must be <= 100kb when running alignment." << std::endl << "[wfmash] For larger values, use -m/--approx-mapping to generate mappings," << std::endl << "[wfmash] then align them with: wfmash ... -i mappings.paf" << std::endl; @@ -664,12 +672,11 @@ void parse_args(int argc, } #endif - args::ValueFlag num_mappings_for_segments(mapping_opts, "N", "number of mappings per segment [1]", {"mappings-per-segment"}); - if (num_mappings_for_segments) { - if (args::get(num_mappings_for_segments) > 0) { - map_parameters.numMappingsForSegment = args::get(num_mappings_for_segments) ; + if (num_mappings) { + if (args::get(num_mappings) > 0) { + map_parameters.numMappingsForSegment = args::get(num_mappings); } else { - std::cerr << "[wfmash] ERROR, skch::parseandSave, the number of mappings to retain for each segment has to be grater than 0." << std::endl; + std::cerr << "[wfmash] ERROR: the number of mappings to retain (-n) must be greater than 0." << std::endl; exit(1); } } else { diff --git a/src/map/include/computeMap.hpp b/src/map/include/computeMap.hpp index ba1d8137..af7f6984 100644 --- a/src/map/include/computeMap.hpp +++ b/src/map/include/computeMap.hpp @@ -55,9 +55,13 @@ namespace skch { struct QueryMappingOutput { std::string queryName; - std::vector results; + std::vector results; // Non-merged mappings + std::vector mergedResults; // Maximally merged mappings std::mutex mutex; progress_meter::ProgressMeter& progress; + QueryMappingOutput(const std::string& name, const std::vector& r, + const std::vector& mr, progress_meter::ProgressMeter& p) + : queryName(name), results(r), mergedResults(mr), progress(p) {} }; struct FragmentData { @@ -149,6 +153,9 @@ namespace skch typedef atomic_queue::AtomicQueue writer_atomic_queue_t; typedef atomic_queue::AtomicQueue query_output_atomic_queue_t; typedef atomic_queue::AtomicQueue fragment_atomic_queue_t; + + // Track maximum chain ID seen across all subsets + std::atomic maxChainIdSeen{0}; void processFragment(FragmentData* fragment, @@ -178,6 +185,10 @@ namespace skch { std::lock_guard lock(fragment->output->mutex); fragment->output->results.insert(fragment->output->results.end(), l2Mappings.begin(), l2Mappings.end()); + // Initialize mergedResults with same mappings + if (param.mergeMappings && param.split) { + fragment->output->mergedResults.insert(fragment->output->mergedResults.end(), l2Mappings.begin(), l2Mappings.end()); + } } // Update progress after processing the fragment @@ -548,6 +559,7 @@ namespace skch } std::cerr << ", average size: " << std::fixed << std::setprecision(0) << avg_subset_size << "bp" << std::endl; + typedef std::vector MappingResultsVector_t; std::unordered_map combinedMappings; // Build index for the current subset @@ -666,6 +678,15 @@ namespace skch } output_thread.join(); + // Process both merged and non-merged mappings + for (auto& [querySeqId, mappings] : combinedMappings) { + if (param.mergeMappings && param.split) { + filterMaximallyMerged(mappings, param, progress); + } else { + filterNonMergedMappings(mappings, param, progress); + } + } + progress.finish(); } @@ -681,6 +702,9 @@ namespace skch "[wfmash::mashmap] mapping (" + std::to_string(subset_count + 1) + "/" + std::to_string(total_subsets) + ")"); + // Create temporary storage for this subset's mappings + std::unordered_map subsetMappings; + // Launch reader thread std::thread reader([&]() { reader_thread(input_queue, reader_done, progress, *idManager); @@ -701,9 +725,9 @@ namespace skch }); } - // Launch aggregator thread + // Launch aggregator thread with subset storage std::thread aggregator([&]() { - aggregator_thread(merged_queue, workers_done, combinedMappings); + aggregator_thread(merged_queue, workers_done, subsetMappings); }); // Wait for all threads to complete @@ -721,7 +745,22 @@ namespace skch aggregator.join(); - // Reset flags and clear aggregatedMappings for next iteration + // Filter mappings within this subset before merging with previous results + for (auto& [querySeqId, mappings] : subsetMappings) { + + // Merge with existing mappings for this query + if (combinedMappings.find(querySeqId) == combinedMappings.end()) { + combinedMappings[querySeqId] = std::move(mappings); + } else { + combinedMappings[querySeqId].insert( + combinedMappings[querySeqId].end(), + std::make_move_iterator(mappings.begin()), + std::make_move_iterator(mappings.end()) + ); + } + } + + // Reset flags for next iteration reader_done.store(false); workers_done.store(false); fragments_done.store(false); @@ -939,6 +978,11 @@ namespace skch } mappingBoundarySanityCheck(input, output->results); + + // Filter and get both merged and non-merged mappings + auto [nonMergedMappings, mergedMappings] = filterSubsetMappings(output->results, input->progress); + output->results = std::move(nonMergedMappings); + output->mergedResults = std::move(mergedMappings); return output; } @@ -1005,10 +1049,12 @@ namespace skch QueryMappingOutput* output = nullptr; if (merged_queue.try_pop(output)) { seqno_t querySeqId = idManager->getSequenceId(output->queryName); + auto& mappings = output->results; + // Chain IDs are already compacted in mapModule combinedMappings[querySeqId].insert( combinedMappings[querySeqId].end(), - output->results.begin(), - output->results.end() + mappings.begin(), + mappings.end() ); delete output; } else if (workers_done.load() && merged_queue.was_empty()) { @@ -2159,6 +2205,53 @@ namespace skch * @param begin Iterator to the start of the chain * @param end Iterator to the end of the chain */ + /** + * @brief Filter mappings within a subset before aggregation + * @param mappings Mappings to filter + * @param param Algorithm parameters + */ + std::pair filterSubsetMappings(MappingResultsVector_t& mappings, progress_meter::ProgressMeter& progress) { + if (mappings.empty()) return {MappingResultsVector_t(), MappingResultsVector_t()}; + + // Only merge once and keep both versions + auto maximallyMergedMappings = mergeMappingsInRange(mappings, param.chain_gap, progress); + + // Build dense chain ID mapping + std::unordered_map id_map; + offset_t next_id = 0; + + // First pass - build the mapping from both sets + for (const auto& mapping : mappings) { + if (id_map.count(mapping.splitMappingId) == 0) { + id_map[mapping.splitMappingId] = next_id++; + } + } + for (const auto& mapping : maximallyMergedMappings) { + if (id_map.count(mapping.splitMappingId) == 0) { + id_map[mapping.splitMappingId] = next_id++; + } + } + + // Get atomic offset for this batch of chain IDs + offset_t base_id = maxChainIdSeen.fetch_add(id_map.size(), std::memory_order_relaxed); + + // Apply compacted IDs with offset + for (auto& mapping : mappings) { + mapping.splitMappingId = id_map[mapping.splitMappingId] + base_id; + } + for (auto& mapping : maximallyMergedMappings) { + mapping.splitMappingId = id_map[mapping.splitMappingId] + base_id; + } + + return {std::move(mappings), std::move(maximallyMergedMappings)}; + } + + /** + * @brief Update chain IDs to prevent conflicts between subsets + * @param mappings Mappings whose chain IDs need updating + * @param maxId Current maximum chain ID seen + */ + void computeChainStatistics(std::vector::iterator begin, std::vector::iterator end) { offset_t chain_start_query = std::numeric_limits::max(); offset_t chain_end_query = std::numeric_limits::min(); @@ -2297,8 +2390,14 @@ namespace skch auto& mappings = *(task->second); std::string queryName = idManager->getSequenceName(querySeqId); - processAggregatedMappings(queryName, mappings, progress); - + // Final filtering pass on pre-filtered mappings + if (param.filterMode == filter::MAP || param.filterMode == filter::ONETOONE) { + MappingResultsVector_t filteredMappings; + filterByGroup(mappings, filteredMappings, param.numMappingsForSegment - 1, + param.filterMode == filter::ONETOONE, *idManager, progress); + mappings = std::move(filteredMappings); + } + std::stringstream ss; reportReadMappings(mappings, queryName, ss);