Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefilter mappings to save memory in batched mapping #301

Merged
merged 21 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6cc82cb
feat: Add flag to disable sequence grouping with -G option
ekg Nov 21, 2024
999d458
feat: Implement subset filtering optimization for memory efficiency
ekg Nov 25, 2024
2ec4192
refactor: Improve atomic chain ID update with max subset calculation
ekg Nov 25, 2024
b4ee587
refactor: Improve mapping aggregation across target subsets
ekg Nov 25, 2024
b64fadb
check if the approx mapping flag is set to correctly detect if we are…
ekg Nov 25, 2024
a735977
refactor: Improve mapping filtering and chaining logic in parallel pr…
ekg Nov 25, 2024
3fa77f9
refactor: Move subset mapping filtering and chain ID updates to corre…
ekg Nov 25, 2024
b3bada2
refactor: Modify mapping filtering to preserve merged and non-merged …
ekg Nov 25, 2024
b09c722
feat: Add mapping filtering logic for merged and non-merged results
ekg Nov 25, 2024
ea1b006
fix: Update QueryMappingOutput and CombinedMappingResults to resolve …
ekg Nov 25, 2024
0247b17
fix: Resolve compilation errors in computeMap.hpp mapping logic
ekg Nov 25, 2024
b0dcb86
fix: Update computeMap.hpp to resolve compilation errors with mapping…
ekg Nov 25, 2024
28a3a8f
fix: Refactor CombinedMappingResults to simplify mapping handling
ekg Nov 25, 2024
f45b330
fix: Resolve compilation errors in computeMap.hpp by updating mapping…
ekg Nov 25, 2024
25d11a4
refactor: Optimize chain ID generation to use sequential, smaller IDs
ekg Nov 25, 2024
e218f13
refactor: Preserve chain relationships in updateChainIds by using bas…
ekg Nov 25, 2024
f1d6046
refactor: Implement dense chain ID mapping with compact range generation
ekg Nov 25, 2024
e97d224
refactor: Move chain ID compaction to MapModule with atomic offset
ekg Nov 25, 2024
a13e992
feat: Add -X flag to control self-mapping behavior with -G
ekg Nov 25, 2024
ce2175d
refactor: Remove redundant mappings parameter and use -n/--mappings f…
ekg Nov 27, 2024
85a85f8
Merge branch 'main' of github.com:waveygang/wfmash into prefilter
ekg Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/interface/parse_args.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ void parse_args(int argc,
args::Group mapping_opts(options_group, "Mapping:");
args::Flag approx_mapping(mapping_opts, "", "output approximate mappings (no alignment)", {'m', "approx-mapping"});
args::ValueFlag<float> map_pct_identity(mapping_opts, "FLOAT", "minimum mapping identity [70]", {'p', "map-pct-id"});
args::ValueFlag<uint32_t> num_mappings(mapping_opts, "INT", "number of mappings to keep per query/target pair [1]", {'n', "mappings"});
args::ValueFlag<uint32_t> num_mappings(mapping_opts, "INT", "number of mappings to keep per segment [1]", {'n', "mappings"});
args::ValueFlag<std::string> segment_length(mapping_opts, "INT", "segment length for mapping [1k]", {'s', "segment-length"});
args::ValueFlag<std::string> block_length(mapping_opts, "INT", "minimum block length [3*segment-length]", {'l', "block-length"});
args::Flag one_to_one(mapping_opts, "", "Perform one-to-one filtering", {'o', "one-to-one"});
args::Flag lower_triangular(mapping_opts, "", "Only compute the lower triangular for all-vs-all mapping", {'L', "lower-triangular"});
args::ValueFlag<char> skip_prefix(mapping_opts, "C", "map between sequence groups with different prefix [#]", {'Y', "group-prefix"});
args::Flag disable_grouping(mapping_opts, "", "disable sequence grouping and exclude self mappings", {'G', "disable-grouping"});
args::Flag enable_self_mappings(mapping_opts, "", "enable self mappings (overrides -G)", {'X', "self-maps"});
args::ValueFlag<std::string> target_prefix(mapping_opts, "pfx", "use only targets whose names start with this prefix", {'T', "target-prefix"});
args::ValueFlag<std::string> target_list(mapping_opts, "FILE", "file containing list of target sequence names to use", {'R', "target-list"});
args::ValueFlag<std::string> query_prefix(mapping_opts, "pfxs", "filter queries by comma-separated prefixes", {'Q', "query-prefix"});
Expand Down Expand Up @@ -169,11 +171,17 @@ void parse_args(int argc,
exit(1);
}

map_parameters.skip_self = false;
map_parameters.skip_self = !args::get(enable_self_mappings);
map_parameters.lower_triangular = lower_triangular ? args::get(lower_triangular) : false;
map_parameters.keep_low_pct_id = true;

if (skip_prefix) {
if (disable_grouping) {
map_parameters.prefix_delim = '\0';
map_parameters.skip_prefix = false;
if (!args::get(enable_self_mappings)) {
map_parameters.skip_self = true;
}
} else if (skip_prefix) {
map_parameters.prefix_delim = args::get(skip_prefix);
map_parameters.skip_prefix = map_parameters.prefix_delim != '\0';
} else {
Expand Down Expand Up @@ -303,7 +311,7 @@ void parse_args(int argc,
exit(1);
}

if (!yeet_parameters.approx_mapping && s > 10000) {
if (!approx_mapping && s > 10000) {
std::cerr << "[wfmash] ERROR: segment length (-s) must be <= 10kb when running alignment." << std::endl
<< "[wfmash] For larger values, use -m/--approx-mapping to generate mappings," << std::endl
<< "[wfmash] then align them with: wfmash ... -i mappings.paf" << std::endl;
Expand Down Expand Up @@ -332,7 +340,7 @@ void parse_args(int argc,
exit(1);
}

if (!yeet_parameters.approx_mapping && l > 30000) {
if (!approx_mapping && l > 30000) {
std::cerr << "[wfmash] ERROR: block length (-l) must be <= 30kb when running alignment." << std::endl
<< "[wfmash] For larger values, use -m/--approx-mapping to generate mappings," << std::endl
<< "[wfmash] then align them with: wfmash ... -i mappings.paf" << std::endl;
Expand Down Expand Up @@ -363,7 +371,7 @@ void parse_args(int argc,
std::cerr << "[wfmash] ERROR: max mapping length must be greater than 0." << std::endl;
exit(1);
}
if (!yeet_parameters.approx_mapping && l > 100000) {
if (!approx_mapping && l > 100000) {
std::cerr << "[wfmash] ERROR: max mapping length (-P) must be <= 100kb when running alignment." << std::endl
<< "[wfmash] For larger values, use -m/--approx-mapping to generate mappings," << std::endl
<< "[wfmash] then align them with: wfmash ... -i mappings.paf" << std::endl;
Expand Down Expand Up @@ -664,12 +672,11 @@ void parse_args(int argc,
}
#endif

args::ValueFlag<int> num_mappings_for_segments(mapping_opts, "N", "number of mappings per segment [1]", {"mappings-per-segment"});
if (num_mappings_for_segments) {
if (args::get(num_mappings_for_segments) > 0) {
map_parameters.numMappingsForSegment = args::get(num_mappings_for_segments) ;
if (num_mappings) {
if (args::get(num_mappings) > 0) {
map_parameters.numMappingsForSegment = args::get(num_mappings);
} else {
std::cerr << "[wfmash] ERROR, skch::parseandSave, the number of mappings to retain for each segment has to be grater than 0." << std::endl;
std::cerr << "[wfmash] ERROR: the number of mappings to retain (-n) must be greater than 0." << std::endl;
exit(1);
}
} else {
Expand Down
115 changes: 107 additions & 8 deletions src/map/include/computeMap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,13 @@ namespace skch
{
struct QueryMappingOutput {
std::string queryName;
std::vector<MappingResult> results;
std::vector<MappingResult> results; // Non-merged mappings
std::vector<MappingResult> mergedResults; // Maximally merged mappings
std::mutex mutex;
progress_meter::ProgressMeter& progress;
QueryMappingOutput(const std::string& name, const std::vector<MappingResult>& r,
const std::vector<MappingResult>& mr, progress_meter::ProgressMeter& p)
: queryName(name), results(r), mergedResults(mr), progress(p) {}
};

struct FragmentData {
Expand Down Expand Up @@ -149,6 +153,9 @@ namespace skch
typedef atomic_queue::AtomicQueue<std::string*, 1024> writer_atomic_queue_t;
typedef atomic_queue::AtomicQueue<QueryMappingOutput*, 1024, nullptr, true, true, false, false> query_output_atomic_queue_t;
typedef atomic_queue::AtomicQueue<FragmentData*, 8192, nullptr, true, true, false, false> fragment_atomic_queue_t;

// Track maximum chain ID seen across all subsets
std::atomic<offset_t> maxChainIdSeen{0};


void processFragment(FragmentData* fragment,
Expand Down Expand Up @@ -178,6 +185,10 @@ namespace skch
{
std::lock_guard<std::mutex> lock(fragment->output->mutex);
fragment->output->results.insert(fragment->output->results.end(), l2Mappings.begin(), l2Mappings.end());
// Initialize mergedResults with same mappings
if (param.mergeMappings && param.split) {
fragment->output->mergedResults.insert(fragment->output->mergedResults.end(), l2Mappings.begin(), l2Mappings.end());
}
}

// Update progress after processing the fragment
Expand Down Expand Up @@ -548,6 +559,7 @@ namespace skch
}
std::cerr << ", average size: " << std::fixed << std::setprecision(0) << avg_subset_size << "bp" << std::endl;

typedef std::vector<MappingResult> MappingResultsVector_t;
std::unordered_map<seqno_t, MappingResultsVector_t> combinedMappings;

// Build index for the current subset
Expand Down Expand Up @@ -666,6 +678,15 @@ namespace skch
}
output_thread.join();

// Process both merged and non-merged mappings
for (auto& [querySeqId, mappings] : combinedMappings) {
if (param.mergeMappings && param.split) {
filterMaximallyMerged(mappings, param, progress);
} else {
filterNonMergedMappings(mappings, param, progress);
}
}

progress.finish();

}
Expand All @@ -681,6 +702,9 @@ namespace skch
"[wfmash::mashmap] mapping ("
+ std::to_string(subset_count + 1) + "/" + std::to_string(total_subsets) + ")");

// Create temporary storage for this subset's mappings
std::unordered_map<seqno_t, MappingResultsVector_t> subsetMappings;

// Launch reader thread
std::thread reader([&]() {
reader_thread(input_queue, reader_done, progress, *idManager);
Expand All @@ -701,9 +725,9 @@ namespace skch
});
}

// Launch aggregator thread
// Launch aggregator thread with subset storage
std::thread aggregator([&]() {
aggregator_thread(merged_queue, workers_done, combinedMappings);
aggregator_thread(merged_queue, workers_done, subsetMappings);
});

// Wait for all threads to complete
Expand All @@ -721,7 +745,22 @@ namespace skch

aggregator.join();

// Reset flags and clear aggregatedMappings for next iteration
// Filter mappings within this subset before merging with previous results
for (auto& [querySeqId, mappings] : subsetMappings) {

// Merge with existing mappings for this query
if (combinedMappings.find(querySeqId) == combinedMappings.end()) {
combinedMappings[querySeqId] = std::move(mappings);
} else {
combinedMappings[querySeqId].insert(
combinedMappings[querySeqId].end(),
std::make_move_iterator(mappings.begin()),
std::make_move_iterator(mappings.end())
);
}
}

// Reset flags for next iteration
reader_done.store(false);
workers_done.store(false);
fragments_done.store(false);
Expand Down Expand Up @@ -939,6 +978,11 @@ namespace skch
}

mappingBoundarySanityCheck(input, output->results);

// Filter and get both merged and non-merged mappings
auto [nonMergedMappings, mergedMappings] = filterSubsetMappings(output->results, input->progress);
output->results = std::move(nonMergedMappings);
output->mergedResults = std::move(mergedMappings);

return output;
}
Expand Down Expand Up @@ -1005,10 +1049,12 @@ namespace skch
QueryMappingOutput* output = nullptr;
if (merged_queue.try_pop(output)) {
seqno_t querySeqId = idManager->getSequenceId(output->queryName);
auto& mappings = output->results;
// Chain IDs are already compacted in mapModule
combinedMappings[querySeqId].insert(
combinedMappings[querySeqId].end(),
output->results.begin(),
output->results.end()
mappings.begin(),
mappings.end()
);
delete output;
} else if (workers_done.load() && merged_queue.was_empty()) {
Expand Down Expand Up @@ -2159,6 +2205,53 @@ namespace skch
* @param begin Iterator to the start of the chain
* @param end Iterator to the end of the chain
*/
/**
* @brief Filter mappings within a subset before aggregation
* @param mappings Mappings to filter
* @param param Algorithm parameters
*/
std::pair<MappingResultsVector_t, MappingResultsVector_t> filterSubsetMappings(MappingResultsVector_t& mappings, progress_meter::ProgressMeter& progress) {
if (mappings.empty()) return {MappingResultsVector_t(), MappingResultsVector_t()};

// Only merge once and keep both versions
auto maximallyMergedMappings = mergeMappingsInRange(mappings, param.chain_gap, progress);

// Build dense chain ID mapping
std::unordered_map<offset_t, offset_t> id_map;
offset_t next_id = 0;

// First pass - build the mapping from both sets
for (const auto& mapping : mappings) {
if (id_map.count(mapping.splitMappingId) == 0) {
id_map[mapping.splitMappingId] = next_id++;
}
}
for (const auto& mapping : maximallyMergedMappings) {
if (id_map.count(mapping.splitMappingId) == 0) {
id_map[mapping.splitMappingId] = next_id++;
}
}

// Get atomic offset for this batch of chain IDs
offset_t base_id = maxChainIdSeen.fetch_add(id_map.size(), std::memory_order_relaxed);

// Apply compacted IDs with offset
for (auto& mapping : mappings) {
mapping.splitMappingId = id_map[mapping.splitMappingId] + base_id;
}
for (auto& mapping : maximallyMergedMappings) {
mapping.splitMappingId = id_map[mapping.splitMappingId] + base_id;
}

return {std::move(mappings), std::move(maximallyMergedMappings)};
}

/**
* @brief Update chain IDs to prevent conflicts between subsets
* @param mappings Mappings whose chain IDs need updating
* @param maxId Current maximum chain ID seen
*/

void computeChainStatistics(std::vector<MappingResult>::iterator begin, std::vector<MappingResult>::iterator end) {
offset_t chain_start_query = std::numeric_limits<offset_t>::max();
offset_t chain_end_query = std::numeric_limits<offset_t>::min();
Expand Down Expand Up @@ -2297,8 +2390,14 @@ namespace skch
auto& mappings = *(task->second);

std::string queryName = idManager->getSequenceName(querySeqId);
processAggregatedMappings(queryName, mappings, progress);

// Final filtering pass on pre-filtered mappings
if (param.filterMode == filter::MAP || param.filterMode == filter::ONETOONE) {
MappingResultsVector_t filteredMappings;
filterByGroup(mappings, filteredMappings, param.numMappingsForSegment - 1,
param.filterMode == filter::ONETOONE, *idManager, progress);
mappings = std::move(filteredMappings);
}

std::stringstream ss;
reportReadMappings(mappings, queryName, ss);

Expand Down
Loading