Skip to content

Commit

Permalink
TimeoutCallback Python
Browse files Browse the repository at this point in the history
Summary: facebookresearch#3351

Differential Revision: D56856744
  • Loading branch information
Amir Sadoughi authored and facebook-github-bot committed May 3, 2024
1 parent 73c2bd5 commit 15e055c
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 5 deletions.
6 changes: 6 additions & 0 deletions faiss/impl/AuxIndexStructures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,10 @@ bool TimeoutCallback::want_interrupt() {
return false;
}

void TimeoutCallback::reset(double timeout_in_seconds) {
auto tc(new faiss::TimeoutCallback());
faiss::InterruptCallback::instance.reset(tc);
tc->set_timeout(timeout_in_seconds);
}

} // namespace faiss
1 change: 1 addition & 0 deletions faiss/impl/AuxIndexStructures.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ struct TimeoutCallback : InterruptCallback {
double timeout;
bool want_interrupt() override;
void set_timeout(double timeout_in_seconds);
static void reset(double timeout_in_seconds);
};

/// set implementation optimized for fast access.
Expand Down
11 changes: 11 additions & 0 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,14 @@ def deserialize_index_binary(data):
reader = VectorIOReader()
copy_array_to_vector(data, reader.data)
return read_index_binary(reader)


class TimeoutGuard:
def __init__(self, timeout_in_seconds: float):
self.timeout = timeout_in_seconds

def __enter__(self):
TimeoutCallback.reset(self.timeout)

def __exit__(self, exc_type, exc_value, traceback):
PythonInterruptCallback.reset()
9 changes: 7 additions & 2 deletions faiss/python/swigfaiss.swig
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,9 @@ PyObject *swig_ptr (PyObject *a)
PyErr_SetString(PyExc_ValueError, "did not recognize array type");
return NULL;
}
%}

%inline %{

struct PythonInterruptCallback: faiss::InterruptCallback {

Expand All @@ -1056,15 +1058,18 @@ struct PythonInterruptCallback: faiss::InterruptCallback {
return err == -1;
}

static void reset() {
faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());
}
};

%}

%init %{
/* needed, else crash at runtime */
import_array();

faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());

PythonInterruptCallback::reset();
%}

// return a pointer usable as input for functions that expect pointers
Expand Down
4 changes: 1 addition & 3 deletions tests/test_callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ TEST(TestCallback, timeout) {

faiss::Clustering kmeans(d, k, cp);

auto tc(new faiss::TimeoutCallback());
faiss::InterruptCallback::instance.reset(tc);
tc->set_timeout(0.010);
faiss::TimeoutCallback::reset(0.010);
EXPECT_THROW(kmeans.train(n, vecs.data(), *index), faiss::FaissException);
delete index;
}
32 changes: 32 additions & 0 deletions tests/test_callback_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import numpy as np
import faiss


class TestCallbackPy(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def test_timeout(self) -> None:
n = 1_000_000
k = 100
d = 128
niter = 10

x = np.random.rand(n, d).astype('float32')
index = faiss.IndexFlat(d)

cp = faiss.ClusteringParameters()
cp.niter = niter
cp.verbose = False

kmeans = faiss.Clustering(d, k, cp)

with self.assertRaises(RuntimeError):
with faiss.TimeoutGuard(0.010):
kmeans.train(x, index)

0 comments on commit 15e055c

Please sign in to comment.