From ce8bc5971c1a4f0a4dcfdfcc146c2e688d189643 Mon Sep 17 00:00:00 2001 From: Amir Sadoughi Date: Mon, 6 May 2024 13:14:40 -0700 Subject: [PATCH] TimeoutCallback C++ (#3397) Summary: https://github.com/facebookresearch/faiss/issues/3351 Reviewed By: junjieqi Differential Revision: D56732720 --- faiss/gpu/perf/PerfClustering.cpp | 6 +++++ faiss/impl/AuxIndexStructures.cpp | 19 +++++++++++++++ faiss/impl/AuxIndexStructures.h | 7 ++++++ tests/CMakeLists.txt | 1 + tests/test_callback.cpp | 39 +++++++++++++++++++++++++++++++ 5 files changed, 72 insertions(+) create mode 100644 tests/test_callback.cpp diff --git a/faiss/gpu/perf/PerfClustering.cpp b/faiss/gpu/perf/PerfClustering.cpp index 0322f0e490..532557fe20 100644 --- a/faiss/gpu/perf/PerfClustering.cpp +++ b/faiss/gpu/perf/PerfClustering.cpp @@ -17,6 +17,7 @@ #include #include +#include DEFINE_int32(num, 10000, "# of vecs"); DEFINE_int32(k, 100, "# of clusters"); @@ -34,6 +35,7 @@ DEFINE_int64( "minimum size to use CPU -> GPU paged copies"); DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use"); DEFINE_int32(max_points, -1, "max points per centroid"); +DEFINE_double(timeout, 0, "timeout in seconds"); using namespace faiss::gpu; @@ -99,10 +101,14 @@ int main(int argc, char** argv) { cp.max_points_per_centroid = FLAGS_max_points; } + auto tc = new faiss::TimeoutCallback(); + faiss::InterruptCallback::instance.reset(tc); + faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp); // Time k-means { + tc->set_timeout(FLAGS_timeout); CpuTimer timer; kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex())); diff --git a/faiss/impl/AuxIndexStructures.cpp b/faiss/impl/AuxIndexStructures.cpp index cebe8a1e23..01c7dd5267 100644 --- a/faiss/impl/AuxIndexStructures.cpp +++ b/faiss/impl/AuxIndexStructures.cpp @@ -236,4 +236,23 @@ size_t InterruptCallback::get_period_hint(size_t flops) { return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1); } +void TimeoutCallback::set_timeout(double timeout_in_seconds) { + timeout = timeout_in_seconds; + start = std::chrono::steady_clock::now(); +} + +bool TimeoutCallback::want_interrupt() { + if (timeout == 0) { + return false; + } + auto end = std::chrono::steady_clock::now(); + std::chrono::duration duration = end - start; + float elapsed_in_seconds = duration.count() / 1000.0; + if (elapsed_in_seconds > timeout) { + timeout = 0; + return true; + } + return false; +} + } // namespace faiss diff --git a/faiss/impl/AuxIndexStructures.h b/faiss/impl/AuxIndexStructures.h index f8b5cca842..5dc15eae46 100644 --- a/faiss/impl/AuxIndexStructures.h +++ b/faiss/impl/AuxIndexStructures.h @@ -161,6 +161,13 @@ struct FAISS_API InterruptCallback { static size_t get_period_hint(size_t flops); }; +struct TimeoutCallback : InterruptCallback { + std::chrono::time_point start; + double timeout; + bool want_interrupt() override; + void set_timeout(double timeout_in_seconds); +}; + /// set implementation optimized for fast access. struct VisitedTable { std::vector visited; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 443195eecb..3980d7dd7c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -34,6 +34,7 @@ set(FAISS_TEST_SRC test_fastscan_perf.cpp test_disable_pq_sdc_tables.cpp test_common_ivf_empty_index.cpp + test_callback.cpp ) add_executable(faiss_test ${FAISS_TEST_SRC}) diff --git a/tests/test_callback.cpp b/tests/test_callback.cpp new file mode 100644 index 0000000000..2fc7d6715c --- /dev/null +++ b/tests/test_callback.cpp @@ -0,0 +1,39 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +TEST(TestCallback, timeout) { + int n = 1000; + int k = 100; + int d = 128; + int niter = 1000000000; + int seed = 42; + + std::vector vecs(n * d); + faiss::float_rand(vecs.data(), vecs.size(), seed); + + auto index(new faiss::IndexFlat(d)); + + faiss::ClusteringParameters cp; + cp.niter = niter; + cp.verbose = false; + + faiss::Clustering kmeans(d, k, cp); + + auto tc(new faiss::TimeoutCallback()); + faiss::InterruptCallback::instance.reset(tc); + tc->set_timeout(0.010); + EXPECT_THROW(kmeans.train(n, vecs.data(), *index), faiss::FaissException); + delete index; +}