Skip to content

Commit

Permalink
Make the tests to not crash... sometimes
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Jul 6, 2022
1 parent b7144a9 commit 38733bb
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 309 deletions.
171 changes: 69 additions & 102 deletions faiss/gpu/raft/RaftIndexIVFFlat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,70 +38,49 @@ RaftIndexIVFFlat::RaftIndexIVFFlat(
faiss::MetricType metric,
GpuIndexIVFFlatConfig config)
: GpuIndexIVFFlat(provider, dims, nlist, metric, config),
raft_handle(resources_->getDefaultStream(config_.device)) {
this->is_trained = false;
}
raft_handle(resources_->getDefaultStream(config_.device)) {}

RaftIndexIVFFlat::~RaftIndexIVFFlat() {}
RaftIndexIVFFlat::~RaftIndexIVFFlat() {
RaftIndexIVFFlat::reset();
}

void RaftIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
printf("Copying from...\n");

// TODO: Need to copy necessary memory from the index and set any needed
// params.
DeviceScope scope(config_.device);

GpuIndex::copyFrom(index);

FAISS_ASSERT(index->nlist > 0);
FAISS_THROW_IF_NOT_FMT(
index->nlist <= (Index::idx_t)std::numeric_limits<int>::max(),
"GPU index only supports %zu inverted lists",
(size_t)std::numeric_limits<int>::max());
nlist = index->nlist;

FAISS_THROW_IF_NOT_FMT(
index->nprobe > 0 && index->nprobe <= getMaxKSelection(),
"GPU index only supports nprobe <= %zu; passed %zu",
(size_t)getMaxKSelection(),
index->nprobe);
nprobe = index->nprobe;

// config.device = config_.device;
if (!index->is_trained) {
// copied in GpuIndex::copyFrom
FAISS_ASSERT(!is_trained && ntotal == 0);
return;
if (index->is_trained && index->ntotal > 0) {
// TODO: A proper copy of the index without retraining
// For now, just get all the data from the index, and train our index
// anew.
auto stream = raft_handle.get_stream();
auto total_elems = size_t(index->ntotal) * size_t(index->d);
rmm::device_uvector<float> buf_dev(total_elems, stream);
{
std::vector<float> buf_host(total_elems);
index->reconstruct_n(0, index->ntotal, buf_host.data());
raft::copy(buf_dev.data(), buf_host.data(), total_elems, stream);
}
FAISS_ASSERT(index->d == this->d);
FAISS_ASSERT(index->metric_arg == this->metric_arg);
FAISS_ASSERT(index->metric_type == this->metric_type);
FAISS_ASSERT(index->nlist == this->nlist);
RaftIndexIVFFlat::rebuildRaftIndex(buf_dev.data(), index->ntotal);
} else {
// index is not trained, so we can remove ours as well (if there was
// any)
raft_knn_index.reset();
}

// copied in GpuIndex::copyFrom
// ntotal can exceed max int, but the number of vectors per inverted
// list cannot exceed this. We check this in the subclasses.
FAISS_ASSERT(is_trained && (ntotal == index->ntotal));

// Since we're trained, the quantizer must have data
FAISS_ASSERT(index->quantizer->ntotal > 0);

raft::spatial::knn::ivf_flat::index_params raft_idx_params;
raft_idx_params.n_lists = nlist;

switch (metric_type) {
case faiss::METRIC_L2:
raft_idx_params.metric = raft::distance::DistanceType::L2Expanded;
break;
case faiss::METRIC_INNER_PRODUCT:
raft_idx_params.metric = raft::distance::DistanceType::InnerProduct;
break;
default:
FAISS_THROW_MSG("Metric is not supported.");
}

// TODO: Invoke corresponding call on the RAFT side to copy quantizer
/**
* For example:
* raft_knn_index.emplace(raft::spatial::knn::ivf_flat::make_index<T>(
* raft_handle, raft_idx_params, (faiss::Index::idx_t)d);
*/
this->is_trained = index->is_trained;
}

