Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TimeoutCallback Python #3413

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions faiss/gpu/perf/PerfClustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <vector>

#include <cuda_profiler_api.h>
#include <faiss/impl/AuxIndexStructures.h>

DEFINE_int32(num, 10000, "# of vecs");
DEFINE_int32(k, 100, "# of clusters");
Expand All @@ -34,6 +35,7 @@ DEFINE_int64(
"minimum size to use CPU -> GPU paged copies");
DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use");
DEFINE_int32(max_points, -1, "max points per centroid");
DEFINE_double(timeout, 0, "timeout in seconds");

using namespace faiss::gpu;

Expand Down Expand Up @@ -99,10 +101,14 @@ int main(int argc, char** argv) {
cp.max_points_per_centroid = FLAGS_max_points;
}

auto tc = new faiss::TimeoutCallback();
faiss::InterruptCallback::instance.reset(tc);

faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp);

// Time k-means
{
tc->set_timeout(FLAGS_timeout);
CpuTimer timer;

kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex()));
Expand Down
25 changes: 25 additions & 0 deletions faiss/impl/AuxIndexStructures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,29 @@ size_t InterruptCallback::get_period_hint(size_t flops) {
return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
}

void TimeoutCallback::set_timeout(double timeout_in_seconds) {
timeout = timeout_in_seconds;
start = std::chrono::steady_clock::now();
}

bool TimeoutCallback::want_interrupt() {
if (timeout == 0) {
return false;
}
auto end = std::chrono::steady_clock::now();
std::chrono::duration<float, std::milli> duration = end - start;
float elapsed_in_seconds = duration.count() / 1000.0;
if (elapsed_in_seconds > timeout) {
timeout = 0;
return true;
}
return false;
}

void TimeoutCallback::reset(double timeout_in_seconds) {
auto tc(new faiss::TimeoutCallback());
faiss::InterruptCallback::instance.reset(tc);
tc->set_timeout(timeout_in_seconds);
}

} // namespace faiss
8 changes: 8 additions & 0 deletions faiss/impl/AuxIndexStructures.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ struct FAISS_API InterruptCallback {
static size_t get_period_hint(size_t flops);
};

struct TimeoutCallback : InterruptCallback {
std::chrono::time_point<std::chrono::steady_clock> start;
double timeout;
bool want_interrupt() override;
void set_timeout(double timeout_in_seconds);
static void reset(double timeout_in_seconds);
};

/// set implementation optimized for fast access.
struct VisitedTable {
std::vector<uint8_t> visited;
Expand Down
11 changes: 11 additions & 0 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,14 @@ def deserialize_index_binary(data):
reader = VectorIOReader()
copy_array_to_vector(data, reader.data)
return read_index_binary(reader)


class TimeoutGuard:
def __init__(self, timeout_in_seconds: float):
self.timeout = timeout_in_seconds

def __enter__(self):
TimeoutCallback.reset(self.timeout)

def __exit__(self, exc_type, exc_value, traceback):
PythonInterruptCallback.reset()
9 changes: 7 additions & 2 deletions faiss/python/swigfaiss.swig
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,9 @@ PyObject *swig_ptr (PyObject *a)
PyErr_SetString(PyExc_ValueError, "did not recognize array type");
return NULL;
}
%}

%inline %{

struct PythonInterruptCallback: faiss::InterruptCallback {

Expand All @@ -1056,15 +1058,18 @@ struct PythonInterruptCallback: faiss::InterruptCallback {
return err == -1;
}

static void reset() {
faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());
}
};

%}

%init %{
/* needed, else crash at runtime */
import_array();

faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());

PythonInterruptCallback::reset();
%}

// return a pointer usable as input for functions that expect pointers
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ set(FAISS_TEST_SRC
test_fastscan_perf.cpp
test_disable_pq_sdc_tables.cpp
test_common_ivf_empty_index.cpp
test_callback.cpp
)

add_executable(faiss_test ${FAISS_TEST_SRC})
Expand Down
37 changes: 37 additions & 0 deletions tests/test_callback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>

#include <faiss/Clustering.h>
#include <faiss/IndexFlat.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissException.h>
#include <faiss/utils/random.h>

TEST(TestCallback, timeout) {
int n = 1000000;
int k = 100;
int d = 128;
int niter = 10;
int seed = 42;

std::vector<float> vecs(n * d);
faiss::float_rand(vecs.data(), vecs.size(), seed);

auto index(new faiss::IndexFlat(d));

faiss::ClusteringParameters cp;
cp.niter = niter;
cp.verbose = false;

faiss::Clustering kmeans(d, k, cp);

faiss::TimeoutCallback::reset(0.010);
EXPECT_THROW(kmeans.train(n, vecs.data(), *index), faiss::FaissException);
delete index;
}
32 changes: 32 additions & 0 deletions tests/test_callback_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import numpy as np
import faiss


class TestCallbackPy(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def test_timeout(self) -> None:
n = 1_000_000
k = 100
d = 128
niter = 10

x = np.random.rand(n, d).astype('float32')
index = faiss.IndexFlat(d)

cp = faiss.ClusteringParameters()
cp.niter = niter
cp.verbose = False

kmeans = faiss.Clustering(d, k, cp)

with self.assertRaises(RuntimeError):
with faiss.TimeoutGuard(0.010):
kmeans.train(x, index)