diff --git a/build.sh b/build.sh new file mode 100644 index 0000000000..bb9985ce25 --- /dev/null +++ b/build.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# NOTE: This file is temporary for the proof-of-concept branch and will be removed before this PR is merged + +BUILD_TYPE=Release +BUILD_DIR=build/ + +RAFT_REPO_REL="" +EXTRA_CMAKE_ARGS="" +set -e + +if [[ ${RAFT_REPO_REL} != "" ]]; then + RAFT_REPO_PATH="`readlink -f \"${RAFT_REPO_REL}\"`" + EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DCPM_raft_SOURCE=${RAFT_REPO_PATH}" +fi + +if [ "$1" == "clean" ]; then + rm -rf build + rm -rf .cache + exit 0 +fi + +if [ "$1" == "test" ]; then + make -C build -j test + exit 0 +fi + +if [ "$1" == "test-raft" ]; then + ./build/faiss/gpu/test/TestRaftIndexIVFFlat + exit 0 +fi + +mkdir -p $BUILD_DIR +cd $BUILD_DIR + +cmake \ + -DFAISS_ENABLE_GPU=ON \ + -DFAISS_ENABLE_RAFT=ON \ + -DFAISS_ENABLE_PYTHON=OFF \ + -DBUILD_TESTING=ON \ + -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DFAISS_OPT_LEVEL=avx2 \ + -DRAFT_NVTX=OFF \ + -DCMAKE_CUDA_ARCHITECTURES="NATIVE" \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + ${EXTRA_CMAKE_ARGS} \ + ../ + + +# make -C build -j12 faiss +cmake --build . -j12 +# make -C build -j12 swigfaiss +# (cd build/faiss/python && python setup.py install) + diff --git a/cmake/thirdparty/fetch_rapids.cmake b/cmake/thirdparty/fetch_rapids.cmake index 044a369606..229c488196 100644 --- a/cmake/thirdparty/fetch_rapids.cmake +++ b/cmake/thirdparty/fetch_rapids.cmake @@ -15,7 +15,7 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. # ============================================================================= -set(RAPIDS_VERSION "23.06") +set(RAPIDS_VERSION "23.08") if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake) file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index eca3889698..ad7d2103fa 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -238,9 +238,11 @@ generate_ivf_interleaved_code() if(FAISS_ENABLE_RAFT) list(APPEND FAISS_GPU_HEADERS + impl/RaftIVFFlat.cuh impl/RaftFlatIndex.cuh) list(APPEND FAISS_GPU_SRC - impl/RaftFlatIndex.cu) + impl/RaftFlatIndex.cu + impl/RaftIVFFlat.cu) target_compile_definitions(faiss PUBLIC USE_NVIDIA_RAFT=1) target_compile_definitions(faiss_avx2 PUBLIC USE_NVIDIA_RAFT=1) diff --git a/faiss/gpu/GpuIndexIVF.cu b/faiss/gpu/GpuIndexIVF.cu index f2ed323605..c83008307d 100644 --- a/faiss/gpu/GpuIndexIVF.cu +++ b/faiss/gpu/GpuIndexIVF.cu @@ -16,6 +16,11 @@ #include #include +#if defined USE_NVIDIA_RAFT +#include +#include +#endif + namespace faiss { namespace gpu { @@ -444,14 +449,46 @@ void GpuIndexIVF::trainQuantizer_(idx_t n, const float* x) { printf("Training IVF quantizer on %ld vectors in %dD\n", n, d); } - // leverage the CPU-side k-means code, which works for the GPU - // flat index as well quantizer->reset(); - Clustering clus(this->d, nlist, this->cp); - clus.verbose = verbose; - clus.train(n, x, *quantizer); - quantizer->is_trained = true; +#if defined USE_NVIDIA_RAFT + + if (config_.use_raft) { + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + + raft::neighbors::ivf_flat::index_params raft_idx_params; + raft_idx_params.n_lists = nlist; + raft_idx_params.metric = metric_type == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct; + raft_idx_params.add_data_on_build = false; + raft_idx_params.kmeans_trainset_fraction = 1.0; + raft_idx_params.kmeans_n_iters = cp.niter; + raft_idx_params.adaptive_centers = !cp.frozen_centroids; + + auto raft_index = raft::neighbors::ivf_flat::build( + raft_handle, raft_idx_params, x, n, (idx_t)d); + + raft_handle.sync_stream(); + + quantizer->train(nlist, raft_index.centers().data_handle()); + quantizer->add(nlist, raft_index.centers().data_handle()); + } else +#else + if (config_.use_raft) { + FAISS_THROW_MSG( + "RAFT has not been compiled into the current version so it cannot be used."); + } else +#endif + { + // leverage the CPU-side k-means code, which works for the GPU + // flat index as well + Clustering clus(this->d, nlist, this->cp); + clus.verbose = verbose; + clus.train(n, x, *quantizer); + } + quantizer->is_trained = true; FAISS_ASSERT(quantizer->ntotal == nlist); } diff --git a/faiss/gpu/GpuIndexIVF.h b/faiss/gpu/GpuIndexIVF.h index 48096eaaf0..a9f092d35b 100644 --- a/faiss/gpu/GpuIndexIVF.h +++ b/faiss/gpu/GpuIndexIVF.h @@ -73,10 +73,10 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface { virtual void updateQuantizer() = 0; /// Returns the number of inverted lists we're managing - idx_t getNumLists() const; + virtual idx_t getNumLists() const; /// Returns the number of vectors present in a particular inverted list - idx_t getListLength(idx_t listId) const; + virtual idx_t getListLength(idx_t listId) const; /// Return the encoded vector data contained in a particular inverted list, /// for debugging purposes. @@ -84,12 +84,13 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface { /// GPU-side representation. /// Otherwise, it is converted to the CPU format. /// compliant format, while the native GPU format may differ. - std::vector getListVectorData(idx_t listId, bool gpuFormat = false) - const; + virtual std::vector getListVectorData( + idx_t listId, + bool gpuFormat = false) const; /// Return the vector indices contained in a particular inverted list, for /// debugging purposes. - std::vector getListIndices(idx_t listId) const; + virtual std::vector getListIndices(idx_t listId) const; void search_preassigned( idx_t n, @@ -121,7 +122,7 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface { int getCurrentNProbe_(const SearchParameters* params) const; void verifyIVFSettings_() const; bool addImplRequiresIDs_() const override; - void trainQuantizer_(idx_t n, const float* x); + virtual void trainQuantizer_(idx_t n, const float* x); /// Called from GpuIndex for add/add_with_ids void addImpl_(idx_t n, const float* x, const idx_t* ids) override; diff --git a/faiss/gpu/GpuIndexIVFFlat.cu b/faiss/gpu/GpuIndexIVFFlat.cu index 285efee970..750096e153 100644 --- a/faiss/gpu/GpuIndexIVFFlat.cu +++ b/faiss/gpu/GpuIndexIVFFlat.cu @@ -15,6 +15,10 @@ #include #include +#if defined USE_NVIDIA_RAFT +#include +#endif + #include namespace faiss { @@ -70,8 +74,7 @@ GpuIndexIVFFlat::GpuIndexIVFFlat( // no other quantizer that we need to train, so this is sufficient if (this->is_trained) { FAISS_ASSERT(this->quantizer); - - index_.reset(new IVFFlat( + set_index_( resources_.get(), this->d, this->nlist, @@ -81,7 +84,7 @@ GpuIndexIVFFlat::GpuIndexIVFFlat( nullptr, // no scalar quantizer ivfFlatConfig_.interleavedLayout, ivfFlatConfig_.indicesOptions, - config_.memorySpace)); + config_.memorySpace); baseIndex_ = std::static_pointer_cast(index_); updateQuantizer(); } @@ -89,6 +92,54 @@ GpuIndexIVFFlat::GpuIndexIVFFlat( GpuIndexIVFFlat::~GpuIndexIVFFlat() {} +void GpuIndexIVFFlat::set_index_( + GpuResources* resources, + int dim, + int nlist, + faiss::MetricType metric, + float metricArg, + bool useResidual, + /// Optional ScalarQuantizer + faiss::ScalarQuantizer* scalarQ, + bool interleavedLayout, + IndicesOptions indicesOptions, + MemorySpace space) { +#if defined USE_NVIDIA_RAFT + + if (config_.use_raft) { + index_.reset(new RaftIVFFlat( + resources, + dim, + nlist, + metric, + metricArg, + useResidual, + scalarQ, + interleavedLayout, + indicesOptions, + space)); + } else +#else + if (config_.use_raft) { + FAISS_THROW_MSG( + "RAFT has not been compiled into the current version so it cannot be used."); + } else +#endif + { + index_.reset(new IVFFlat( + resources, + dim, + nlist, + metric, + metricArg, + useResidual, + scalarQ, + interleavedLayout, + indicesOptions, + space)); + } +} + void GpuIndexIVFFlat::reserveMemory(size_t numVecs) { DeviceScope scope(config_.device); @@ -110,25 +161,25 @@ void GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) { // The other index might not be trained if (!index->is_trained) { - FAISS_ASSERT(!this->is_trained); + FAISS_ASSERT(!is_trained); return; } // Otherwise, we can populate ourselves from the other index - FAISS_ASSERT(this->is_trained); + FAISS_ASSERT(is_trained); // Copy our lists as well - index_.reset(new IVFFlat( + set_index_( resources_.get(), - this->d, - this->nlist, + d, + nlist, index->metric_type, index->metric_arg, false, // no residual nullptr, // no scalar quantizer ivfFlatConfig_.interleavedLayout, ivfFlatConfig_.indicesOptions, - config_.memorySpace)); + config_.memorySpace); baseIndex_ = std::static_pointer_cast(index_); updateQuantizer(); @@ -201,18 +252,30 @@ void GpuIndexIVFFlat::train(idx_t n, const float* x) { FAISS_ASSERT(!index_); - // FIXME: GPUize more of this - // First, make sure that the data is resident on the CPU, if it is not on - // the CPU, as we depend upon parts of the CPU code - auto hostData = toHost( - (float*)x, - resources_->getDefaultStream(config_.device), - {n, this->d}); - - trainQuantizer_(n, hostData.data()); +#if defined USE_NVIDIA_RAFT + if (config_.use_raft) { + // No need to copy the data to host + trainQuantizer_(n, x); + } else +#else + if (config_.use_raft) { + FAISS_THROW_MSG( + "RAFT has not been compiled into the current version so it cannot be used."); + } else +#endif + { + // FIXME: GPUize more of this + // First, make sure that the data is resident on the CPU, if it is not + // on the CPU, as we depend upon parts of the CPU code + auto hostData = toHost( + (float*)x, + resources_->getDefaultStream(config_.device), + {n, this->d}); + trainQuantizer_(n, hostData.data()); + } // The quantizer is now trained; construct the IVF index - index_.reset(new IVFFlat( + set_index_( resources_.get(), this->d, this->nlist, @@ -222,7 +285,7 @@ void GpuIndexIVFFlat::train(idx_t n, const float* x) { nullptr, // no scalar quantizer ivfFlatConfig_.interleavedLayout, ivfFlatConfig_.indicesOptions, - config_.memorySpace)); + config_.memorySpace); baseIndex_ = std::static_pointer_cast(index_); updateQuantizer(); diff --git a/faiss/gpu/GpuIndexIVFFlat.h b/faiss/gpu/GpuIndexIVFFlat.h index 9206d20f61..d7508feef4 100644 --- a/faiss/gpu/GpuIndexIVFFlat.h +++ b/faiss/gpu/GpuIndexIVFFlat.h @@ -8,6 +8,8 @@ #pragma once #include +#include + #include namespace faiss { @@ -86,6 +88,19 @@ class GpuIndexIVFFlat : public GpuIndexIVF { void train(idx_t n, const float* x) override; protected: + void set_index_( + GpuResources* resources, + int dim, + int nlist, + faiss::MetricType metric, + float metricArg, + bool useResidual, + /// Optional ScalarQuantizer + faiss::ScalarQuantizer* scalarQ, + bool interleavedLayout, + IndicesOptions indicesOptions, + MemorySpace space); + /// Our configuration options const GpuIndexIVFFlatConfig ivfFlatConfig_; diff --git a/faiss/gpu/StandardGpuResources.cpp b/faiss/gpu/StandardGpuResources.cpp index 418912aa4a..754025d049 100644 --- a/faiss/gpu/StandardGpuResources.cpp +++ b/faiss/gpu/StandardGpuResources.cpp @@ -362,7 +362,11 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) { defaultStreams_[device] = defaultStream; - cudaStream_t asyncCopyStream = nullptr; +#if defined USE_NVIDIA_RAFT + raftHandles_.emplace(std::make_pair(device, defaultStream)); +#endif + + cudaStream_t asyncCopyStream = 0; CUDA_VERIFY( cudaStreamCreateWithFlags(&asyncCopyStream, cudaStreamNonBlocking)); diff --git a/faiss/gpu/impl/IVFBase.cuh b/faiss/gpu/impl/IVFBase.cuh index 0ac2cd2843..2bb319d002 100644 --- a/faiss/gpu/impl/IVFBase.cuh +++ b/faiss/gpu/impl/IVFBase.cuh @@ -45,7 +45,7 @@ class IVFBase { /// Clear out all inverted lists, but retain the coarse quantizer /// and the product quantizer info - void reset(); + virtual void reset(); /// Return the number of dimensions we are indexing idx_t getDim() const; @@ -59,29 +59,30 @@ class IVFBase { /// For debugging purposes, return the list length of a particular /// list - idx_t getListLength(idx_t listId) const; + virtual idx_t getListLength(idx_t listId) const; /// Return the list indices of a particular list back to the CPU - std::vector getListIndices(idx_t listId) const; + virtual std::vector getListIndices(idx_t listId) const; /// Return the encoded vectors of a particular list back to the CPU - std::vector getListVectorData(idx_t listId, bool gpuFormat) const; + virtual std::vector getListVectorData(idx_t listId, bool gpuFormat) + const; /// Copy all inverted lists from a CPU representation to ourselves - void copyInvertedListsFrom(const InvertedLists* ivf); + virtual void copyInvertedListsFrom(const InvertedLists* ivf); /// Copy all inverted lists from ourselves to a CPU representation - void copyInvertedListsTo(InvertedLists* ivf); + virtual void copyInvertedListsTo(InvertedLists* ivf); /// Update our coarse quantizer with this quantizer instance; may be a CPU /// or GPU quantizer - void updateQuantizer(Index* quantizer); + virtual void updateQuantizer(Index* quantizer); /// Classify and encode/add vectors to our IVF lists. /// The input data must be on our current device. /// Returns the number of vectors successfully added. Vectors may /// not be able to be added because they contain NaNs. - idx_t addVectors( + virtual idx_t addVectors( Index* coarseQuantizer, Tensor& vecs, Tensor& indices); @@ -111,7 +112,7 @@ class IVFBase { protected: /// Adds a set of codes and indices to a list, with the representation /// coming from the CPU equivalent - void addEncodedVectorsToList_( + virtual void addEncodedVectorsToList_( idx_t listId, // resident on the host const void* codes, diff --git a/faiss/gpu/impl/IVFFlat.cuh b/faiss/gpu/impl/IVFFlat.cuh index 60b454f622..246fc18b16 100644 --- a/faiss/gpu/impl/IVFFlat.cuh +++ b/faiss/gpu/impl/IVFFlat.cuh @@ -60,17 +60,17 @@ class IVFFlat : public IVFBase { size_t getCpuVectorsEncodingSize_(idx_t numVecs) const override; /// Translate to our preferred GPU encoding - std::vector translateCodesToGpu_( + virtual std::vector translateCodesToGpu_( std::vector codes, idx_t numVecs) const override; /// Translate from our preferred GPU encoding - std::vector translateCodesFromGpu_( + virtual std::vector translateCodesFromGpu_( std::vector codes, idx_t numVecs) const override; /// Encode the vectors that we're adding and append to our IVF lists - void appendVectors_( + virtual void appendVectors_( Tensor& vecs, Tensor& ivfCentroidResiduals, Tensor& indices, @@ -84,7 +84,7 @@ class IVFFlat : public IVFBase { /// Shared IVF search implementation, used by both search and /// searchPreassigned - void searchImpl_( + virtual void searchImpl_( Tensor& queries, Tensor& coarseDistances, Tensor& coarseIndices, diff --git a/faiss/gpu/impl/RaftIVFFlat.cu b/faiss/gpu/impl/RaftIVFFlat.cu new file mode 100644 index 0000000000..2c6afb795c --- /dev/null +++ b/faiss/gpu/impl/RaftIVFFlat.cu @@ -0,0 +1,604 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace faiss { +namespace gpu { + +RaftIVFFlat::RaftIVFFlat( + GpuResources* res, + int dim, + int nlist, + faiss::MetricType metric, + float metricArg, + bool useResidual, + faiss::ScalarQuantizer* scalarQ, + bool interleavedLayout, + IndicesOptions indicesOptions, + MemorySpace space) + : IVFFlat(res, + dim, + nlist, + metric, + metricArg, + useResidual, + scalarQ, + interleavedLayout, + indicesOptions, + space) { + FAISS_THROW_IF_NOT_MSG( + indicesOptions == INDICES_64_BIT, + "only INDICES_64_BIT is supported for RAFT index"); + reset(); +} + +RaftIVFFlat::~RaftIVFFlat() {} + +/// Find the approximate k nearest neighbors for `queries` against +/// our database +void RaftIVFFlat::search( + Index* coarseQuantizer, + Tensor& queries, + int nprobe, + int k, + Tensor& outDistances, + Tensor& outIndices) { + // TODO: We probably don't want to ignore the coarse quantizer here... + + uint32_t numQueries = queries.getSize(0); + uint32_t cols = queries.getSize(1); + uint32_t k_ = k; + + // Device is already set in GpuIndex::search + FAISS_ASSERT(raft_knn_index.has_value()); + FAISS_ASSERT(numQueries > 0); + FAISS_ASSERT(cols == dim_); + FAISS_THROW_IF_NOT(nprobe > 0 && nprobe <= numLists_); + + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + raft::neighbors::ivf_flat::search_params pams; + pams.n_probes = nprobe; + + auto queries_view = raft::make_device_matrix_view( + queries.data(), (idx_t)numQueries, (idx_t)cols); + auto out_inds_view = raft::make_device_matrix_view( + outIndices.data(), (idx_t)numQueries, (idx_t)k_); + auto out_dists_view = raft::make_device_matrix_view( + outDistances.data(), (idx_t)numQueries, (idx_t)k_); + + raft::neighbors::ivf_flat::search( + raft_handle, + pams, + raft_knn_index.value(), + queries_view, + out_inds_view, + out_dists_view); + + /// Identify NaN rows and mask their nearest neighbors + auto nan_flag = raft::make_device_vector(raft_handle, numQueries); + + validRowIndices_(queries, nan_flag.data_handle()); + + raft::linalg::map_offset( + raft_handle, + raft::make_device_vector_view(outIndices.data(), numQueries * k_), + [nan_flag = nan_flag.data_handle(), + out_inds = outIndices.data(), + k_] __device__(uint32_t i) { + uint32_t row = i / k_; + if (!nan_flag[row]) + return idx_t(-1); + return out_inds[i]; + }); + + float max_val = std::numeric_limits::max(); + raft::linalg::map_offset( + raft_handle, + raft::make_device_vector_view(outDistances.data(), numQueries * k_), + [nan_flag = nan_flag.data_handle(), + out_dists = outDistances.data(), + max_val, + k_] __device__(uint32_t i) { + uint32_t row = i / k_; + if (!nan_flag[row]) + return max_val; + return out_dists[i]; + }); +} + +/// Classify and encode/add vectors to our IVF lists. +/// The input data must be on our current device. +/// Returns the number of vectors successfully added. Vectors may +/// not be able to be added because they contain NaNs. +idx_t RaftIVFFlat::addVectors( + Index* coarseQuantizer, + Tensor& vecs, + Tensor& indices) { + /// TODO: We probably don't want to ignore the coarse quantizer here + + idx_t n_rows = vecs.getSize(0); + + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + + /// Remove NaN values + auto nan_flag = raft::make_device_vector(raft_handle, n_rows); + + validRowIndices_(vecs, nan_flag.data_handle()); + + idx_t n_rows_valid = thrust::reduce( + raft_handle.get_thrust_policy(), + nan_flag.data_handle(), + nan_flag.data_handle() + n_rows, + 0); + + if (n_rows_valid < n_rows) { + auto gather_indices = raft::make_device_vector( + raft_handle, n_rows_valid); + + auto count = thrust::make_counting_iterator(0); + + thrust::copy_if( + raft_handle.get_thrust_policy(), + count, + count + n_rows, + gather_indices.data_handle(), + [nan_flag = nan_flag.data_handle()] __device__(auto i) { + return nan_flag[i]; + }); + + raft::matrix::gather( + raft_handle, + raft::make_device_matrix_view( + vecs.data(), n_rows, dim_), + raft::make_const_mdspan(gather_indices.view()), + (idx_t)16); + + auto valid_indices = raft::make_device_vector( + raft_handle, n_rows_valid); + + raft::matrix::gather( + raft_handle, + raft::make_device_matrix_view( + indices.data(), n_rows, (idx_t)1), + raft::make_const_mdspan(gather_indices.view())); + } + + FAISS_ASSERT(raft_knn_index.has_value()); + raft_knn_index.emplace(raft::neighbors::ivf_flat::extend( + raft_handle, + raft::make_device_matrix_view( + vecs.data(), n_rows_valid, dim_), + std::make_optional>( + raft::make_device_vector_view( + indices.data(), n_rows_valid)), + raft_knn_index.value())); + + return n_rows_valid; +} + +void RaftIVFFlat::reset() { + raft_knn_index.reset(); +} + +idx_t RaftIVFFlat::getListLength(idx_t listId) const { + FAISS_ASSERT(raft_knn_index.has_value()); + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + + uint32_t size; + raft::update_host( + &size, + raft_knn_index.value().list_sizes().data_handle() + listId, + 1, + raft_handle.get_stream()); + raft_handle.sync_stream(); + + return static_cast(size); +} + +/// Return the list indices of a particular list back to the CPU +std::vector RaftIVFFlat::getListIndices(idx_t listId) const { + FAISS_ASSERT(raft_knn_index.has_value()); + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + auto stream = raft_handle.get_stream(); + + idx_t listSize = getListLength(listId); + + std::vector vec(listSize); + + // fetch the list indices ptr on host + idx_t* list_indices_ptr; + + // fetch the list indices ptr on host + raft::update_host( + &list_indices_ptr, + raft_knn_index.value().inds_ptrs().data_handle() + listId, + 1, + stream); + raft_handle.sync_stream(); + + raft::update_host(vec.data(), list_indices_ptr, listSize, stream); + raft_handle.sync_stream(); + + return vec; +} + +/// Return the encoded vectors of a particular list back to the CPU +std::vector RaftIVFFlat::getListVectorData( + idx_t listId, + bool gpuFormat) const { + if (gpuFormat) { + FAISS_THROW_MSG("gpuFormat is not suppported for raft indices"); + } + FAISS_ASSERT(raft_knn_index.has_value()); + + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + auto stream = raft_handle.get_stream(); + + idx_t listSize = getListLength(listId); + + // the interleaved block can be slightly larger than the list size (it's + // rounded up) + auto gpuListSizeInBytes = getGpuVectorsEncodingSize_(listSize); + auto cpuListSizeInBytes = getCpuVectorsEncodingSize_(listSize); + + std::vector interleaved_codes(gpuListSizeInBytes); + std::vector flat_codes(cpuListSizeInBytes); + + float* list_data_ptr; + + // fetch the list data ptr on host + raft::update_host( + &list_data_ptr, + raft_knn_index.value().data_ptrs().data_handle() + listId, + 1, + stream); + raft_handle.sync_stream(); + + raft::update_host( + interleaved_codes.data(), + reinterpret_cast(list_data_ptr), + gpuListSizeInBytes, + stream); + raft_handle.sync_stream(); + + RaftIVFFlatCodePackerInterleaved packer( + (size_t)listSize, dim_, raft_knn_index.value().veclen()); + packer.unpack_all(interleaved_codes.data(), flat_codes.data()); + return flat_codes; +} + +/// Performs search when we are already given the IVF cells to look at +/// (GpuIndexIVF::search_preassigned implementation) +void RaftIVFFlat::searchPreassigned( + Index* coarseQuantizer, + Tensor& vecs, + Tensor& ivfDistances, + Tensor& ivfAssignments, + int k, + Tensor& outDistances, + Tensor& outIndices, + bool storePairs) { + // TODO: Fill this in! +} + +void RaftIVFFlat::updateQuantizer(Index* quantizer) { + idx_t quantizer_ntotal = quantizer->ntotal; + + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + auto stream = raft_handle.get_stream(); + + auto total_elems = size_t(quantizer_ntotal) * size_t(quantizer->d); + + raft::logger::get().set_level(RAFT_LEVEL_TRACE); + + raft::neighbors::ivf_flat::index_params pams; + pams.add_data_on_build = false; + + pams.n_lists = this->numLists_; + + switch (this->metric_) { + case faiss::METRIC_L2: + pams.metric = raft::distance::DistanceType::L2Expanded; + break; + case faiss::METRIC_INNER_PRODUCT: + pams.metric = raft::distance::DistanceType::InnerProduct; + break; + default: + FAISS_THROW_MSG("Metric is not supported."); + } + + raft_knn_index.emplace(raft_handle, pams, (uint32_t)this->dim_); + + cudaMemsetAsync( + raft_knn_index.value().list_sizes().data_handle(), + 0, + raft_knn_index.value().list_sizes().size() * sizeof(uint32_t), + stream); + cudaMemsetAsync( + raft_knn_index.value().data_ptrs().data_handle(), + 0, + raft_knn_index.value().data_ptrs().size() * sizeof(float*), + stream); + cudaMemsetAsync( + raft_knn_index.value().inds_ptrs().data_handle(), + 0, + raft_knn_index.value().inds_ptrs().size() * sizeof(idx_t*), + stream); + + /// Copy (reconstructed) centroids over, rather than re-training + std::vector buf_host(total_elems); + quantizer->reconstruct_n(0, quantizer_ntotal, buf_host.data()); + + raft::update_device( + raft_knn_index.value().centers().data_handle(), + buf_host.data(), + total_elems, + stream); +} + +void RaftIVFFlat::copyInvertedListsFrom(const InvertedLists* ivf) { + size_t nlist = ivf ? ivf->nlist : 0; + size_t ntotal = ivf ? ivf->compute_ntotal() : 0; + + raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + + std::vector list_sizes_(nlist); + std::vector indices_(ntotal); + + // the index must already exist + FAISS_ASSERT(raft_knn_index.has_value()); + + auto& raft_lists = raft_knn_index.value().lists(); + + // conservative memory alloc for cloning cpu inverted lists + raft::neighbors::ivf_flat::list_spec raft_list_spec{ + static_cast(dim_), true}; + + for (size_t i = 0; i < nlist; ++i) { + size_t listSize = ivf->list_size(i); + + // GPU index can only support max int entries per list + FAISS_THROW_IF_NOT_FMT( + listSize <= (size_t)std::numeric_limits::max(), + "GPU inverted list can only support " + "%zu entries; %zu found", + (size_t)std::numeric_limits::max(), + listSize); + + // store the list size + list_sizes_[i] = static_cast(listSize); + + raft::neighbors::ivf::resize_list( + raft_handle, + raft_lists[i], + raft_list_spec, + (uint32_t)listSize, + (uint32_t)0); + } + + // Update the pointers and the sizes + raft_knn_index.value().recompute_internal_state(raft_handle); + + for (size_t i = 0; i < nlist; ++i) { + size_t listSize = ivf->list_size(i); + addEncodedVectorsToList_( + i, ivf->get_codes(i), ivf->get_ids(i), listSize); + } + + raft::update_device( + raft_knn_index.value().list_sizes().data_handle(), + list_sizes_.data(), + nlist, + raft_handle.get_stream()); + + // Precompute the centers vector norms for L2Expanded distance + if (this->metric_ == faiss::METRIC_L2) { + raft_knn_index.value().allocate_center_norms(raft_handle); + raft::linalg::rowNorm( + raft_knn_index.value().center_norms().value().data_handle(), + raft_knn_index.value().centers().data_handle(), + raft_knn_index.value().dim(), + (uint32_t)nlist, + raft::linalg::L2Norm, + true, + raft_handle.get_stream()); + } +} + +size_t RaftIVFFlat::getGpuVectorsEncodingSize_(idx_t numVecs) const { + idx_t bits = 32 /* float */; + + // bytes to encode a block of 32 vectors (single dimension) + idx_t bytesPerDimBlock = bits * 32 / 8; // = 128 + + // bytes to fully encode 32 vectors + idx_t bytesPerBlock = bytesPerDimBlock * dim_; + + // number of blocks of 32 vectors we have + idx_t numBlocks = + utils::divUp(numVecs, raft::neighbors::ivf_flat::kIndexGroupSize); + + // total size to encode numVecs + return bytesPerBlock * numBlocks; +} + +void RaftIVFFlat::addEncodedVectorsToList_( + idx_t listId, + const void* codes, + const idx_t* indices, + idx_t numVecs) { + auto stream = resources_->getDefaultStreamCurrentDevice(); + + // This list must already exist + FAISS_ASSERT(raft_knn_index.has_value()); + + // This list must currently be empty + FAISS_ASSERT(getListLength(listId) == 0); + + // If there's nothing to add, then there's nothing we have to do + if (numVecs == 0) { + return; + } + + // The GPU might have a different layout of the memory + auto gpuListSizeInBytes = getGpuVectorsEncodingSize_(numVecs); + auto cpuListSizeInBytes = getCpuVectorsEncodingSize_(numVecs); + + // We only have int32 length representations on the GPU per each + // list; the length is in sizeof(char) + FAISS_ASSERT(gpuListSizeInBytes <= (size_t)std::numeric_limits::max()); + + std::vector interleaved_codes(gpuListSizeInBytes); + RaftIVFFlatCodePackerInterleaved packer( + (size_t)numVecs, (uint32_t)dim_, raft_knn_index.value().veclen()); + + packer.pack_all( + reinterpret_cast(codes), interleaved_codes.data()); + + float* list_data_ptr; + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + + /// fetch the list data ptr on host + raft::update_host( + &list_data_ptr, + raft_knn_index.value().data_ptrs().data_handle() + listId, + 1, + stream); + raft_handle.sync_stream(); + + raft::update_device( + reinterpret_cast(list_data_ptr), + interleaved_codes.data(), + gpuListSizeInBytes, + stream); + + /// Handle the indices as well + idx_t* list_indices_ptr; + + // fetch the list indices ptr on host + raft::update_host( + &list_indices_ptr, + raft_knn_index.value().inds_ptrs().data_handle() + listId, + 1, + stream); + raft_handle.sync_stream(); + + raft::update_device(list_indices_ptr, indices, numVecs, stream); +} + +void RaftIVFFlat::validRowIndices_( + Tensor& vecs, + bool* nan_flag) { + raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + idx_t n_rows = vecs.getSize(0); + + thrust::fill_n(raft_handle.get_thrust_policy(), nan_flag, n_rows, true); + raft::linalg::map_offset( + raft_handle, + raft::make_device_vector_view(nan_flag, n_rows), + [vecs = vecs.data(), dim_ = this->dim_] __device__(idx_t i) { + for (idx_t col = 0; col < dim_; col++) { + if (!isfinite(vecs[i * dim_ + col])) { + return false; + } + } + return true; + }); +} + +RaftIVFFlatCodePackerInterleaved::RaftIVFFlatCodePackerInterleaved( + size_t list_size, + uint32_t dim, + uint32_t chunk_size) { + this->dim = dim; + this->chunk_size = chunk_size; + // NB: dim should be divisible by the number of 4 byte records in one chunk + FAISS_ASSERT(dim % chunk_size == 0); + nvec = list_size; + code_size = dim * 4; + block_size = + utils::roundUp(nvec, raft::neighbors::ivf_flat::kIndexGroupSize); +} + +void RaftIVFFlatCodePackerInterleaved::pack_1( + const uint8_t* flat_code, + size_t offset, + uint8_t* block) const { + raft::neighbors::ivf_flat::codepacker::pack_1( + reinterpret_cast(flat_code), + reinterpret_cast(block), + dim, + chunk_size, + static_cast(offset)); +} + +void RaftIVFFlatCodePackerInterleaved::unpack_1( + const uint8_t* block, + size_t offset, + uint8_t* flat_code) const { + raft::neighbors::ivf_flat::codepacker::unpack_1( + reinterpret_cast(block), + reinterpret_cast(flat_code), + dim, + chunk_size, + static_cast(offset)); +} + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/impl/RaftIVFFlat.cuh b/faiss/gpu/impl/RaftIVFFlat.cuh new file mode 100644 index 0000000000..3aba501c9f --- /dev/null +++ b/faiss/gpu/impl/RaftIVFFlat.cuh @@ -0,0 +1,149 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +#include + +namespace faiss { +namespace gpu { + +class RaftIVFFlat : public IVFFlat { + public: + RaftIVFFlat( + GpuResources* resources, + int dim, + int nlist, + faiss::MetricType metric, + float metricArg, + bool useResidual, + /// Optional ScalarQuantizer + faiss::ScalarQuantizer* scalarQ, + bool interleavedLayout, + IndicesOptions indicesOptions, + MemorySpace space); + + ~RaftIVFFlat() override; + + /// Find the approximate k nearest neigbors for `queries` against + /// our database + void search( + Index* coarseQuantizer, + Tensor& queries, + int nprobe, + int k, + Tensor& outDistances, + Tensor& outIndices) override; + + /// Performs search when we are already given the IVF cells to look at + /// (GpuIndexIVF::search_preassigned implementation) + void searchPreassigned( + Index* coarseQuantizer, + Tensor& vecs, + Tensor& ivfDistances, + Tensor& ivfAssignments, + int k, + Tensor& outDistances, + Tensor& outIndices, + bool storePairs) override; + + /// Classify and encode/add vectors to our IVF lists. + /// The input data must be on our current device. + /// Returns the number of vectors successfully added. Vectors may + /// not be able to be added because they contain NaNs. + idx_t addVectors( + Index* coarseQuantizer, + Tensor& vecs, + Tensor& indices) override; + + /// Reserve GPU memory in our inverted lists for this number of vectors + // void reserveMemory(idx_t numVecs) override; + + /// Clear out all inverted lists, but retain the coarse quantizer + /// and the product quantizer info + void reset() override; + + /// For debugging purposes, return the list length of a particular + /// list + idx_t getListLength(idx_t listId) const override; + + /// Return the list indices of a particular list back to the CPU + std::vector getListIndices(idx_t listId) const override; + + /// Return the encoded vectors of a particular list back to the CPU + std::vector getListVectorData(idx_t listId, bool gpuFormat) + const override; + + void updateQuantizer(Index* quantizer) override; + + /// Copy all inverted lists from a CPU representation to ourselves + void copyInvertedListsFrom(const InvertedLists* ivf) override; + + /// Filter out matrix rows containing NaN values + void validRowIndices_(Tensor& vecs, bool* nan_flag); + + protected: + /// Adds a set of codes and indices to a list, with the representation + /// coming from the CPU equivalent + void addEncodedVectorsToList_( + idx_t listId, + // resident on the host + const void* codes, + // resident on the host + const idx_t* indices, + idx_t numVecs) override; + + /// Returns the number of bytes in which an IVF list containing numVecs + /// vectors is encoded on the device. Note that due to padding this is not + /// the same as the encoding size for a subset of vectors in an IVF list; + /// this is the size for an entire IVF list + size_t getGpuVectorsEncodingSize_(idx_t numVecs) const override; + + std::optional> + raft_knn_index{std::nullopt}; +}; + +struct RaftIVFFlatCodePackerInterleaved : CodePacker { + RaftIVFFlatCodePackerInterleaved( + size_t list_size, + uint32_t dim, + uint32_t chuk_size); + void pack_1(const uint8_t* flat_code, size_t offset, uint8_t* block) + const final; + void unpack_1(const uint8_t* block, size_t offset, uint8_t* flat_code) + const final; + + protected: + uint32_t chunk_size; + uint32_t dim; +}; + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/test/TestGpuIndexFlat.cpp b/faiss/gpu/test/TestGpuIndexFlat.cpp index 4f7c95deab..6d9c83e547 100644 --- a/faiss/gpu/test/TestGpuIndexFlat.cpp +++ b/faiss/gpu/test/TestGpuIndexFlat.cpp @@ -749,7 +749,6 @@ void testSearchAndReconstruct(bool use_raft) { } } } - TEST(TestGpuIndexFlat, SearchAndReconstruct) { testSearchAndReconstruct(false); } @@ -767,4 +766,4 @@ int main(int argc, char** argv) { faiss::gpu::setTestSeed(100); return RUN_ALL_TESTS(); -} +} \ No newline at end of file diff --git a/faiss/gpu/test/TestGpuIndexIVFFlat.cpp b/faiss/gpu/test/TestGpuIndexIVFFlat.cpp index c4fc95ef29..9fb88e2687 100644 --- a/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +++ b/faiss/gpu/test/TestGpuIndexIVFFlat.cpp @@ -30,6 +30,7 @@ #include #include #include +#include "faiss/gpu/GpuIndicesOptions.h" // FIXME: figure out a better way to test fp16 constexpr float kF16MaxRelErr = 0.3f; @@ -55,6 +56,8 @@ struct Options { faiss::gpu::INDICES_64_BIT}); device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + use_raft = false; } std::string toString() const { @@ -62,7 +65,7 @@ struct Options { str << "IVFFlat device " << device << " numVecs " << numAdd << " dim " << dim << " numCentroids " << numCentroids << " nprobe " << nprobe << " numQuery " << numQuery << " k " << k << " indicesOpt " - << indicesOpt; + << indicesOpt << " use_raft " << use_raft; return str.str(); } @@ -76,6 +79,7 @@ struct Options { int k; int device; faiss::gpu::IndicesOptions indicesOpt; + bool use_raft; }; void queryTest( @@ -106,6 +110,7 @@ void queryTest( config.device = opt.device; config.indicesOptions = opt.indicesOpt; config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.use_raft = opt.use_raft; faiss::gpu::GpuIndexIVFFlat gpuIndex( &res, cpuIndex.d, cpuIndex.nlist, cpuIndex.metric_type, config); @@ -129,7 +134,10 @@ void queryTest( } } -void addTest(faiss::MetricType metricType, bool useFloat16CoarseQuantizer) { +void addTest( + faiss::MetricType metricType, + bool useFloat16CoarseQuantizer, + bool use_raft) { for (int tries = 0; tries < 2; ++tries) { Options opt; @@ -153,8 +161,10 @@ void addTest(faiss::MetricType metricType, bool useFloat16CoarseQuantizer) { faiss::gpu::GpuIndexIVFFlatConfig config; config.device = opt.device; - config.indicesOptions = opt.indicesOpt; + config.indicesOptions = + use_raft ? faiss::gpu::INDICES_64_BIT : opt.indicesOpt; config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.use_raft = use_raft; faiss::gpu::GpuIndexIVFFlat gpuIndex( &res, cpuIndex.d, cpuIndex.nlist, cpuIndex.metric_type, config); @@ -178,7 +188,7 @@ void addTest(faiss::MetricType metricType, bool useFloat16CoarseQuantizer) { } } -void copyToTest(bool useFloat16CoarseQuantizer) { +void copyToTest(bool useFloat16CoarseQuantizer, bool use_raft) { Options opt; std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); @@ -188,8 +198,10 @@ void copyToTest(bool useFloat16CoarseQuantizer) { faiss::gpu::GpuIndexIVFFlatConfig config; config.device = opt.device; - config.indicesOptions = opt.indicesOpt; + config.indicesOptions = + use_raft ? faiss::gpu::INDICES_64_BIT : opt.indicesOpt; config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.use_raft = use_raft; faiss::gpu::GpuIndexIVFFlat gpuIndex( &res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config); @@ -229,7 +241,7 @@ void copyToTest(bool useFloat16CoarseQuantizer) { compFloat16 ? 0.30f : 0.015f); } -void copyFromTest(bool useFloat16CoarseQuantizer) { +void copyFromTest(bool useFloat16CoarseQuantizer, bool use_raft) { Options opt; std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); @@ -247,8 +259,10 @@ void copyFromTest(bool useFloat16CoarseQuantizer) { faiss::gpu::GpuIndexIVFFlatConfig config; config.device = opt.device; - config.indicesOptions = opt.indicesOpt; + config.indicesOptions = + use_raft ? faiss::gpu::INDICES_64_BIT : opt.indicesOpt; config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.use_raft = use_raft; faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, 1, 1, faiss::METRIC_L2, config); gpuIndex.nprobe = 1; @@ -280,19 +294,35 @@ void copyFromTest(bool useFloat16CoarseQuantizer) { } TEST(TestGpuIndexIVFFlat, Float32_32_Add_L2) { - addTest(faiss::METRIC_L2, false); + addTest(faiss::METRIC_L2, false, false); + +#if defined USE_NVIDIA_RAFT + addTest(faiss::METRIC_L2, false, true); +#endif } TEST(TestGpuIndexIVFFlat, Float32_32_Add_IP) { - addTest(faiss::METRIC_INNER_PRODUCT, false); + addTest(faiss::METRIC_INNER_PRODUCT, false, false); + +#if defined USE_NVIDIA_RAFT + addTest(faiss::METRIC_INNER_PRODUCT, false, true); +#endif } TEST(TestGpuIndexIVFFlat, Float16_32_Add_L2) { - addTest(faiss::METRIC_L2, true); + addTest(faiss::METRIC_L2, true, false); + +#if defined USE_NVIDIA_RAFT + addTest(faiss::METRIC_L2, true, true); +#endif } TEST(TestGpuIndexIVFFlat, Float16_32_Add_IP) { - addTest(faiss::METRIC_INNER_PRODUCT, true); + addTest(faiss::METRIC_INNER_PRODUCT, true, false); + +#if defined USE_NVIDIA_RAFT + addTest(faiss::METRIC_INNER_PRODUCT, true, true); +#endif } // @@ -300,11 +330,25 @@ TEST(TestGpuIndexIVFFlat, Float16_32_Add_IP) { // TEST(TestGpuIndexIVFFlat, Float32_Query_L2) { - queryTest(Options(), faiss::METRIC_L2, false); + Options opt; + queryTest(opt, faiss::METRIC_L2, false); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + opt.indicesOpt = faiss::gpu::INDICES_64_BIT; + queryTest(opt, faiss::METRIC_L2, false); +#endif } TEST(TestGpuIndexIVFFlat, Float32_Query_IP) { - queryTest(Options(), faiss::METRIC_INNER_PRODUCT, false); + Options opt; + queryTest(opt, faiss::METRIC_INNER_PRODUCT, false); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + opt.indicesOpt = faiss::gpu::INDICES_64_BIT; + queryTest(opt, faiss::METRIC_INNER_PRODUCT, false); +#endif } TEST(TestGpuIndexIVFFlat, LargeBatch) { @@ -312,16 +356,36 @@ TEST(TestGpuIndexIVFFlat, LargeBatch) { opt.dim = 3; opt.numQuery = 100000; queryTest(opt, faiss::METRIC_L2, false); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + opt.indicesOpt = faiss::gpu::INDICES_64_BIT; + queryTest(opt, faiss::METRIC_L2, false); +#endif } // float16 coarse quantizer TEST(TestGpuIndexIVFFlat, Float16_32_Query_L2) { - queryTest(Options(), faiss::METRIC_L2, true); + Options opt; + queryTest(opt, faiss::METRIC_L2, true); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + opt.indicesOpt = faiss::gpu::INDICES_64_BIT; + queryTest(opt, faiss::METRIC_L2, true); +#endif } TEST(TestGpuIndexIVFFlat, Float16_32_Query_IP) { - queryTest(Options(), faiss::METRIC_INNER_PRODUCT, true); + Options opt; + queryTest(opt, faiss::METRIC_INNER_PRODUCT, true); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + opt.indicesOpt = faiss::gpu::INDICES_64_BIT; + queryTest(opt, faiss::METRIC_INNER_PRODUCT, true); +#endif } // @@ -333,24 +397,48 @@ TEST(TestGpuIndexIVFFlat, Float32_Query_L2_64) { Options opt; opt.dim = 64; queryTest(opt, faiss::METRIC_L2, false); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + opt.indicesOpt = faiss::gpu::INDICES_64_BIT; + queryTest(opt, faiss::METRIC_L2, false); +#endif } TEST(TestGpuIndexIVFFlat, Float32_Query_IP_64) { Options opt; opt.dim = 64; queryTest(opt, faiss::METRIC_INNER_PRODUCT, false); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + opt.indicesOpt = faiss::gpu::INDICES_64_BIT; + queryTest(opt, faiss::METRIC_INNER_PRODUCT, false); +#endif } TEST(TestGpuIndexIVFFlat, Float32_Query_L2_128) { Options opt; opt.dim = 128; queryTest(opt, faiss::METRIC_L2, false); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + opt.indicesOpt = faiss::gpu::INDICES_64_BIT; + queryTest(opt, faiss::METRIC_L2, false); +#endif } TEST(TestGpuIndexIVFFlat, Float32_Query_IP_128) { Options opt; opt.dim = 128; queryTest(opt, faiss::METRIC_INNER_PRODUCT, false); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + opt.indicesOpt = faiss::gpu::INDICES_64_BIT; + queryTest(opt, faiss::METRIC_INNER_PRODUCT, false); +#endif } // @@ -358,11 +446,19 @@ TEST(TestGpuIndexIVFFlat, Float32_Query_IP_128) { // TEST(TestGpuIndexIVFFlat, Float32_32_CopyTo) { - copyToTest(false); + copyToTest(false, false); + +#if defined USE_NVIDIA_RAFT + copyToTest(false, true); +#endif } TEST(TestGpuIndexIVFFlat, Float32_32_CopyFrom) { - copyFromTest(false); + copyFromTest(false, false); + +#if defined USE_NVIDIA_RAFT + copyFromTest(false, true); +#endif } TEST(TestGpuIndexIVFFlat, Float32_negative) { @@ -392,6 +488,14 @@ TEST(TestGpuIndexIVFFlat, Float32_negative) { faiss::gpu::StandardGpuResources res; res.noTempMemory(); + // Construct a positive test set + auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + + // Put all vecs on positive size + for (auto& f : queryVecs) { + f = std::abs(f); + } + faiss::gpu::GpuIndexIVFFlatConfig config; config.device = opt.device; config.indicesOptions = opt.indicesOpt; @@ -401,14 +505,6 @@ TEST(TestGpuIndexIVFFlat, Float32_negative) { gpuIndex.copyFrom(&cpuIndex); gpuIndex.nprobe = opt.nprobe; - // Construct a positive test set - auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); - - // Put all vecs on positive size - for (auto& f : queryVecs) { - f = std::abs(f); - } - bool compFloat16 = false; faiss::gpu::compareIndices( queryVecs, @@ -424,6 +520,31 @@ TEST(TestGpuIndexIVFFlat, Float32_negative) { // in fp16. Figure out another way to test compFloat16 ? 0.99f : 0.1f, compFloat16 ? 0.65f : 0.015f); + +#if defined USE_NVIDIA_RAFT + config.use_raft = true; + config.indicesOptions = faiss::gpu::INDICES_64_BIT; + + faiss::gpu::GpuIndexIVFFlat raftGpuIndex( + &res, cpuIndex.d, cpuIndex.nlist, cpuIndex.metric_type, config); + raftGpuIndex.copyFrom(&cpuIndex); + raftGpuIndex.nprobe = opt.nprobe; + + faiss::gpu::compareIndices( + queryVecs, + cpuIndex, + raftGpuIndex, + opt.numQuery, + opt.dim, + opt.k, + opt.toString(), + compFloat16 ? kF16MaxRelErr : kF32MaxRelErr, + // FIXME: the fp16 bounds are + // useless when math (the accumulator) is + // in fp16. Figure out another way to test + compFloat16 ? 0.99f : 0.1f, + compFloat16 ? 0.65f : 0.015f); +#endif } // @@ -439,6 +560,13 @@ TEST(TestGpuIndexIVFFlat, QueryNaN) { faiss::gpu::StandardGpuResources res; res.noTempMemory(); + int numQuery = 10; + std::vector nans( + numQuery * opt.dim, std::numeric_limits::quiet_NaN()); + + std::vector distances(numQuery * opt.k, 0); + std::vector indices(numQuery * opt.k, 0); + faiss::gpu::GpuIndexIVFFlatConfig config; config.device = opt.device; config.indicesOptions = opt.indicesOpt; @@ -451,14 +579,31 @@ TEST(TestGpuIndexIVFFlat, QueryNaN) { gpuIndex.train(opt.numTrain, trainVecs.data()); gpuIndex.add(opt.numAdd, addVecs.data()); - int numQuery = 10; - std::vector nans( - numQuery * opt.dim, std::numeric_limits::quiet_NaN()); + gpuIndex.search( + numQuery, nans.data(), opt.k, distances.data(), indices.data()); - std::vector distances(numQuery * opt.k, 0); - std::vector indices(numQuery * opt.k, 0); + for (int q = 0; q < numQuery; ++q) { + for (int k = 0; k < opt.k; ++k) { + EXPECT_EQ(indices[q * opt.k + k], -1); + EXPECT_EQ( + distances[q * opt.k + k], + std::numeric_limits::max()); + } + } - gpuIndex.search( +#if defined USE_NVIDIA_RAFT + config.use_raft = true; + config.indicesOptions = faiss::gpu::INDICES_64_BIT; + std::fill(distances.begin(), distances.end(), 0); + std::fill(indices.begin(), indices.end(), 0); + faiss::gpu::GpuIndexIVFFlat raftGpuIndex( + &res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config); + raftGpuIndex.nprobe = opt.nprobe; + + raftGpuIndex.train(opt.numTrain, trainVecs.data()); + raftGpuIndex.add(opt.numAdd, addVecs.data()); + + raftGpuIndex.search( numQuery, nans.data(), opt.k, distances.data(), indices.data()); for (int q = 0; q < numQuery; ++q) { @@ -469,6 +614,7 @@ TEST(TestGpuIndexIVFFlat, QueryNaN) { std::numeric_limits::max()); } } +#endif } TEST(TestGpuIndexIVFFlat, AddNaN) { @@ -477,15 +623,6 @@ TEST(TestGpuIndexIVFFlat, AddNaN) { faiss::gpu::StandardGpuResources res; res.noTempMemory(); - faiss::gpu::GpuIndexIVFFlatConfig config; - config.device = opt.device; - config.indicesOptions = opt.indicesOpt; - config.flatConfig.useFloat16 = faiss::gpu::randBool(); - - faiss::gpu::GpuIndexIVFFlat gpuIndex( - &res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config); - gpuIndex.nprobe = opt.nprobe; - int numNans = 10; std::vector nans( numNans * opt.dim, std::numeric_limits::quiet_NaN()); @@ -497,6 +634,14 @@ TEST(TestGpuIndexIVFFlat, AddNaN) { } std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = opt.device; + config.indicesOptions = opt.indicesOpt; + config.flatConfig.useFloat16 = faiss::gpu::randBool(); + faiss::gpu::GpuIndexIVFFlat gpuIndex( + &res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config); + gpuIndex.nprobe = opt.nprobe; gpuIndex.train(opt.numTrain, trainVecs.data()); // should not crash @@ -514,6 +659,27 @@ TEST(TestGpuIndexIVFFlat, AddNaN) { opt.k, distance.data(), indices.data()); + +#if defined USE_NVIDIA_RAFT + config.use_raft = true; + config.indicesOptions = faiss::gpu::INDICES_64_BIT; + faiss::gpu::GpuIndexIVFFlat raftGpuIndex( + &res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config); + raftGpuIndex.nprobe = opt.nprobe; + raftGpuIndex.train(opt.numTrain, trainVecs.data()); + + // should not crash + EXPECT_EQ(raftGpuIndex.ntotal, 0); + raftGpuIndex.add(numNans, nans.data()); + + // should not crash + raftGpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + distance.data(), + indices.data()); +#endif } TEST(TestGpuIndexIVFFlat, UnifiedMemory) { @@ -570,6 +736,26 @@ TEST(TestGpuIndexIVFFlat, UnifiedMemory) { kF32MaxRelErr, 0.1f, 0.015f); + +#if defined USE_NVIDIA_RAFT + config.use_raft = true; + config.indicesOptions = faiss::gpu::INDICES_64_BIT; + faiss::gpu::GpuIndexIVFFlat raftGpuIndex( + &res, dim, numCentroids, faiss::METRIC_L2, config); + raftGpuIndex.copyFrom(&cpuIndex); + raftGpuIndex.nprobe = nprobe; + + faiss::gpu::compareIndices( + cpuIndex, + raftGpuIndex, + numQuery, + dim, + k, + "Unified Memory", + kF32MaxRelErr, + 0.1f, + 0.015f); +#endif } TEST(TestGpuIndexIVFFlat, LongIVFList) { @@ -628,6 +814,27 @@ TEST(TestGpuIndexIVFFlat, LongIVFList) { kF32MaxRelErr, 0.1f, 0.015f); + +#if defined USE_NVIDIA_RAFT + config.use_raft = true; + config.indicesOptions = faiss::gpu::INDICES_64_BIT; + faiss::gpu::GpuIndexIVFFlat raftGpuIndex( + &res, dim, numCentroids, faiss::METRIC_L2, config); + raftGpuIndex.train(numTrain, trainVecs.data()); + raftGpuIndex.add(numAdd, addVecs.data()); + raftGpuIndex.nprobe = 1; + + faiss::gpu::compareIndices( + cpuIndex, + raftGpuIndex, + numQuery, + dim, + k, + "Unified Memory", + kF32MaxRelErr, + 0.1f, + 0.015f); +#endif } int main(int argc, char** argv) { @@ -637,4 +844,4 @@ int main(int argc, char** argv) { faiss::gpu::setTestSeed(100); return RUN_ALL_TESTS(); -} +} \ No newline at end of file