Skip to content

Commit

Permalink
Merge pull request #43 from matsui528/follow_up
Browse files Browse the repository at this point in the history
Follow up PR for #42
  • Loading branch information
matsui528 authored Sep 5, 2021
2 parents 9a7de86 + 67acf61 commit 8f59f51
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 27 deletions.
12 changes: 5 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,22 @@ jobs:
strategy:
matrix:
# https://github.blog/2019-08-08-github-actions-now-supports-ci-cd/
#os: [ubuntu-latest, macos-latest]
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.5, 3.6, 3.7, 3.8]
# https://stackoverflow.com/questions/57810623/how-to-select-the-c-c-compiler-used-for-a-github-actions-job:
compiler: [gcc, clang, cl]
# Don't check ubuntu+clang
exclude:
# Excluding clang in ubuntu and windows
# ubuntu: gcc
- os: ubuntu-latest
compiler: clang
- os: windows-latest
compiler: clang
# Excluding cl in ubuntu and mac
- os: ubuntu-latest
compiler: cl
# mac: gcc, clang
- os: macos-latest
compiler: cl
# Excluding gcc in windows
# win: cl
- os: windows-latest
compiler: clang
- os: windows-latest
compiler: gcc
steps:
Expand Down
34 changes: 21 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,37 @@ The search can be operated for a subset of a database. | Rii remains fast even a
You can install the package via pip. This library works with Python 3.5+ on linux/mac/wsl/Windows10 (x64, using MSVC:flags - /arch:AVX2, /openmp:llvm, /fp:fast').

```
pip install git+https://github.com/ashleyabraham/rii.git
pip install rii
```
or use pre-compiled binary for Windows 10 (, may need MS Visual Studio Build tools)

```
pip install https://github.com/ashleyabraham/rii/releases/download/v0.2.7/rii-0.2.7-cp38-cp38-win_amd64.whl
```

<details>
<summary>For windows (maintained by @ashleyabraham)</summary>

### Pre-compiled binary for Windows 10 (, may need MS Visual Studio Build tools)
```
pip install https://github.com/ashleyabraham/rii/releases/download/v0.2.7/rii-0.2.7-cp38-cp38-win_amd64.whl
```

### OpenMP
In order to use OpenMP 3.0 /openmp:llvm flag is used which causes warnings of multiple libs loading, use at your descretion when used with other parallel processing library loadings. To supress use

### Windows (notes)
#### OpenMP
In order to use OpenMP 3.0 /openmp:llvm flag is used which causes warnings of multiple libs loading, use at your descretion when used with other parallel processing library loadings. To supress use
`SET KMP_DUPLICATE_LIB_OK=TRUE`

### SIMD
The /arch:AVX2 flag is used in MSVC to set appropriate SIMD preprocessors and compiler intrinsics

</details>

`SET KMP_DUPLICATE_LIB_OK=TRUE`

#### SIMD
The /arch:AVX2 flag is used in MSVC to set appropriate SIMD preprocessors and compiler intrinsics


## [Documentation](https://rii.readthedocs.io/en/latest/index.html)
- [Tutorial](https://rii.readthedocs.io/en/latest/source/tutorial.html)
- [Tips](https://rii.readthedocs.io/en/latest/source/tips.html)
- [API](https://rii.readthedocs.io/en/latest/source/api.html)


## Usage

### Basic ANN
Expand Down Expand Up @@ -104,13 +111,13 @@ print(ids, dists) # e.g., [728 85 132] [14.80522156 15.92787838 16.28690338]
```python
# Add new vectors
X2 = np.random.random((1000, D)).astype(np.float32)
e.add_configure(vecs=X2) # Now N is 11000
e.add(vecs=X2) # Now N is 11000
e.query(q=q) # Ok. (0.12 msec / query)

# However, if you add quite a lot of vectors, the search might become slower
# because the data structure has been optimized for the initial item size (N=10000)
X3 = np.random.random((1000000, D)).astype(np.float32)
e.add_configure(vecs=X3) # A lot. Now N is 1011000
e.add(vecs=X3) # A lot. Now N is 1011000
e.query(q=q) # Slower (0.96 msec/query)

# In such case, run the reconfigure function. That updates the data structure
Expand Down Expand Up @@ -156,3 +163,4 @@ e1.merge(e2) # Now e1 contains both X1 and X2

## Credits
- The logo is designed by [@richardbmx](https://github.com/richardbmx) ([#4](https://github.com/matsui528/rii/issues/4))
- The windows implementation is by [@ashleyabraham](https://github.com/ashleyabraham) ([#42](https://github.com/matsui528/rii/pull/42))
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,16 @@ def build_extensions(self):
opts.append('/fp:fast') # -Ofast

if sys.platform not in ['darwin', 'win32']:
opts.append('-fopenmp') # For pqk-means
# For linux
opts.append('-fopenmp') # For pqk-means.

if sys.platform not in ['win32']:
# For linux and mac
opts.append('-march=native') # For fast SIMD computation of L2 distance
opts.append('-mtune=native') # Do optimization (It seems this doesn't boost, but just in case)
opts.append('-Ofast') # This makes the program faster


for ext in self.extensions:
ext.extra_compile_args = opts
if not sys.platform == 'darwin':
Expand Down
24 changes: 20 additions & 4 deletions src/distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ static inline __m128 masked_read (int d, const float *x)
// cannot use AVX2 _mm_mask_set1_epi32
}

//#ifdef __AVX__
#if defined(__AVX__)
// Reading function for AVX and AVX512
// This function is from Faiss
Expand All @@ -84,7 +83,6 @@ static inline __m256 masked_read_8 (int d, const float *x)



//#ifdef __AVX512F__
#if defined(__AVX512F__)
// Reading function for AVX512
// reads 0 <= d < 16 floats as __m512
Expand All @@ -109,7 +107,6 @@ static inline __m512 masked_read_16 (int d, const float *x)

// ========================= Distance functions ============================

//#ifdef __AVX512F__
#if defined(__AVX512F__)
static const std::string g_simd_architecture = "avx512";

Expand All @@ -128,31 +125,39 @@ float fvec_L2sqr (const float *x, const float *y, size_t d)
}

__m256 msum2 = _mm512_extractf32x8_ps(msum1, 1);
// msum2 += _mm512_extractf32x8_ps(msum1, 0);
msum2 = _mm256_add_ps(msum2, _mm512_extractf32x8_ps(msum1, 0));

while (d >= 8) {
__m256 mx = _mm256_loadu_ps (x); x += 8;
__m256 my = _mm256_loadu_ps (y); y += 8;
// const __m256 a_m_b1 = mx - my;
const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
// msum2 += a_m_b1 * a_m_b1;
msum2 = _mm256_add_ps(msum2, _mm256_mul_ps(a_m_b1, a_m_b1));
d -= 8;
}

__m128 msum3 = _mm256_extractf128_ps(msum2, 1);
// msum3 += _mm256_extractf128_ps(msum2, 0);
msum3 = _mm_add_ps(msum3, _mm256_extractf128_ps(msum2, 0));

if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
// const __m128 a_m_b1 = mx - my;
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
// msum3 += a_m_b1 * a_m_b1;
msum3 = _mm_add_ps(msum3, _mm_mul_ps(a_m_b1, a_m_b1));
d -= 4;
}

if (d > 0) {
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
// __m128 a_m_b1 = mx - my;
__m128 a_m_b1 = _mm_sub_ps(mx, my);
// msum3 += a_m_b1 * a_m_b1;
msum3 = _mm_add_ps(msum3, _mm_mul_ps(a_m_b1, a_m_b1));
}

Expand All @@ -173,26 +178,33 @@ float fvec_L2sqr (const float *x, const float *y, size_t d)
while (d >= 8) {
__m256 mx = _mm256_loadu_ps (x); x += 8;
__m256 my = _mm256_loadu_ps (y); y += 8;
const __m256 a_m_b1 = _mm256_sub_ps(mx, my); // mx - my;
// const __m256 a_m_b1 = mx - my;
const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
// msum1 += a_m_b1 * a_m_b1;
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1 ,a_m_b1));
d -= 8;
}

__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
// msum2 += _mm256_extractf128_ps(msum1, 0);
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));

if (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
// const __m128 a_m_b1 = mx - my;
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
// msum2 += a_m_b1 * a_m_b1;
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
d -= 4;
}

if (d > 0) {
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
// __m128 a_m_b1 = mx - my;
__m128 a_m_b1 = _mm_sub_ps(mx, my);
// msum2 += a_m_b1 * a_m_b1;
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
}

Expand All @@ -214,7 +226,9 @@ float fvec_L2sqr (const float *x, const float *y, size_t d)
while (d >= 4) {
__m128 mx = _mm_loadu_ps (x); x += 4;
__m128 my = _mm_loadu_ps (y); y += 4;
// const __m128 a_m_b1 = mx - my;
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
// msum1 += a_m_b1 * a_m_b1;
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
d -= 4;
}
Expand All @@ -223,7 +237,9 @@ float fvec_L2sqr (const float *x, const float *y, size_t d)
// add the last 1, 2 or 3 values
__m128 mx = masked_read (d, x);
__m128 my = masked_read (d, y);
// __m128 a_m_b1 = mx - my;
__m128 a_m_b1 = _mm_sub_ps(mx, my);
// msum1 += a_m_b1 * a_m_b1;
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
}

Expand Down
20 changes: 18 additions & 2 deletions tests/test_rii.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#from .context import rii
import rii
from .context import rii
import unittest
import numpy as np
import nanopq
Expand Down Expand Up @@ -54,6 +53,23 @@ def test_reconfigure(self):
self.assertEqual(len(e.posting_lists), nlist)
self.assertEqual(sum([len(plist) for plist in e.posting_lists]), N)

def test_simple_add_configure(self):
M, Ks = 4, 20
N1, N2, D = 300, 700, 40
X1 = np.random.random((N1, D)).astype(np.float32)
X2 = np.random.random((N2, D)).astype(np.float32)
e = rii.Rii(fine_quantizer=nanopq.PQ(M=M, Ks=Ks, verbose=True).fit(vecs=X1))
e.add(vecs=X1)
self.assertEqual(e.N, N1)
e.add(vecs=X2)
self.assertEqual(e.N, N1 + N2)
for nlist in [5, 100]:
e.reconfigure(nlist=nlist)
self.assertEqual(e.nlist, nlist)
self.assertEqual(e.coarse_centers.shape, (nlist, M))
self.assertEqual(len(e.posting_lists), nlist)
self.assertEqual(sum([len(plist) for plist in e.posting_lists]), N1 + N2)

def test_add_configure(self):
M, Ks = 4, 20
N, D = 1000, 40
Expand Down

0 comments on commit 8f59f51

Please sign in to comment.