diff --git a/faiss/gpu/perf/PerfClustering.cpp b/faiss/gpu/perf/PerfClustering.cpp index 0322f0e490..532557fe20 100644 --- a/faiss/gpu/perf/PerfClustering.cpp +++ b/faiss/gpu/perf/PerfClustering.cpp @@ -17,6 +17,7 @@ #include #include +#include DEFINE_int32(num, 10000, "# of vecs"); DEFINE_int32(k, 100, "# of clusters"); @@ -34,6 +35,7 @@ DEFINE_int64( "minimum size to use CPU -> GPU paged copies"); DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use"); DEFINE_int32(max_points, -1, "max points per centroid"); +DEFINE_double(timeout, 0, "timeout in seconds"); using namespace faiss::gpu; @@ -99,10 +101,14 @@ int main(int argc, char** argv) { cp.max_points_per_centroid = FLAGS_max_points; } + auto tc = new faiss::TimeoutCallback(); + faiss::InterruptCallback::instance.reset(tc); + faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp); // Time k-means { + tc->set_timeout(FLAGS_timeout); CpuTimer timer; kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex())); diff --git a/faiss/impl/AuxIndexStructures.cpp b/faiss/impl/AuxIndexStructures.cpp index cebe8a1e23..e2b2791e55 100644 --- a/faiss/impl/AuxIndexStructures.cpp +++ b/faiss/impl/AuxIndexStructures.cpp @@ -236,4 +236,29 @@ size_t InterruptCallback::get_period_hint(size_t flops) { return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1); } +void TimeoutCallback::set_timeout(double timeout_in_seconds) { + timeout = timeout_in_seconds; + start = std::chrono::steady_clock::now(); +} + +bool TimeoutCallback::want_interrupt() { + if (timeout == 0) { + return false; + } + auto end = std::chrono::steady_clock::now(); + std::chrono::duration duration = end - start; + float elapsed_in_seconds = duration.count() / 1000.0; + if (elapsed_in_seconds > timeout) { + timeout = 0; + return true; + } + return false; +} + +void TimeoutCallback::reset(double timeout_in_seconds) { + auto tc(new faiss::TimeoutCallback()); + faiss::InterruptCallback::instance.reset(tc); + tc->set_timeout(timeout_in_seconds); +} + } // namespace faiss diff --git a/faiss/impl/AuxIndexStructures.h b/faiss/impl/AuxIndexStructures.h index f8b5cca842..7e12a1a3af 100644 --- a/faiss/impl/AuxIndexStructures.h +++ b/faiss/impl/AuxIndexStructures.h @@ -161,6 +161,14 @@ struct FAISS_API InterruptCallback { static size_t get_period_hint(size_t flops); }; +struct TimeoutCallback : InterruptCallback { + std::chrono::time_point start; + double timeout; + bool want_interrupt() override; + void set_timeout(double timeout_in_seconds); + static void reset(double timeout_in_seconds); +}; + /// set implementation optimized for fast access. struct VisitedTable { std::vector visited; diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index 95be4254dc..0562d1dd89 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -316,3 +316,14 @@ def deserialize_index_binary(data): reader = VectorIOReader() copy_array_to_vector(data, reader.data) return read_index_binary(reader) + + +class TimeoutGuard: + def __init__(self, timeout_in_seconds: float): + self.timeout = timeout_in_seconds + + def __enter__(self): + TimeoutCallback.reset(self.timeout) + + def __exit__(self, exc_type, exc_value, traceback): + PythonInterruptCallback.reset() diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 5c9a7b3fa7..85e04d322c 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -1041,7 +1041,9 @@ PyObject *swig_ptr (PyObject *a) PyErr_SetString(PyExc_ValueError, "did not recognize array type"); return NULL; } +%} +%inline %{ struct PythonInterruptCallback: faiss::InterruptCallback { @@ -1056,15 +1058,18 @@ struct PythonInterruptCallback: faiss::InterruptCallback { return err == -1; } + static void reset() { + faiss::InterruptCallback::instance.reset(new PythonInterruptCallback()); + } }; + %} %init %{ /* needed, else crash at runtime */ import_array(); - faiss::InterruptCallback::instance.reset(new PythonInterruptCallback()); - + PythonInterruptCallback::reset(); %} // return a pointer usable as input for functions that expect pointers diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 443195eecb..3980d7dd7c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -34,6 +34,7 @@ set(FAISS_TEST_SRC test_fastscan_perf.cpp test_disable_pq_sdc_tables.cpp test_common_ivf_empty_index.cpp + test_callback.cpp ) add_executable(faiss_test ${FAISS_TEST_SRC}) diff --git a/tests/test_callback.cpp b/tests/test_callback.cpp new file mode 100644 index 0000000000..cdfadf1d39 --- /dev/null +++ b/tests/test_callback.cpp @@ -0,0 +1,37 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +TEST(TestCallback, timeout) { + int n = 1000; + int k = 100; + int d = 128; + int niter = 1000000000; + int seed = 42; + + std::vector vecs(n * d); + faiss::float_rand(vecs.data(), vecs.size(), seed); + + auto index(new faiss::IndexFlat(d)); + + faiss::ClusteringParameters cp; + cp.niter = niter; + cp.verbose = false; + + faiss::Clustering kmeans(d, k, cp); + + faiss::TimeoutCallback::reset(0.010); + EXPECT_THROW(kmeans.train(n, vecs.data(), *index), faiss::FaissException); + delete index; +} diff --git a/tests/test_callback_py.py b/tests/test_callback_py.py new file mode 100644 index 0000000000..0ec176dd86 --- /dev/null +++ b/tests/test_callback_py.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +import numpy as np +import faiss + + +class TestCallbackPy(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def test_timeout(self) -> None: + n = 1000 + k = 100 + d = 128 + niter = 1_000_000_000 + + x = np.random.rand(n, d).astype('float32') + index = faiss.IndexFlat(d) + + cp = faiss.ClusteringParameters() + cp.niter = niter + cp.verbose = False + + kmeans = faiss.Clustering(d, k, cp) + + with self.assertRaises(RuntimeError): + with faiss.TimeoutGuard(0.010): + kmeans.train(x, index)