Skip to content

Commit

Permalink
Rework masking and add N repeat masking
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-steinegger committed Dec 8, 2024
1 parent e3b16fa commit a2d01d0
Show file tree
Hide file tree
Showing 22 changed files with 241 additions and 181 deletions.
14 changes: 9 additions & 5 deletions lib/tantan/tantan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ struct Tantan {
}
};

void maskSequences(uchar *seqBeg,
int maskSequences(uchar *seqBeg,
uchar *seqEnd,
int maxRepeatOffset,
const const_double_ptr *likelihoodRatioMatrix,
Expand All @@ -498,7 +498,7 @@ void maskSequences(uchar *seqBeg,
repeatOffsetProbDecay, firstGapProb, otherGapProb,
probabilities);

maskProbableLetters(seqBeg, seqEnd, probabilities, minMaskProb, maskTable);
return maskProbableLetters(seqBeg, seqEnd, probabilities, minMaskProb, maskTable);
}

void getProbabilities(const uchar *seqBeg,
Expand All @@ -517,17 +517,21 @@ void getProbabilities(const uchar *seqBeg,
tantan.calcRepeatProbs(probabilities);
}

void maskProbableLetters(uchar *seqBeg,
int maskProbableLetters(uchar *seqBeg,
uchar *seqEnd,
const float *probabilities,
double minMaskProb,
const uchar *maskTable) {
int masked = 0;
while (seqBeg < seqEnd) {
if (*probabilities >= minMaskProb)
*seqBeg = maskTable[*seqBeg];
if (*probabilities >= minMaskProb) {
*seqBeg = maskTable[*seqBeg];
masked++;
}
++probabilities;
++seqBeg;
}
return masked;
}

void countTransitions(const uchar *seqBeg,
Expand Down
4 changes: 2 additions & 2 deletions lib/tantan/tantan.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ namespace tantan {
typedef unsigned char uchar;
typedef const double *const_double_ptr;

void maskSequences(uchar *seqBeg,
int maskSequences(uchar *seqBeg,
uchar *seqEnd,
int maxRepeatOffset,
const const_double_ptr *likelihoodRatioMatrix,
Expand Down Expand Up @@ -87,7 +87,7 @@ void getProbabilities(const uchar *seqBeg,
// The following routine masks each letter whose corresponding entry
// in "probabilities" is >= minMaskProb.

void maskProbableLetters(uchar *seqBeg,
int maskProbableLetters(uchar *seqBeg,
uchar *seqEnd,
const float *probabilities,
double minMaskProb,
Expand Down
1 change: 0 additions & 1 deletion src/alignment/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ set(alignment_header_files
alignment/MsaFilter.h
alignment/MultipleAlignment.h
alignment/PSSMCalculator.h
alignment/PSSMMasker.h
alignment/StripedSmithWaterman.h
alignment/BandedNucleotideAligner.h
alignment/DistanceCalculator.h
Expand Down
52 changes: 0 additions & 52 deletions src/alignment/PSSMMasker.h

This file was deleted.

2 changes: 2 additions & 0 deletions src/commons/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(commons_header_files
commons/itoa.h
commons/KSeqBufferReader.h
commons/KSeqWrapper.h
commons/Masker.h
commons/MathUtil.h
commons/MemoryMapped.h
commons/MemoryTracker.h
Expand Down Expand Up @@ -60,6 +61,7 @@ set(commons_source_files
commons/FileUtil.cpp
commons/HeaderSummarizer.cpp
commons/KSeqWrapper.cpp
commons/Masker.cpp
commons/MemoryMapped.cpp
commons/MemoryTracker.cpp
commons/MMseqsMPI.cpp
Expand Down
136 changes: 136 additions & 0 deletions src/commons/Masker.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#include "Masker.h"
#include <algorithm> // for std::toupper

Masker::Masker(BaseMatrix &s) : subMat(s), probMatrix(s)
{
maxSeqLen = 1;
charSequence = (unsigned char *)malloc(maxSeqLen * sizeof(char));
maskLetterNum = subMat.aa2num[(int)'X'];
}

Masker::~Masker() {
free(charSequence);
}

int Masker::maskSequence(Sequence & seq, bool maskTantan, double maskProb,
bool maskLowerCaseLetter, int maskNrepeats) {

int maskedResidues = 0;

if(maskTantan){
// 1. Apply tantan masking without influencing by repeat mask
maskedResidues += tantan::maskSequences(seq.numSequence,
(seq.numSequence + seq.L),
50 /*maxCycleLength*/,
probMatrix.probMatrixPointers,
0.005 /*repeatProb*/,
0.05 /*repeatEndProb*/,
0.9 /*repeatOffsetProbDecay*/,
0, 0,
maskProb /*minMaskProb*/,
probMatrix.hardMaskTable);
}
if( maskNrepeats > 0){
// 2. Generate the mask for repeats
maskedResidues += maskRepeats(seq.numSequence, seq.L, maskNrepeats, maskLetterNum);
}
// 3. Handle lowercase masking
if(maskLowerCaseLetter){
if ((Parameters::isEqualDbtype(seq.getSequenceType(), Parameters::DBTYPE_AMINO_ACIDS) ||
Parameters::isEqualDbtype(seq.getSequenceType(), Parameters::DBTYPE_NUCLEOTIDES))) {
const char *charSeq = seq.getSeqData();
for (int i = 0; i < seq.L; i++) {
if (std::islower((unsigned char)charSeq[i])) {
seq.numSequence[i] = maskLetterNum; // Apply masking
maskedResidues++;
}
}
}
}
// 4. Finalize masking
if(maskTantan || maskNrepeats || maskLowerCaseLetter){
finalizeMasking(seq.numSequence, seq.L);
}
return maskedResidues;
}

void Masker::maskPssm(Sequence& centerSequence, float maskProb, PSSMCalculator::Profile& pssmRes) {
if ((size_t)centerSequence.L > maxSeqLen) {
maxSeqLen = sizeof(char) * centerSequence.L * 1.5;
charSequence = (unsigned char*)realloc(charSequence, maxSeqLen);
}
memcpy(charSequence, centerSequence.numSequence, sizeof(unsigned char) * centerSequence.L);
tantan::maskSequences(charSequence, charSequence + centerSequence.L,
50 /*options.maxCycleLength*/,
probMatrix.probMatrixPointers,
0.005 /*options.repeatProb*/,
0.05 /*options.repeatEndProb*/,
0.9 /*options.repeatOffsetProbDecay*/,
0, 0,
maskProb /*options.minMaskProb*/,
probMatrix.hardMaskTable);

for (int pos = 0; pos < centerSequence.L; pos++) {
if (charSequence[pos] == maskLetterNum) {
for (size_t aa = 0; aa < Sequence::PROFILE_AA_SIZE; aa++) {
pssmRes.pssm[pos * Sequence::PROFILE_AA_SIZE + aa] = -1;
}
}
}
}


int Masker::maskRepeats(unsigned char * numSequence, const unsigned int seqLen, int maskNrepeating, char maskChar) {

unsigned int repeatCount = 0;
int startOfRepeat = -1;
char previousChar = '\0';
int maskedResidues = 0; // Counter for masked residues

for (unsigned int pos = 0; pos < seqLen; ++pos) {
char currentChar = numSequence[pos];

if (currentChar == previousChar) {
repeatCount++;
} else {
if (repeatCount > (unsigned int)maskNrepeating) {
for (unsigned int i = startOfRepeat; i < pos; ++i) {
numSequence[i] = maskChar;
maskedResidues++;
}
}
repeatCount = 1;
startOfRepeat = pos;
previousChar = currentChar;
}
}

// Handle the last run
if (repeatCount > (unsigned int)maskNrepeating) {
for (unsigned int i = startOfRepeat; i < seqLen; ++i) {
numSequence[i] = maskChar;
maskedResidues++;
}
}

return maskedResidues;
}

void Masker::finalizeMasking(unsigned char * numSequence, const unsigned int seqLen) {
unsigned char maskChar = probMatrix.hardMaskTable[0];

for (unsigned int i = 0; i < seqLen; i++) {
unsigned char code = numSequence[i];
numSequence[i] = (code == maskChar || code == maskLetterNum) ? maskLetterNum : numSequence[i];
}
}

void Masker::applySoftmasking(unsigned char *charSequence, const unsigned char * num_sequence, unsigned int seqLen) {
for (unsigned int pos = 0; pos < seqLen; pos++) {
// If masked, lowercase (soft) or uppercase (hard) could be applied here if needed.
// For simplicity, we treat maskChar as masked and others as uppercase:
charSequence[pos] = (num_sequence[pos] == maskLetterNum)
? (char)std::tolower(charSequence[pos])
: (char)std::toupper(charSequence[pos]);
}
}
37 changes: 37 additions & 0 deletions src/commons/Masker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef MMSEQS_MASKER_H
#define MMSEQS_MASKER_H

#include "Parameters.h"
#include "Sequence.h"
#include "SubstitutionMatrix.h"
#include "tantan.h"
#include "PSSMCalculator.h"
#include <cctype>

class Masker {
public:
Masker(BaseMatrix &subMat);

~Masker();

int maskSequence(Sequence & seq, bool maskTantan, double maskProb,
bool maskLowerCaseLetter, int maskNrepeating);

void maskPssm(Sequence& centerSequence, float maskProb, PSSMCalculator::Profile& pssmRes);

void applySoftmasking(unsigned char *charSequence, const unsigned char * numSequence, unsigned int seqLen);

char maskLetterNum;

private:
int maskRepeats(unsigned char *numSequence, const unsigned int seqLen, int maskNrepeating, char maskChar);

void finalizeMasking(unsigned char * numSequence, const unsigned int seqLen);

BaseMatrix &subMat;
ProbabilityMatrix probMatrix;

unsigned char * charSequence;
size_t maxSeqLen;
};
#endif
9 changes: 8 additions & 1 deletion src/commons/Parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ Parameters::Parameters():
PARAM_MAX_SEQ_LEN(PARAM_MAX_SEQ_LEN_ID, "--max-seq-len", "Max sequence length", "Maximum sequence length", typeid(size_t), (void *) &maxSeqLen, "^[0-9]{1}[0-9]*", MMseqsParameter::COMMAND_COMMON | MMseqsParameter::COMMAND_EXPERT),
PARAM_DIAGONAL_SCORING(PARAM_DIAGONAL_SCORING_ID, "--diag-score", "Diagonal scoring", "Use ungapped diagonal scoring during prefilter", typeid(bool), (void *) &diagonalScoring, "", MMseqsParameter::COMMAND_PREFILTER | MMseqsParameter::COMMAND_EXPERT),
PARAM_EXACT_KMER_MATCHING(PARAM_EXACT_KMER_MATCHING_ID, "--exact-kmer-matching", "Exact k-mer matching", "Extract only exact k-mers for matching (range 0-1)", typeid(int), (void *) &exactKmerMatching, "^[0-1]{1}$", MMseqsParameter::COMMAND_PREFILTER | MMseqsParameter::COMMAND_EXPERT),
PARAM_MASK_RESIDUES(PARAM_MASK_RESIDUES_ID, "--mask", "Mask residues", "Mask sequences in k-mer stage: 0: w/o low complexity masking, 1: with low complexity masking", typeid(int), (void *) &maskMode, "^[0-1]{1}", MMseqsParameter::COMMAND_PREFILTER | MMseqsParameter::COMMAND_EXPERT),
PARAM_MASK_RESIDUES(PARAM_MASK_RESIDUES_ID, "--mask", "Mask residues", "Mask sequences in prefilter stage with tantan: 0: w/o low complexity masking, 1: with low complexity masking", typeid(int), (void *) &maskMode, "^[0-1]{1}", MMseqsParameter::COMMAND_PREFILTER | MMseqsParameter::COMMAND_EXPERT),
PARAM_MASK_PROBABILTY(PARAM_MASK_PROBABILTY_ID, "--mask-prob", "Mask residues probability", "Mask sequences is probablity is above threshold", typeid(float), (void *) &maskProb, "^0(\\.[0-9]+)?|^1(\\.0+)?$", MMseqsParameter::COMMAND_PREFILTER | MMseqsParameter::COMMAND_EXPERT),
PARAM_MASK_LOWER_CASE(PARAM_MASK_LOWER_CASE_ID, "--mask-lower-case", "Mask lower case residues", "Lowercase letters will be excluded from k-mer search 0: include region, 1: exclude region", typeid(int), (void *) &maskLowerCaseMode, "^[0-1]{1}", MMseqsParameter::COMMAND_PREFILTER | MMseqsParameter::COMMAND_EXPERT),
PARAM_MASK_N_REPEAT(PARAM_MASK_N_REPEAT_ID, "--mask-n-repeat", "Mask lower letter repeating N times", "Repeat letters that occure > threshold in a rwo", typeid(int), (void *) &maskNrepeats, "^[0-9]{1}[0-9]*$", MMseqsParameter::COMMAND_PREFILTER | MMseqsParameter::COMMAND_EXPERT),
PARAM_MIN_DIAG_SCORE(PARAM_MIN_DIAG_SCORE_ID, "--min-ungapped-score", "Minimum diagonal score", "Accept only matches with ungapped alignment score above threshold", typeid(int), (void *) &minDiagScoreThr, "^[0-9]{1}[0-9]*$", MMseqsParameter::COMMAND_PREFILTER | MMseqsParameter::COMMAND_EXPERT),
PARAM_K_SCORE(PARAM_K_SCORE_ID, "--k-score", "k-score", "k-mer threshold for generating similar k-mer lists", typeid(MultiParam<SeqProf<int>>), (void *) &kmerScore, "^[0-9]{1}[0-9]*$", MMseqsParameter::COMMAND_PREFILTER | MMseqsParameter::COMMAND_EXPERT),
PARAM_MAX_SEQS(PARAM_MAX_SEQS_ID, "--max-seqs", "Max results per query", "Maximum results per query sequence allowed to pass the prefilter (affects sensitivity)", typeid(size_t), (void *) &maxResListLen, "^[1-9]{1}[0-9]*$", MMseqsParameter::COMMAND_PREFILTER),
Expand Down Expand Up @@ -427,6 +428,7 @@ Parameters::Parameters():
prefilter.push_back(&PARAM_MASK_RESIDUES);
prefilter.push_back(&PARAM_MASK_PROBABILTY);
prefilter.push_back(&PARAM_MASK_LOWER_CASE);
prefilter.push_back(&PARAM_MASK_N_REPEAT);
prefilter.push_back(&PARAM_MIN_DIAG_SCORE);
prefilter.push_back(&PARAM_TAXON_LIST);
prefilter.push_back(&PARAM_INCLUDE_IDENTITY);
Expand Down Expand Up @@ -788,6 +790,7 @@ Parameters::Parameters():
indexdb.push_back(&PARAM_MASK_RESIDUES);
indexdb.push_back(&PARAM_MASK_PROBABILTY);
indexdb.push_back(&PARAM_MASK_LOWER_CASE);
indexdb.push_back(&PARAM_MASK_N_REPEAT);
indexdb.push_back(&PARAM_SPACED_KMER_MODE);
indexdb.push_back(&PARAM_SPACED_KMER_PATTERN);
indexdb.push_back(&PARAM_S);
Expand Down Expand Up @@ -815,6 +818,7 @@ Parameters::Parameters():
kmerindexdb.push_back(&PARAM_MASK_RESIDUES);
kmerindexdb.push_back(&PARAM_MASK_PROBABILTY);
kmerindexdb.push_back(&PARAM_MASK_LOWER_CASE);
kmerindexdb.push_back(&PARAM_MASK_N_REPEAT);
kmerindexdb.push_back(&PARAM_CHECK_COMPATIBLE);
kmerindexdb.push_back(&PARAM_SEARCH_TYPE);
kmerindexdb.push_back(&PARAM_SPACED_KMER_MODE);
Expand Down Expand Up @@ -992,6 +996,7 @@ Parameters::Parameters():
kmermatcher.push_back(&PARAM_MASK_RESIDUES);
kmermatcher.push_back(&PARAM_MASK_PROBABILTY);
kmermatcher.push_back(&PARAM_MASK_LOWER_CASE);
kmermatcher.push_back(&PARAM_MASK_N_REPEAT);
kmermatcher.push_back(&PARAM_COV_MODE);
kmermatcher.push_back(&PARAM_K);
kmermatcher.push_back(&PARAM_C);
Expand All @@ -1013,6 +1018,7 @@ Parameters::Parameters():
kmersearch.push_back(&PARAM_MASK_RESIDUES);
kmersearch.push_back(&PARAM_MASK_PROBABILTY);
kmersearch.push_back(&PARAM_MASK_LOWER_CASE);
kmersearch.push_back(&PARAM_MASK_N_REPEAT);
kmersearch.push_back(&PARAM_COV_MODE);
kmersearch.push_back(&PARAM_C);
kmersearch.push_back(&PARAM_MAX_SEQ_LEN);
Expand Down Expand Up @@ -2350,6 +2356,7 @@ void Parameters::setDefaults() {
maskMode = 1;
maskProb = 0.9;
maskLowerCaseMode = 0;
maskNrepeats = 0;
minDiagScoreThr = 15;
spacedKmer = true;
includeIdentity = false;
Expand Down
2 changes: 2 additions & 0 deletions src/commons/Parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ class Parameters {
int maskMode; // mask low complex areas
float maskProb; // mask probability
int maskLowerCaseMode; // mask lowercase letters in prefilter and kmermatchers
int maskNrepeats; // mask letters that occur at least N times in a row

int minDiagScoreThr; // min diagonal score
int spacedKmer; // Spaced Kmers
Expand Down Expand Up @@ -754,6 +755,7 @@ class Parameters {
PARAMETER(PARAM_MASK_RESIDUES)
PARAMETER(PARAM_MASK_PROBABILTY)
PARAMETER(PARAM_MASK_LOWER_CASE)
PARAMETER(PARAM_MASK_N_REPEAT)

PARAMETER(PARAM_MIN_DIAG_SCORE)
PARAMETER(PARAM_K_SCORE)
Expand Down
Loading

0 comments on commit a2d01d0

Please sign in to comment.