Skip to content

Commit

Permalink
Integrate IVF-Flat from RAFT (#2521)
Browse files Browse the repository at this point in the history
Summary:
This is a design proposal that demonstrates an approach to enabling optional support for [RAFT](https://github.com/rapidsai/raft) versions of IVF PQ and IVF Flat (and brute force w/ fused k-selection when k <= 64). There are still a few open issues and design discussions needed for the new RAFT index types to support the full range of features of that FAISS' current gpu index types.

Checklist for the integration todos:
- [x] Rebase on current `main` branch
- [X] The raft handle has been plugged directly into the StandardGpuResources
- [X] `FlatIndex` passing Googletests
- [x] Use `CodePacker` to support `copyFrom()` and `copyTo()`
- [X] `IVF-flat passing Googletests
- [ ] Raise appropriate exceptions for operations which are not yet supported by RAFT

Additional features we've discussed:
- [x] Separate IVF lists into individual memory chunks
- [ ] Saving/loading

To build FAISS w/ optional RAFT support:
```
mkdir build
cd build
cmake ../ -DFAISS_ENABLE_RAFT=ON -DFAISS_ENABLE_GPU=ON
make -j
```

For development/testing, we've also supplied a bash script to make things easier: `build.sh`

Below is a benchmark comparing the training of IVF Flat indices for RAFT and FAISS:
![image](https://user-images.githubusercontent.com/1242464/194944737-8b808f11-e28e-4556-82d1-1ea4b0707283.png)

The benchmark was produced using Googlebench in [this](https://github.com/tfeher/raft/tree/raft_faiss_bench) RAFT fork. We're going to provide benchmarks for the queries as well. There are still a couple bottlenecks to be removed in the IVF-Flat training implementation and we'll update the current benchmark when ready.

Pull Request resolved: #2521

Test Plan: `buck test mode/debuck test mode/dev-nosan //faiss/gpu/test:test_gpu_index_ivfflat`

Reviewed By: algoriddle

Differential Revision: D49118319

Pulled By: mdouze

fbshipit-source-id: 5916108bc27154acf7c92021ba579a6ca85d730b
  • Loading branch information
cjnolet authored and facebook-github-bot committed Oct 5, 2023
1 parent 458633c commit edcf743
Show file tree
Hide file tree
Showing 14 changed files with 1,231 additions and 91 deletions.
58 changes: 58 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/bin/bash

# NOTE: This file is temporary for the proof-of-concept branch and will be removed before this PR is merged

BUILD_TYPE=Release
BUILD_DIR=build/

RAFT_REPO_REL=""
EXTRA_CMAKE_ARGS=""
set -e

if [[ ${RAFT_REPO_REL} != "" ]]; then
RAFT_REPO_PATH="`readlink -f \"${RAFT_REPO_REL}\"`"
EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DCPM_raft_SOURCE=${RAFT_REPO_PATH}"
fi

if [ "$1" == "clean" ]; then
rm -rf build
rm -rf .cache
exit 0
fi

if [ "$1" == "test" ]; then
make -C build -j test
exit 0
fi

if [ "$1" == "test-raft" ]; then
./build/faiss/gpu/test/TestRaftIndexIVFFlat
exit 0
fi

mkdir -p $BUILD_DIR
cd $BUILD_DIR

cmake \
-DFAISS_ENABLE_GPU=ON \
-DFAISS_ENABLE_RAFT=ON \
-DFAISS_ENABLE_PYTHON=OFF \
-DBUILD_TESTING=ON \
-DBUILD_SHARED_LIBS=OFF \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DFAISS_OPT_LEVEL=avx2 \
-DRAFT_NVTX=OFF \
-DCMAKE_CUDA_ARCHITECTURES="NATIVE" \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
${EXTRA_CMAKE_ARGS} \
../


# make -C build -j12 faiss
cmake --build . -j12
# make -C build -j12 swigfaiss
# (cd build/faiss/python && python setup.py install)

2 changes: 1 addition & 1 deletion cmake/thirdparty/fetch_rapids.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================
set(RAPIDS_VERSION "23.06")
set(RAPIDS_VERSION "23.08")

if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake
Expand Down
4 changes: 3 additions & 1 deletion faiss/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,11 @@ generate_ivf_interleaved_code()

if(FAISS_ENABLE_RAFT)
list(APPEND FAISS_GPU_HEADERS
impl/RaftIVFFlat.cuh
impl/RaftFlatIndex.cuh)
list(APPEND FAISS_GPU_SRC
impl/RaftFlatIndex.cu)
impl/RaftFlatIndex.cu
impl/RaftIVFFlat.cu)

target_compile_definitions(faiss PUBLIC USE_NVIDIA_RAFT=1)
target_compile_definitions(faiss_avx2 PUBLIC USE_NVIDIA_RAFT=1)
Expand Down
49 changes: 43 additions & 6 deletions faiss/gpu/GpuIndexIVF.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
#include <faiss/gpu/impl/IVFBase.cuh>
#include <faiss/gpu/utils/CopyUtils.cuh>

#if defined USE_NVIDIA_RAFT
#include <raft/core/handle.hpp>
#include <raft/neighbors/ivf_flat.cuh>
#endif

namespace faiss {
namespace gpu {

Expand Down Expand Up @@ -444,14 +449,46 @@ void GpuIndexIVF::trainQuantizer_(idx_t n, const float* x) {
printf("Training IVF quantizer on %ld vectors in %dD\n", n, d);
}

// leverage the CPU-side k-means code, which works for the GPU
// flat index as well
quantizer->reset();
Clustering clus(this->d, nlist, this->cp);
clus.verbose = verbose;
clus.train(n, x, *quantizer);
quantizer->is_trained = true;

#if defined USE_NVIDIA_RAFT

if (config_.use_raft) {
const raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();

raft::neighbors::ivf_flat::index_params raft_idx_params;
raft_idx_params.n_lists = nlist;
raft_idx_params.metric = metric_type == faiss::METRIC_L2
? raft::distance::DistanceType::L2Expanded
: raft::distance::DistanceType::InnerProduct;
raft_idx_params.add_data_on_build = false;
raft_idx_params.kmeans_trainset_fraction = 1.0;
raft_idx_params.kmeans_n_iters = cp.niter;
raft_idx_params.adaptive_centers = !cp.frozen_centroids;

auto raft_index = raft::neighbors::ivf_flat::build(
raft_handle, raft_idx_params, x, n, (idx_t)d);

raft_handle.sync_stream();

quantizer->train(nlist, raft_index.centers().data_handle());
quantizer->add(nlist, raft_index.centers().data_handle());
} else
#else
if (config_.use_raft) {
FAISS_THROW_MSG(
"RAFT has not been compiled into the current version so it cannot be used.");
} else
#endif
{
// leverage the CPU-side k-means code, which works for the GPU
// flat index as well
Clustering clus(this->d, nlist, this->cp);
clus.verbose = verbose;
clus.train(n, x, *quantizer);
}
quantizer->is_trained = true;
FAISS_ASSERT(quantizer->ntotal == nlist);
}

Expand Down
13 changes: 7 additions & 6 deletions faiss/gpu/GpuIndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,24 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
virtual void updateQuantizer() = 0;

/// Returns the number of inverted lists we're managing
idx_t getNumLists() const;
virtual idx_t getNumLists() const;

/// Returns the number of vectors present in a particular inverted list
idx_t getListLength(idx_t listId) const;
virtual idx_t getListLength(idx_t listId) const;

/// Return the encoded vector data contained in a particular inverted list,
/// for debugging purposes.
/// If gpuFormat is true, the data is returned as it is encoded in the
/// GPU-side representation.
/// Otherwise, it is converted to the CPU format.
/// compliant format, while the native GPU format may differ.
std::vector<uint8_t> getListVectorData(idx_t listId, bool gpuFormat = false)
const;
virtual std::vector<uint8_t> getListVectorData(
idx_t listId,
bool gpuFormat = false) const;

/// Return the vector indices contained in a particular inverted list, for
/// debugging purposes.
std::vector<idx_t> getListIndices(idx_t listId) const;
virtual std::vector<idx_t> getListIndices(idx_t listId) const;

void search_preassigned(
idx_t n,
Expand Down Expand Up @@ -121,7 +122,7 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
int getCurrentNProbe_(const SearchParameters* params) const;
void verifyIVFSettings_() const;
bool addImplRequiresIDs_() const override;
void trainQuantizer_(idx_t n, const float* x);
virtual void trainQuantizer_(idx_t n, const float* x);

/// Called from GpuIndex for add/add_with_ids
void addImpl_(idx_t n, const float* x, const idx_t* ids) override;
Expand Down
103 changes: 83 additions & 20 deletions faiss/gpu/GpuIndexIVFFlat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#include <faiss/gpu/utils/CopyUtils.cuh>
#include <faiss/gpu/utils/Float16.cuh>

#if defined USE_NVIDIA_RAFT
#include <faiss/gpu/impl/RaftIVFFlat.cuh>
#endif

#include <limits>

namespace faiss {
Expand Down Expand Up @@ -70,8 +74,7 @@ GpuIndexIVFFlat::GpuIndexIVFFlat(
// no other quantizer that we need to train, so this is sufficient
if (this->is_trained) {
FAISS_ASSERT(this->quantizer);

index_.reset(new IVFFlat(
set_index_(
resources_.get(),
this->d,
this->nlist,
Expand All @@ -81,14 +84,62 @@ GpuIndexIVFFlat::GpuIndexIVFFlat(
nullptr, // no scalar quantizer
ivfFlatConfig_.interleavedLayout,
ivfFlatConfig_.indicesOptions,
config_.memorySpace));
config_.memorySpace);
baseIndex_ = std::static_pointer_cast<IVFBase, IVFFlat>(index_);
updateQuantizer();
}
}

GpuIndexIVFFlat::~GpuIndexIVFFlat() {}

void GpuIndexIVFFlat::set_index_(
GpuResources* resources,
int dim,
int nlist,
faiss::MetricType metric,
float metricArg,
bool useResidual,
/// Optional ScalarQuantizer
faiss::ScalarQuantizer* scalarQ,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space) {
#if defined USE_NVIDIA_RAFT

if (config_.use_raft) {
index_.reset(new RaftIVFFlat(
resources,
dim,
nlist,
metric,
metricArg,
useResidual,
scalarQ,
interleavedLayout,
indicesOptions,
space));
} else
#else
if (config_.use_raft) {
FAISS_THROW_MSG(
"RAFT has not been compiled into the current version so it cannot be used.");
} else
#endif
{
index_.reset(new IVFFlat(
resources,
dim,
nlist,
metric,
metricArg,
useResidual,
scalarQ,
interleavedLayout,
indicesOptions,
space));
}
}

void GpuIndexIVFFlat::reserveMemory(size_t numVecs) {
DeviceScope scope(config_.device);

Expand All @@ -110,25 +161,25 @@ void GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {

// The other index might not be trained
if (!index->is_trained) {
FAISS_ASSERT(!this->is_trained);
FAISS_ASSERT(!is_trained);
return;
}

// Otherwise, we can populate ourselves from the other index
FAISS_ASSERT(this->is_trained);
FAISS_ASSERT(is_trained);

// Copy our lists as well
index_.reset(new IVFFlat(
set_index_(
resources_.get(),
this->d,
this->nlist,
d,
nlist,
index->metric_type,
index->metric_arg,
false, // no residual
nullptr, // no scalar quantizer
ivfFlatConfig_.interleavedLayout,
ivfFlatConfig_.indicesOptions,
config_.memorySpace));
config_.memorySpace);
baseIndex_ = std::static_pointer_cast<IVFBase, IVFFlat>(index_);
updateQuantizer();

Expand Down Expand Up @@ -201,18 +252,30 @@ void GpuIndexIVFFlat::train(idx_t n, const float* x) {

FAISS_ASSERT(!index_);

// FIXME: GPUize more of this
// First, make sure that the data is resident on the CPU, if it is not on
// the CPU, as we depend upon parts of the CPU code
auto hostData = toHost<float, 2>(
(float*)x,
resources_->getDefaultStream(config_.device),
{n, this->d});

trainQuantizer_(n, hostData.data());
#if defined USE_NVIDIA_RAFT
if (config_.use_raft) {
// No need to copy the data to host
trainQuantizer_(n, x);
} else
#else
if (config_.use_raft) {
FAISS_THROW_MSG(
"RAFT has not been compiled into the current version so it cannot be used.");
} else
#endif
{
// FIXME: GPUize more of this
// First, make sure that the data is resident on the CPU, if it is not
// on the CPU, as we depend upon parts of the CPU code
auto hostData = toHost<float, 2>(
(float*)x,
resources_->getDefaultStream(config_.device),
{n, this->d});
trainQuantizer_(n, hostData.data());
}

// The quantizer is now trained; construct the IVF index
index_.reset(new IVFFlat(
set_index_(
resources_.get(),
this->d,
this->nlist,
Expand All @@ -222,7 +285,7 @@ void GpuIndexIVFFlat::train(idx_t n, const float* x) {
nullptr, // no scalar quantizer
ivfFlatConfig_.interleavedLayout,
ivfFlatConfig_.indicesOptions,
config_.memorySpace));
config_.memorySpace);
baseIndex_ = std::static_pointer_cast<IVFBase, IVFFlat>(index_);
updateQuantizer();

Expand Down
15 changes: 15 additions & 0 deletions faiss/gpu/GpuIndexIVFFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#pragma once

#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/impl/ScalarQuantizer.h>

#include <memory>

namespace faiss {
Expand Down Expand Up @@ -86,6 +88,19 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
void train(idx_t n, const float* x) override;

protected:
void set_index_(
GpuResources* resources,
int dim,
int nlist,
faiss::MetricType metric,
float metricArg,
bool useResidual,
/// Optional ScalarQuantizer
faiss::ScalarQuantizer* scalarQ,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space);

/// Our configuration options
const GpuIndexIVFFlatConfig ivfFlatConfig_;

Expand Down
6 changes: 5 additions & 1 deletion faiss/gpu/StandardGpuResources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,11 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) {

defaultStreams_[device] = defaultStream;

cudaStream_t asyncCopyStream = nullptr;
#if defined USE_NVIDIA_RAFT
raftHandles_.emplace(std::make_pair(device, defaultStream));
#endif

cudaStream_t asyncCopyStream = 0;
CUDA_VERIFY(
cudaStreamCreateWithFlags(&asyncCopyStream, cudaStreamNonBlocking));

Expand Down
Loading

0 comments on commit edcf743

Please sign in to comment.