Skip to content

Commit

Permalink
Rework prefilter
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-steinegger committed Apr 29, 2023
1 parent 92d8cc3 commit 8a89305
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions src/prefiltering/ungappedprefilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ int doRescorealldiagonal(Parameters &par, DBReader<unsigned int> &qdbr, DBWriter


Debug::Progress progress(dbSize);
std::vector<hit_t> shortResults;

#pragma omp parallel
{
Expand All @@ -70,17 +71,16 @@ int doRescorealldiagonal(Parameters &par, DBReader<unsigned int> &qdbr, DBWriter
thread_idx = (unsigned int) omp_get_thread_num();
#endif
char buffer[1024+32768];
std::vector<hit_t> shortResults;
shortResults.reserve(std::max(static_cast<size_t >(1), tdbr->getSize()/5));
std::vector<hit_t> threadShortResults;
Sequence qSeq(par.maxSeqLen, querySeqType, subMat, 0, false, par.compBiasCorrection);
Sequence tSeq(par.maxSeqLen, targetSeqType, subMat, 0, false, par.compBiasCorrection);
SmithWaterman aligner(par.maxSeqLen, subMat->alphabetSize,
par.compBiasCorrection, par.compBiasCorrectionScale, targetSeqType);

std::string resultBuffer;
resultBuffer.reserve(262144);
#pragma omp for schedule(dynamic, 1)
for (size_t id = dbStart; id < (dbStart+dbSize); id++) {
#pragma omp master
progress.updateProgress();
char *querySeqData = qdbr.getData(id, thread_idx);
size_t queryKey = qdbr.getDbKey(id);
Expand All @@ -93,7 +93,7 @@ int doRescorealldiagonal(Parameters &par, DBReader<unsigned int> &qdbr, DBWriter
}else{
aligner.ssw_init(&qSeq, tinySubMat, subMat);
}

#pragma omp for schedule(dynamic, 1) nowait
for (size_t tId = 0; tId < tdbr->getSize(); tId++) {
unsigned int targetKey = tdbr->getDbKey(tId);
const bool isIdentity = (queryKey == targetKey && (par.includeIdentity || sameDB))? true : false;
Expand All @@ -109,28 +109,40 @@ int doRescorealldiagonal(Parameters &par, DBReader<unsigned int> &qdbr, DBWriter

int score = aligner.ungapped_alignment(tSeq.numSequence, tSeq.L);
bool hasDiagScore = (score > par.minDiagScoreThr);
double evalue = evaluer->computeEvalue(score, qSeq.L);
double evalue = 0.0;
// check if evalThr != inf
if (par.evalThr < std::numeric_limits<double>::max()) {
evalue = evaluer->computeEvalue(score, qSeq.L);
}
bool hasEvalue = (evalue <= par.evalThr);
// --filter-hits
if (isIdentity || (hasDiagScore && hasEvalue)) {
hit_t hit;
hit.seqId = targetKey;
hit.prefScore = score;
hit.diagonal = 0;
shortResults.emplace_back(hit);
threadShortResults.emplace_back(hit);
}
}

SORT_SERIAL(shortResults.begin(), shortResults.end(), hit_t::compareHitsByScoreAndId);
size_t maxSeqs = std::min(par.maxResListLen, shortResults.size());
for (size_t i = 0; i < maxSeqs; ++i) {
size_t len = QueryMatcher::prefilterHitToBuffer(buffer, shortResults[i]);
resultBuffer.append(buffer, len);
#pragma omp critical
{
shortResults.insert(shortResults.end(), threadShortResults.begin(), threadShortResults.end());
threadShortResults.clear();
}
#pragma omp barrier
#pragma omp master
{
SORT_PARALLEL(shortResults.begin(), shortResults.end(), hit_t::compareHitsByScoreAndId);
size_t maxSeqs = std::min(par.maxResListLen, shortResults.size());
for (size_t i = 0; i < maxSeqs; ++i) {
size_t len = QueryMatcher::prefilterHitToBuffer(buffer, shortResults[i]);
resultBuffer.append(buffer, len);
}

resultWriter.writeData(resultBuffer.c_str(), resultBuffer.length(), queryKey, thread_idx);
resultBuffer.clear();
shortResults.clear();
resultWriter.writeData(resultBuffer.c_str(), resultBuffer.length(), queryKey, 0);
resultBuffer.clear();
shortResults.clear();
}
}
}

Expand Down Expand Up @@ -163,7 +175,7 @@ int ungappedprefilter(int argc, const char **argv, const Command &command) {

qdbr.decomposeDomainByAminoAcid(MMseqsMPI::rank, MMseqsMPI::numProc, &dbFrom, &dbSize);
std::pair<std::string, std::string> tmpOutput = Util::createTmpFileNames(par.db3, par.db3Index, MMseqsMPI::rank);
DBWriter resultWriter(tmpOutput.first.c_str(), tmpOutput.second.c_str(), par.threads, par.compressed, Parameters::DBTYPE_PREFILTER_RES);
DBWriter resultWriter(tmpOutput.first.c_str(), tmpOutput.second.c_str(), 1 par.compressed, Parameters::DBTYPE_PREFILTER_RES);
resultWriter.open();
int status = doRescorealldiagonal(par, qdbr, resultWriter, dbFrom, dbSize);
resultWriter.close();
Expand All @@ -178,7 +190,7 @@ int ungappedprefilter(int argc, const char **argv, const Command &command) {
DBWriter::mergeResults(par.db3, par.db3Index, splitFiles);
}
#else
DBWriter resultWriter(par.db3.c_str(), par.db3Index.c_str(), par.threads, par.compressed, Parameters::DBTYPE_PREFILTER_RES);
DBWriter resultWriter(par.db3.c_str(), par.db3Index.c_str(), 1, par.compressed, Parameters::DBTYPE_PREFILTER_RES);
resultWriter.open();
int status = doRescorealldiagonal(par, qdbr, resultWriter, 0, qdbr.getSize());
resultWriter.close();
Expand Down

0 comments on commit 8a89305

Please sign in to comment.