void RaftIndexIVFFlat::reserveMemory(size_t numVecs) {
Expand Down Expand Up @@ -137,23 +116,8 @@ size_t RaftIndexIVFFlat::reclaimMemory() {
}

void RaftIndexIVFFlat::train(Index::idx_t n, const float* x) {
// For now, only support <= max int results
FAISS_THROW_IF_NOT_FMT(
n <= (Index::idx_t)std::numeric_limits<int>::max(),
"GPU index only supports up to %d indices",
std::numeric_limits<int>::max());

DeviceScope scope(config_.device);

if (this->is_trained) {
FAISS_ASSERT(raft_knn_index.has_value());
return;
}

raft::spatial::knn::ivf_flat::index_params raft_idx_params;
raft_idx_params.n_lists = nlist;
raft_idx_params.metric = raft::distance::DistanceType::L2Expanded;

// TODO: This should only train the quantizer portion of the index
/**
* For example:
Expand All @@ -163,16 +127,11 @@ void RaftIndexIVFFlat::train(Index::idx_t n, const float* x) {
* raft::spatial::knn::ivf_flat::train_quantizer(
* raft_handle, *raft_knn_index, const_cast<float*>(x), n);
*
* NB: ivf_flat does not have a quantizer. Training here imply kmeans?
*/

raft_knn_index.emplace(raft::spatial::knn::ivf_flat::build(
raft_handle,
raft_idx_params,
const_cast<float*>(x),
n,
(faiss::Index::idx_t)d));

raft_handle.sync_stream();
RaftIndexIVFFlat::rebuildRaftIndex(x, n);
}

int RaftIndexIVFFlat::getListLength(int listId) const {
Expand Down Expand Up @@ -208,8 +167,8 @@ std::vector<uint8_t> RaftIndexIVFFlat::getListVectorData(
}

void RaftIndexIVFFlat::reset() {
std::cout << "Calling reset()" << std::endl;
raft_knn_index.reset();
this->ntotal = 0;
}

std::vector<Index::idx_t> RaftIndexIVFFlat::getListIndices(int listId) const {
Expand All @@ -232,28 +191,20 @@ void RaftIndexIVFFlat::addImpl_(
const float* x,
const Index::idx_t* xids) {
// Device is already set in GpuIndex::add
FAISS_ASSERT(raft_knn_index.has_value());
FAISS_ASSERT(is_trained);
FAISS_ASSERT(n > 0);
/* TODO:
At the moment, raft does not support adding vectors, and does not support
providing indices with the vectors even in training
// Data is already resident on the GPU
Tensor<float, 2, true> data(const_cast<float*>(x), {n, (int)this->d});
Tensor<Index::idx_t, 1, true> labels(const_cast<Index::idx_t*>(xids), {n});

// // Not all vectors may be able to be added (some may contain NaNs etc)
// index_->addVectors(data, labels);
//
// // but keep the ntotal based on the total number of vectors that we
// // attempted to add
ntotal += n;

std::cout << "Calling addImpl_ with " << n << " vectors." << std::endl;

// TODO: Invoke corresponding call in raft::ivf_flat
/**
* For example:
* raft::spatial::knn::ivf_flat::add_vectors(
* raft_handle, *raft_knn_index, n, x, xids);
For now, just do the training anew
*/
raft_knn_index.reset();

// Not all vectors may be able to be added (some may contain NaNs etc)
// but keep the ntotal based on the total number of vectors that we
// attempted to add index_->addVectors(data, labels);
RaftIndexIVFFlat::rebuildRaftIndex(x, n);
}

void RaftIndexIVFFlat::searchImpl_(
Expand All @@ -267,28 +218,44 @@ void RaftIndexIVFFlat::searchImpl_(
FAISS_ASSERT(n > 0);
FAISS_THROW_IF_NOT(nprobe > 0 && nprobe <= nlist);

// Data is already resident on the GPU
Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int)this->d});
Tensor<float, 2, true> outDistances(distances, {n, k});
Tensor<Index::idx_t, 2, true> outLabels(
const_cast<Index::idx_t*>(labels), {n, k});

// TODO: Populate the rest of the params properly.
raft::spatial::knn::ivf_flat::search_params raft_idx_params;
raft_idx_params.n_probes = nprobe;

raft::spatial::knn::ivf_flat::search_params pams;
pams.n_probes = nprobe;
raft::spatial::knn::ivf_flat::search<float, faiss::Index::idx_t>(
raft_handle,
raft_idx_params,
pams,
*raft_knn_index,
const_cast<float*>(x),
static_cast<std::uint32_t>(n),
static_cast<std::uint32_t>(k),
static_cast<faiss::Index::idx_t*>(labels),
labels,
distances);

raft_handle.sync_stream();
}

void RaftIndexIVFFlat::rebuildRaftIndex(const float* x, Index::idx_t n_rows) {
raft::spatial::knn::ivf_flat::index_params pams;

pams.n_lists = this->nlist;
switch (this->metric_type) {
case faiss::METRIC_L2:
pams.metric = raft::distance::DistanceType::L2Expanded;
break;
case faiss::METRIC_INNER_PRODUCT:
pams.metric = raft::distance::DistanceType::InnerProduct;
break;
default:
FAISS_THROW_MSG("Metric is not supported.");
}
pams.metric_arg = this->metric_arg;
pams.kmeans_trainset_fraction = 1.0;

raft_knn_index.emplace(raft::spatial::knn::ivf_flat::build(
this->raft_handle, pams, x, n_rows, uint32_t(this->d)));
this->raft_handle.sync_stream();
this->is_trained = true;
this->ntotal = n_rows;
}

} // namespace gpu
} // namespace faiss
2 changes: 2 additions & 0 deletions faiss/gpu/raft/RaftIndexIVFFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class RaftIndexIVFFlat : public GpuIndexIVFFlat {
float* distances,
Index::idx_t* labels) const override;

void rebuildRaftIndex(const float* x, Index::idx_t n_rows);

const raft::handle_t raft_handle;
std::optional<raft::spatial::knn::ivf_flat::index<float, Index::idx_t>> raft_knn_index{std::nullopt};
};
Expand Down
Loading

0 comments on commit 38733bb

Please sign in to comment.