Skip to content

Commit

Permalink
Merge pull request #301 from waveygang/prefilter
Browse files Browse the repository at this point in the history
Prefilter mappings to save memory in batched mapping
  • Loading branch information
ekg authored Dec 2, 2024
2 parents 9c15c7d + 85a85f8 commit 643af3e
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 19 deletions.
29 changes: 18 additions & 11 deletions src/interface/parse_args.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ void parse_args(int argc,
args::Group mapping_opts(options_group, "Mapping:");
args::Flag approx_mapping(mapping_opts, "", "output approximate mappings (no alignment)", {'m', "approx-mapping"});
args::ValueFlag<float> map_pct_identity(mapping_opts, "FLOAT", "minimum mapping identity [70]", {'p', "map-pct-id"});
args::ValueFlag<uint32_t> num_mappings(mapping_opts, "INT", "number of mappings to keep per query/target pair [1]", {'n', "mappings"});
args::ValueFlag<uint32_t> num_mappings(mapping_opts, "INT", "number of mappings to keep per segment [1]", {'n', "mappings"});
args::ValueFlag<std::string> segment_length(mapping_opts, "INT", "segment length for mapping [1k]", {'s', "segment-length"});
args::ValueFlag<std::string> block_length(mapping_opts, "INT", "minimum block length [3*segment-length]", {'l', "block-length"});
args::Flag one_to_one(mapping_opts, "", "Perform one-to-one filtering", {'o', "one-to-one"});
args::Flag lower_triangular(mapping_opts, "", "Only compute the lower triangular for all-vs-all mapping", {'L', "lower-triangular"});
args::ValueFlag<char> skip_prefix(mapping_opts, "C", "map between sequence groups with different prefix [#]", {'Y', "group-prefix"});
args::Flag disable_grouping(mapping_opts, "", "disable sequence grouping and exclude self mappings", {'G', "disable-grouping"});
args::Flag enable_self_mappings(mapping_opts, "", "enable self mappings (overrides -G)", {'X', "self-maps"});
args::ValueFlag<std::string> target_prefix(mapping_opts, "pfx", "use only targets whose names start with this prefix", {'T', "target-prefix"});
args::ValueFlag<std::string> target_list(mapping_opts, "FILE", "file containing list of target sequence names to use", {'R', "target-list"});
args::ValueFlag<std::string> query_prefix(mapping_opts, "pfxs", "filter queries by comma-separated prefixes", {'Q', "query-prefix"});
Expand Down Expand Up @@ -169,11 +171,17 @@ void parse_args(int argc,
exit(1);
}

map_parameters.skip_self = false;
map_parameters.skip_self = !args::get(enable_self_mappings);
map_parameters.lower_triangular = lower_triangular ? args::get(lower_triangular) : false;
map_parameters.keep_low_pct_id = true;

if (skip_prefix) {
if (disable_grouping) {
map_parameters.prefix_delim = '\0';
map_parameters.skip_prefix = false;
if (!args::get(enable_self_mappings)) {
map_parameters.skip_self = true;
}
} else if (skip_prefix) {
map_parameters.prefix_delim = args::get(skip_prefix);
map_parameters.skip_prefix = map_parameters.prefix_delim != '\0';
} else {
Expand Down Expand Up @@ -303,7 +311,7 @@ void parse_args(int argc,
exit(1);
}

if (!yeet_parameters.approx_mapping && s > 10000) {
if (!approx_mapping && s > 10000) {
std::cerr << "[wfmash] ERROR: segment length (-s) must be <= 10kb when running alignment." << std::endl
<< "[wfmash] For larger values, use -m/--approx-mapping to generate mappings," << std::endl
<< "[wfmash] then align them with: wfmash ... -i mappings.paf" << std::endl;
Expand Down Expand Up @@ -332,7 +340,7 @@ void parse_args(int argc,
exit(1);
}

if (!yeet_parameters.approx_mapping && l > 30000) {
if (!approx_mapping && l > 30000) {
std::cerr << "[wfmash] ERROR: block length (-l) must be <= 30kb when running alignment." << std::endl
<< "[wfmash] For larger values, use -m/--approx-mapping to generate mappings," << std::endl
<< "[wfmash] then align them with: wfmash ... -i mappings.paf" << std::endl;
Expand Down Expand Up @@ -363,7 +371,7 @@ void parse_args(int argc,
std::cerr << "[wfmash] ERROR: max mapping length must be greater than 0." << std::endl;
exit(1);
}
if (!yeet_parameters.approx_mapping && l > 100000) {
if (!approx_mapping && l > 100000) {
std::cerr << "[wfmash] ERROR: max mapping length (-P) must be <= 100kb when running alignment." << std::endl
<< "[wfmash] For larger values, use -m/--approx-mapping to generate mappings," << std::endl
<< "[wfmash] then align them with: wfmash ... -i mappings.paf" << std::endl;
Expand Down Expand Up @@ -664,12 +672,11 @@ void parse_args(int argc,
}
#endif

args::ValueFlag<int> num_mappings_for_segments(mapping_opts, "N", "number of mappings per segment [1]", {"mappings-per-segment"});
if (num_mappings_for_segments) {
if (args::get(num_mappings_for_segments) > 0) {
map_parameters.numMappingsForSegment = args::get(num_mappings_for_segments) ;
if (num_mappings) {
if (args::get(num_mappings) > 0) {
map_parameters.numMappingsForSegment = args::get(num_mappings);
} else {
std::cerr << "[wfmash] ERROR, skch::parseandSave, the number of mappings to retain for each segment has to be grater than 0." << std::endl;
std::cerr << "[wfmash] ERROR: the number of mappings to retain (-n) must be greater than 0." << std::endl;
exit(1);
}
} else {
Expand Down
115 changes: 107 additions & 8 deletions src/map/include/computeMap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,13 @@ namespace skch
{
struct QueryMappingOutput {
std::string queryName;
std::vector<MappingResult> results;
std::vector<MappingResult> results; // Non-merged mappings
std::vector<MappingResult> mergedResults; // Maximally merged mappings
std::mutex mutex;
progress_meter::ProgressMeter& progress;
QueryMappingOutput(const std::string& name, const std::vector<MappingResult>& r,
const std::vector<MappingResult>& mr, progress_meter::ProgressMeter& p)
: queryName(name), results(r), mergedResults(mr), progress(p) {}
};

struct FragmentData {
Expand Down Expand Up @@ -149,6 +153,9 @@ namespace skch
typedef atomic_queue::AtomicQueue<std::string*, 1024> writer_atomic_queue_t;
typedef atomic_queue::AtomicQueue<QueryMappingOutput*, 1024, nullptr, true, true, false, false> query_output_atomic_queue_t;
typedef atomic_queue::AtomicQueue<FragmentData*, 8192, nullptr, true, true, false, false> fragment_atomic_queue_t;

// Track maximum chain ID seen across all subsets
std::atomic<offset_t> maxChainIdSeen{0};


void processFragment(FragmentData* fragment,
Expand Down Expand Up @@ -178,6 +185,10 @@ namespace skch
{
std::lock_guard<std::mutex> lock(fragment->output->mutex);
fragment->output->results.insert(fragment->output->results.end(), l2Mappings.begin(), l2Mappings.end());
// Initialize mergedResults with same mappings
if (param.mergeMappings && param.split) {
fragment->output->mergedResults.insert(fragment->output->mergedResults.end(), l2Mappings.begin(), l2Mappings.end());
}
}

// Update progress after processing the fragment
Expand Down Expand Up @@ -548,6 +559,7 @@ namespace skch
}
std::cerr << ", average size: " << std::fixed << std::setprecision(0) << avg_subset_size << "bp" << std::endl;

typedef std::vector<MappingResult> MappingResultsVector_t;
std::unordered_map<seqno_t, MappingResultsVector_t> combinedMappings;

// Build index for the current subset
Expand Down Expand Up @@ -666,6 +678,15 @@ namespace skch
}
output_thread.join();

// Process both merged and non-merged mappings
for (auto& [querySeqId, mappings] : combinedMappings) {
if (param.mergeMappings && param.split) {
filterMaximallyMerged(mappings, param, progress);
} else {
filterNonMergedMappings(mappings, param, progress);
}
}

progress.finish();

}
Expand All @@ -681,6 +702,9 @@ namespace skch
"[wfmash::mashmap] mapping ("
+ std::to_string(subset_count + 1) + "/" + std::to_string(total_subsets) + ")");

// Create temporary storage for this subset's mappings
std::unordered_map<seqno_t, MappingResultsVector_t> subsetMappings;

// Launch reader thread
std::thread reader([&]() {
reader_thread(input_queue, reader_done, progress, *idManager);
Expand All @@ -701,9 +725,9 @@ namespace skch
});
}

// Launch aggregator thread
// Launch aggregator thread with subset storage
std::thread aggregator([&]() {
aggregator_thread(merged_queue, workers_done, combinedMappings);
aggregator_thread(merged_queue, workers_done, subsetMappings);
});

// Wait for all threads to complete
Expand All @@ -721,7 +745,22 @@ namespace skch

aggregator.join();

// Reset flags and clear aggregatedMappings for next iteration
// Filter mappings within this subset before merging with previous results
for (auto& [querySeqId, mappings] : subsetMappings) {

// Merge with existing mappings for this query
if (combinedMappings.find(querySeqId) == combinedMappings.end()) {
combinedMappings[querySeqId] = std::move(mappings);
} else {
combinedMappings[querySeqId].insert(
combinedMappings[querySeqId].end(),
std::make_move_iterator(mappings.begin()),
std::make_move_iterator(mappings.end())
);
}
}

// Reset flags for next iteration
reader_done.store(false);
workers_done.store(false);
fragments_done.store(false);
Expand Down Expand Up @@ -939,6 +978,11 @@ namespace skch
}

mappingBoundarySanityCheck(input, output->results);

// Filter and get both merged and non-merged mappings
auto [nonMergedMappings, mergedMappings] = filterSubsetMappings(output->results, input->progress);
output->results = std::move(nonMergedMappings);
output->mergedResults = std::move(mergedMappings);

return output;
}
Expand Down Expand Up @@ -1005,10 +1049,12 @@ namespace skch
QueryMappingOutput* output = nullptr;
if (merged_queue.try_pop(output)) {
seqno_t querySeqId = idManager->getSequenceId(output->queryName);
auto& mappings = output->results;
// Chain IDs are already compacted in mapModule
combinedMappings[querySeqId].insert(
combinedMappings[querySeqId].end(),
output->results.begin(),
output->results.end()
mappings.begin(),
mappings.end()
);
delete output;
} else if (workers_done.load() && merged_queue.was_empty()) {
Expand Down Expand Up @@ -2159,6 +2205,53 @@ namespace skch
* @param begin Iterator to the start of the chain
* @param end Iterator to the end of the chain
*/
/**
* @brief Filter mappings within a subset before aggregation
* @param mappings Mappings to filter
* @param param Algorithm parameters
*/
std::pair<MappingResultsVector_t, MappingResultsVector_t> filterSubsetMappings(MappingResultsVector_t& mappings, progress_meter::ProgressMeter& progress) {
if (mappings.empty()) return {MappingResultsVector_t(), MappingResultsVector_t()};

// Only merge once and keep both versions
auto maximallyMergedMappings = mergeMappingsInRange(mappings, param.chain_gap, progress);

// Build dense chain ID mapping
std::unordered_map<offset_t, offset_t> id_map;
offset_t next_id = 0;

// First pass - build the mapping from both sets
for (const auto& mapping : mappings) {
if (id_map.count(mapping.splitMappingId) == 0) {
id_map[mapping.splitMappingId] = next_id++;
}
}
for (const auto& mapping : maximallyMergedMappings) {
if (id_map.count(mapping.splitMappingId) == 0) {
id_map[mapping.splitMappingId] = next_id++;
}
}

// Get atomic offset for this batch of chain IDs
offset_t base_id = maxChainIdSeen.fetch_add(id_map.size(), std::memory_order_relaxed);

// Apply compacted IDs with offset
for (auto& mapping : mappings) {
mapping.splitMappingId = id_map[mapping.splitMappingId] + base_id;
}
for (auto& mapping : maximallyMergedMappings) {
mapping.splitMappingId = id_map[mapping.splitMappingId] + base_id;
}

return {std::move(mappings), std::move(maximallyMergedMappings)};
}

/**
* @brief Update chain IDs to prevent conflicts between subsets
* @param mappings Mappings whose chain IDs need updating
* @param maxId Current maximum chain ID seen
*/

void computeChainStatistics(std::vector<MappingResult>::iterator begin, std::vector<MappingResult>::iterator end) {
offset_t chain_start_query = std::numeric_limits<offset_t>::max();
offset_t chain_end_query = std::numeric_limits<offset_t>::min();
Expand Down Expand Up @@ -2297,8 +2390,14 @@ namespace skch
auto& mappings = *(task->second);

std::string queryName = idManager->getSequenceName(querySeqId);
processAggregatedMappings(queryName, mappings, progress);

// Final filtering pass on pre-filtered mappings
if (param.filterMode == filter::MAP || param.filterMode == filter::ONETOONE) {
MappingResultsVector_t filteredMappings;
filterByGroup(mappings, filteredMappings, param.numMappingsForSegment - 1,
param.filterMode == filter::ONETOONE, *idManager, progress);
mappings = std::move(filteredMappings);
}

std::stringstream ss;
reportReadMappings(mappings, queryName, ss);

Expand Down

0 comments on commit 643af3e

Please sign in to comment.