Skip to content

Commit

Permalink
update - index build with multi-thread support
Browse files Browse the repository at this point in the history
  • Loading branch information
quito418 committed Apr 27, 2022
1 parent 16ebdaf commit 4f1e25b
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 81 deletions.
238 changes: 158 additions & 80 deletions src/Learnedindex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "LearnedIndex_seeding.h"
#include "memcpy_bwamem.h"
#include <vector>
#include <time.h>
#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -139,7 +140,7 @@ void buildSAandLEP(char* prefix, int num_threads){
char pac_file_name[PATH_MAX];
strcpy_s(pac_file_name, PATH_MAX, prefix);
strcat_s(pac_file_name, PATH_MAX, ".pac");
printf("[Build-LearnedIndexmode] pac2nt function...\n");
fprintf(stderr,"[Build-LearnedIndexmode] pac2nt function...\n");
pac2nt_(pac_file_name, reference_seq);


Expand Down Expand Up @@ -215,7 +216,7 @@ void buildSAandLEP(char* prefix, int num_threads){
}
}

printf("[Build-LearnedIndexmode] ref seq len = %ld\n", pac_len);
fprintf(stderr,"[Build-LearnedIndexmode] ref seq len = %ld\n", pac_len);
binary_ref_stream.write(binary_ref_seq, pac_len * sizeof(char));
binary_ref_stream.close();
assert(pac_len%2 ==0 );
Expand All @@ -224,7 +225,7 @@ void buildSAandLEP(char* prefix, int num_threads){
binary_ref_seq[i] = 3;
reference_seq += "T";
}
printf("[Build-LearnedIndexmode] padded ref len = %ld\n", pac_len);
fprintf(stderr,"[Build-LearnedIndexmode] padded ref len = %ld\n", pac_len);
//build suffix array
size = (pac_len + 2) * sizeof(int64_t);
int64_t *suffix_array=(int64_t *)_mm_malloc(size, 64);
Expand All @@ -233,11 +234,12 @@ void buildSAandLEP(char* prefix, int num_threads){
startTick = __rdtsc();


printf("[Build-LearnedIndexmode] Building Suffix array\n");
fprintf(stderr,"[Build-LearnedIndexmode] Building Suffix array with sais library\n");
uint64_t query_k_mer = 32; // fix this to 32, use front 32 character for Learned Index model inference

clock_t t;
t = clock();
status = saisxx(reference_seq.c_str(), suffix_array, pac_len);

fprintf(stderr, "%.2f sec\n", (float)(clock() - t) / CLOCKS_PER_SEC);
// [ Ref ][ Ref_complement ][TTTTTTTTTT...T]

uint64_t total_sa_num = pac_len - padded_t_len;//
Expand Down Expand Up @@ -279,92 +281,174 @@ void buildSAandLEP(char* prefix, int num_threads){
std::cerr << "unable to open " << pos_filename << std::endl;
exit(EXIT_FAILURE);
}
// #if MEM_TRADEOFF && READ_FROM_FILE
// Design3: ISA
uint8_t *ref_to_sapos = (uint8_t *)_mm_malloc(5*total_sa_num*sizeof(uint8_t) , 64);
// #endif

uint8_t c;
bwtintv_t ik, ok[4];
uint32_t prevHits;
uint64_t r;
uint64_t sa_count=0;
printf("[Build-LearnedIndexmode] Build and Save SA, Pos\n");
for (i=0 ; i< pac_len; i++){

if ( pac_len - padded_t_len <= suffix_array[i] && pac_len > suffix_array[i] ){
continue;
}

uint32_t pos_val = suffix_array[i] >>8 ;
uint8_t ls_val= suffix_array[i] & 0xff;
pos_out.write(reinterpret_cast<const char*>(&pos_val), sizeof(uint32_t));
pos_out.write(reinterpret_cast<const char*>(&ls_val), sizeof(uint8_t));

possa_out.write(reinterpret_cast<const char*>(&pos_val), sizeof(uint32_t));
possa_out.write(reinterpret_cast<const char*>(&ls_val), sizeof(uint8_t));
// #if MEM_TRADEOFF && READ_FROM_FILE

*(uint32_t*)(ref_to_sapos+ suffix_array[i]*5) = (uint32_t) (sa_count>>8);
*(ref_to_sapos+ suffix_array[i]*5+4) = (uint8_t)(sa_count & 0xff);
// #endif
sa_count++;
assert(suffix_array[i] < pac_len && suffix_array[i]>=0);


// char c = binary_ref_seq[suffix_array[i]-1];
uint64_t binary_suffix_array = 0;
uint64_t reverse_bit = 0;
for (r =0 ; r < query_k_mer ; r++){
binary_suffix_array = binary_suffix_array << 2;
reverse_bit = reverse_bit <<2;
switch (binary_ref_seq[ (suffix_array[i]+r)%pac_len]){
case 0:
binary_suffix_array = (binary_suffix_array|0);

break;
case 1:
binary_suffix_array = (binary_suffix_array|1);
break;
case 2:
binary_suffix_array = (binary_suffix_array|2);
break;
case 3:
binary_suffix_array = (binary_suffix_array|3);
break;
fprintf(stderr,"[Build-LearnedIndexmode] Writing index files... should take a while\n");
uint64_t index_build_batch_size = 10000;


uint8_t *ref_to_sapos = (uint8_t *)_mm_malloc(5*total_sa_num*sizeof(uint8_t) , 64);

uint8_t pos_out_batch[index_build_batch_size*5];
uint8_t possa_out_batch[index_build_batch_size*13];
uint8_t sa_out_batch[index_build_batch_size*8];

uint64_t cumulative_sa = 0;
for (i=0 ; i< pac_len; i += index_build_batch_size){
int padded_t_flag = 0;
uint64_t write_num = i + index_build_batch_size < pac_len ? index_build_batch_size : pac_len - i;


#pragma omp parallel num_threads(num_threads) shared(i, index_build_batch_size, padded_t_flag, write_num)
{
#pragma omp for schedule(monotonic:dynamic)
for (uint64_t j =i; j < i+write_num; j++){
if (padded_t_flag) continue;
if ( pac_len - padded_t_len <= suffix_array[j] && pac_len > suffix_array[j] ){
#pragma omp atomic
padded_t_flag++; // if there is padded T in current batch, we should process it with single-thread
continue;
}
// fill in ref2sa (Inverse suffix array)
uint32_t val_ref2sa_4 = (cumulative_sa + (j-i)) >> 8;
uint8_t val_ref2sa_1 = (cumulative_sa + (j-i)) & 0xff;
memcpy( ref_to_sapos + suffix_array[j]*5, &val_ref2sa_4, 4);
memcpy( ref_to_sapos + suffix_array[j]*5 + 4, &val_ref2sa_1, 1);

// fill in suffix array in packed binary form
uint32_t pos_val = suffix_array[j] >>8 ;
uint8_t ls_val= suffix_array[j] & 0xff;
memcpy(pos_out_batch + (j-i)*5, &pos_val, 4 );
memcpy(pos_out_batch + (j-i)*5 + 4, &ls_val, 1 );

// fill in suffix array and corresponding 64-bit suffix in packed binary form
memcpy(possa_out_batch + (j-i)*13, &pos_val, 4 );
memcpy(possa_out_batch + (j-i)*13 + 4, &ls_val, 1 );

// below code generates the 64-bit suffix to be added to .suffixarray_uint64 and possa_packed file
uint64_t binary_suffix_array = 0;
uint64_t reverse_bit = 0;
for (r =0 ; r < query_k_mer ; r++){
binary_suffix_array = binary_suffix_array << 2;
reverse_bit = reverse_bit <<2;
switch (binary_ref_seq[ (suffix_array[j]+r)%pac_len]){
case 0:
binary_suffix_array = (binary_suffix_array|0);
break;
case 1:
binary_suffix_array = (binary_suffix_array|1);
break;
case 2:
binary_suffix_array = (binary_suffix_array|2);
break;
case 3:
binary_suffix_array = (binary_suffix_array|3);
break;

}
switch (binary_ref_seq[ (suffix_array[j]+query_k_mer-r-1)%pac_len]){
case 0:
reverse_bit = (reverse_bit|0);
break;
case 1:
reverse_bit = (reverse_bit|1);
break;
case 2:
reverse_bit = (reverse_bit|2);
break;
case 3:
reverse_bit = (reverse_bit|3);
break;

}
}

memcpy(possa_out_batch + (j-i)*13+5, &reverse_bit, 8 );
memcpy(sa_out_batch + (j-i)*8, &binary_suffix_array, 8 );

}
switch (binary_ref_seq[ (suffix_array[i]+query_k_mer-r-1)%pac_len]){
case 0:
reverse_bit = (reverse_bit|0);
break;
case 1:
reverse_bit = (reverse_bit|1);
break;
case 2:
reverse_bit = (reverse_bit|2);
break;
case 3:
reverse_bit = (reverse_bit|3);
break;
#pragma omp barrier
}

uint64_t padded_t_num = 0;
if (padded_t_flag){
for (uint64_t j =i; j < i+write_num; j++){
if ( pac_len - padded_t_len <= suffix_array[j] && pac_len > suffix_array[j] ){
padded_t_num ++;
continue;
}
// fill in ref2sa (Inverse suffix array)
uint32_t val_ref2sa_4 = (cumulative_sa + (j-i-padded_t_num)) >> 8;
uint8_t val_ref2sa_1 = (cumulative_sa + (j-i-padded_t_num)) & 0xff;
memcpy( ref_to_sapos + suffix_array[j]*5, &val_ref2sa_4, 4);
memcpy( ref_to_sapos + suffix_array[j]*5 + 4, &val_ref2sa_1, 1);

// fill in suffix array in packed binary form
uint32_t pos_val = suffix_array[j] >>8 ;
uint8_t ls_val= suffix_array[j] & 0xff;
memcpy(pos_out_batch + (j-i-padded_t_num)*5, &pos_val, 4 );
memcpy(pos_out_batch + (j-i-padded_t_num)*5 + 4, &ls_val, 1 );

// fill in suffix array and corresponding 64-bit suffix in packed binary form
memcpy(possa_out_batch + (j-i-padded_t_num)*13, &pos_val, 4 );
memcpy(possa_out_batch + (j-i-padded_t_num)*13 + 4, &ls_val, 1 );

// below code generates the 64-bit suffix to be added to .suffixarray_uint64 and possa_packed file
uint64_t binary_suffix_array = 0;
uint64_t reverse_bit = 0;
for (r =0 ; r < query_k_mer ; r++){
binary_suffix_array = binary_suffix_array << 2;
reverse_bit = reverse_bit <<2;
switch (binary_ref_seq[ (suffix_array[j]+r)%pac_len]){
case 0:
binary_suffix_array = (binary_suffix_array|0);
break;
case 1:
binary_suffix_array = (binary_suffix_array|1);
break;
case 2:
binary_suffix_array = (binary_suffix_array|2);
break;
case 3:
binary_suffix_array = (binary_suffix_array|3);
break;

}
switch (binary_ref_seq[ (suffix_array[j]+query_k_mer-r-1)%pac_len]){
case 0:
reverse_bit = (reverse_bit|0);
break;
case 1:
reverse_bit = (reverse_bit|1);
break;
case 2:
reverse_bit = (reverse_bit|2);
break;
case 3:
reverse_bit = (reverse_bit|3);
break;
}
}
memcpy(possa_out_batch + (j-i-padded_t_num)*13+5, &reverse_bit, 8 );
memcpy(sa_out_batch + (j-i-padded_t_num)*8, &binary_suffix_array, 8 );
}
}



possa_out.write(reinterpret_cast<const char*>(&reverse_bit), sizeof(uint64_t));
sa_out.write(reinterpret_cast<const char*>(&binary_suffix_array), sizeof(uint64_t));
write_num -= padded_t_num;
cumulative_sa += write_num;

pos_out.write(pos_out_batch, 5*write_num );
possa_out.write(possa_out_batch, 13*write_num );
sa_out.write(sa_out_batch, 8*write_num );

}
//check suffix array number is correct
assert(sa_count== total_sa_num);
pos_out.close();
possa_out.close();
sa_out.close();

// #if MEM_TRADEOFF && READ_FROM_FILE
char ref_to_sapos_name[PATH_MAX];
strcpy_s(ref_to_sapos_name, PATH_MAX, prefix);
strcat_s(ref_to_sapos_name, PATH_MAX, ".ref2sa_packed");
Expand All @@ -373,14 +457,8 @@ void buildSAandLEP(char* prefix, int num_threads){
ref2sa_stream.write((char*)ref_to_sapos, 5*total_sa_num * sizeof(uint8_t));
ref2sa_stream.close();
_mm_free(ref_to_sapos);
// #endif


// hit_out.close();
fprintf(stderr, "build suffix-array ticks = %llu\n", __rdtsc() - startTick);




_mm_free(binary_ref_seq);
_mm_free(suffix_array);
Expand Down
5 changes: 4 additions & 1 deletion src/bwtindex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ int bwa_index(int argc, char *argv[]) // the "index" command
fprintf(stderr, "Usage: bwa-meme index [options] <in.fasta>\n\n");
fprintf(stderr, "Options: -a STR BWT construction algorithm: bwtsw, is, rb2, mem2, ert or meme \n");
fprintf(stderr, " -p STR prefix of the index [same as fasta name]\n");
fprintf(stderr, " -t INT number of threads for ERT index building [%d]\n", num_threads);
fprintf(stderr, " -t INT number of threads for MEME index building [%d]\n", num_threads);
fprintf(stderr, " -6 index files named as <in.fasta>.64.* instead of <in.fasta>.* \n");
fprintf(stderr, "\n");
fprintf(stderr, "Warning: `-a bwtsw' does not work for short genomes, while `-a is' and\n");
Expand Down Expand Up @@ -349,6 +349,7 @@ int bwa_idx_build_Learned_index(const char *fa, const char *prefix, int num_thre
int64_t l_pac;

{ // nucleotide indexing
t = clock();
gzFile fp = xzopen(fa, "r");
// t = clock();
// fprintf(stderr, "[bwa_index] Pack FASTA... ");
Expand All @@ -366,7 +367,9 @@ int bwa_idx_build_Learned_index(const char *fa, const char *prefix, int num_thre
if (bwa_verbose >= 3) {
fprintf(stderr, "[M::%s] Building Index for bwa-meme...\n", __func__ );
}

buildSAandLEP(prefix, num_threads);
fprintf(stderr, "Took %.2f sec for index build\n", (float)(clock() - t) / CLOCKS_PER_SEC);
// bwa_idx_destroy(bid);
}
return 0;
Expand Down

0 comments on commit 4f1e25b

Please sign in to comment.