Skip to content

Commit

Permalink
TimeoutCallback C++ and Python (facebookresearch#3417)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#3417

facebookresearch#3351

Reviewed By: junjieqi

Differential Revision: D57120422

fbshipit-source-id: e2e446642e7be8647f5115f90916fad242e31286
  • Loading branch information
Amir Sadoughi authored and abhinavdangeti committed Jul 12, 2024
1 parent 279cf07 commit 44e9c87
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 2 deletions.
6 changes: 6 additions & 0 deletions faiss/gpu/perf/PerfClustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <vector>

#include <cuda_profiler_api.h>
#include <faiss/impl/AuxIndexStructures.h>

DEFINE_int32(num, 10000, "# of vecs");
DEFINE_int32(k, 100, "# of clusters");
Expand All @@ -34,6 +35,7 @@ DEFINE_int64(
"minimum size to use CPU -> GPU paged copies");
DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use");
DEFINE_int32(max_points, -1, "max points per centroid");
DEFINE_double(timeout, 0, "timeout in seconds");

using namespace faiss::gpu;

Expand Down Expand Up @@ -99,10 +101,14 @@ int main(int argc, char** argv) {
cp.max_points_per_centroid = FLAGS_max_points;
}

auto tc = new faiss::TimeoutCallback();
faiss::InterruptCallback::instance.reset(tc);

faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp);

// Time k-means
{
tc->set_timeout(FLAGS_timeout);
CpuTimer timer;

kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex()));
Expand Down
25 changes: 25 additions & 0 deletions faiss/impl/AuxIndexStructures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,29 @@ size_t InterruptCallback::get_period_hint(size_t flops) {
return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
}

void TimeoutCallback::set_timeout(double timeout_in_seconds) {
timeout = timeout_in_seconds;
start = std::chrono::steady_clock::now();
}

bool TimeoutCallback::want_interrupt() {
if (timeout == 0) {
return false;
}
auto end = std::chrono::steady_clock::now();
std::chrono::duration<float, std::milli> duration = end - start;
float elapsed_in_seconds = duration.count() / 1000.0;
if (elapsed_in_seconds > timeout) {
timeout = 0;
return true;
}
return false;
}

void TimeoutCallback::reset(double timeout_in_seconds) {
auto tc(new faiss::TimeoutCallback());
faiss::InterruptCallback::instance.reset(tc);
tc->set_timeout(timeout_in_seconds);
}

} // namespace faiss
8 changes: 8 additions & 0 deletions faiss/impl/AuxIndexStructures.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ struct FAISS_API InterruptCallback {
static size_t get_period_hint(size_t flops);
};

struct TimeoutCallback : InterruptCallback {
std::chrono::time_point<std::chrono::steady_clock> start;
double timeout;
bool want_interrupt() override;
void set_timeout(double timeout_in_seconds);
static void reset(double timeout_in_seconds);
};

/// set implementation optimized for fast access.
struct VisitedTable {
std::vector<uint8_t> visited;
Expand Down
11 changes: 11 additions & 0 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,14 @@ def deserialize_index_binary(data):
reader = VectorIOReader()
copy_array_to_vector(data, reader.data)
return read_index_binary(reader)


class TimeoutGuard:
def __init__(self, timeout_in_seconds: float):
self.timeout = timeout_in_seconds

def __enter__(self):
TimeoutCallback.reset(self.timeout)

def __exit__(self, exc_type, exc_value, traceback):
PythonInterruptCallback.reset()
9 changes: 7 additions & 2 deletions faiss/python/swigfaiss.swig
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,9 @@ PyObject *swig_ptr (PyObject *a)
PyErr_SetString(PyExc_ValueError, "did not recognize array type");
return NULL;
}
%}

%inline %{

struct PythonInterruptCallback: faiss::InterruptCallback {

Expand All @@ -1056,15 +1058,18 @@ struct PythonInterruptCallback: faiss::InterruptCallback {
return err == -1;
}

static void reset() {
faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());
}
};

%}

%init %{
/* needed, else crash at runtime */
import_array();

faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());

PythonInterruptCallback::reset();
%}

// return a pointer usable as input for functions that expect pointers
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ set(FAISS_TEST_SRC
test_fastscan_perf.cpp
test_disable_pq_sdc_tables.cpp
test_common_ivf_empty_index.cpp
test_callback.cpp
)

add_executable(faiss_test ${FAISS_TEST_SRC})
Expand Down
37 changes: 37 additions & 0 deletions tests/test_callback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>

#include <faiss/Clustering.h>
#include <faiss/IndexFlat.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissException.h>
#include <faiss/utils/random.h>

TEST(TestCallback, timeout) {
int n = 1000;
int k = 100;
int d = 128;
int niter = 1000000000;
int seed = 42;

std::vector<float> vecs(n * d);
faiss::float_rand(vecs.data(), vecs.size(), seed);

auto index(new faiss::IndexFlat(d));

faiss::ClusteringParameters cp;
cp.niter = niter;
cp.verbose = false;

faiss::Clustering kmeans(d, k, cp);

faiss::TimeoutCallback::reset(0.010);
EXPECT_THROW(kmeans.train(n, vecs.data(), *index), faiss::FaissException);
delete index;
}
32 changes: 32 additions & 0 deletions tests/test_callback_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import numpy as np
import faiss


class TestCallbackPy(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def test_timeout(self) -> None:
n = 1000
k = 100
d = 128
niter = 1_000_000_000

x = np.random.rand(n, d).astype('float32')
index = faiss.IndexFlat(d)

cp = faiss.ClusteringParameters()
cp.niter = niter
cp.verbose = False

kmeans = faiss.Clustering(d, k, cp)

with self.assertRaises(RuntimeError):
with faiss.TimeoutGuard(0.010):
kmeans.train(x, index)

0 comments on commit 44e9c87

Please sign in to comment.