Skip to content

Commit

Permalink
Use min to limit the index of next prefetch.
Browse files Browse the repository at this point in the history
  • Loading branch information
kishorenc committed May 15, 2023
1 parent b5c2eba commit 573ab84
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
tableint candidate_id = *(datal + j);
// if (candidate_id == 0) continue;
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(datal + j)), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*(datal + j)), _MM_HINT_T0);
size_t next_index = std::min(size - 1, j + 1);
_mm_prefetch((char *) (visited_array + *(datal + next_index)), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*(datal + next_index)), _MM_HINT_T0);
#endif
if (visited_array[candidate_id] == visited_array_tag) continue;
visited_array[candidate_id] = visited_array_tag;
Expand Down Expand Up @@ -343,8 +344,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
int candidate_id = *(data + j);
// if (candidate_id == 0) continue;
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + j)), _MM_HINT_T0);
_mm_prefetch(data_level0_memory_ + (*(data + j)) * size_data_per_element_ + offsetData_,
size_t next_index = std::min(size, j + 1);
_mm_prefetch((char *) (visited_array + *(data + next_index)), _MM_HINT_T0);
_mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
_MM_HINT_T0); ////////////
#endif
if (!(visited_array[candidate_id] == visited_array_tag)) {
Expand Down Expand Up @@ -1007,7 +1009,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
#endif
for (int i = 0; i < size; i++) {
#ifdef USE_SSE
_mm_prefetch(getDataByInternalId(*(datal + i)), _MM_HINT_T0);
size_t next_index = std::min(size - 1, i + 1);
_mm_prefetch(getDataByInternalId(*(datal + next_index)), _MM_HINT_T0);
#endif
tableint cand = datal[i];
dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
Expand Down

0 comments on commit 573ab84

Please sign in to comment.