Skip to content

Commit

Permalink
Improve: Expose pointers for USearch
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Oct 4, 2023
1 parent a594758 commit d8acf83
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 16 deletions.
2 changes: 1 addition & 1 deletion include/simsimd/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ typedef union {
* @brief Computes `1/sqrt(x)` using the trick from Quake 3, replacing
* magic numbers with the ones suggested by Jan Kadlec.
*/
inline static float simsimd_approximate_inverse_square_root(float number) {
inline static simsimd_f32_t simsimd_approximate_inverse_square_root(simsimd_f32_t number) {
simsimd_f32i32_t conv;
conv.f = number;
conv.i = 0x5F1FFFF9 - (conv.i >> 1);
Expand Down
78 changes: 65 additions & 13 deletions python/lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ simsimd_datatype_t numpy_string_to_datatype(char const* name) {
return simsimd_datatype_unknown_k;
}

simsimd_datatype_t python_string_to_datatype(char const* name) {
if (same_string(name, "f") || same_string(name, "f32"))
return simsimd_datatype_f32_k;
else if (same_string(name, "h") || same_string(name, "f16"))
return simsimd_datatype_f16_k;
else if (same_string(name, "c") || same_string(name, "i8"))
return simsimd_datatype_i8_k;
else if (same_string(name, "b") || same_string(name, "b1"))
return simsimd_datatype_b1_k;
else if (same_string(name, "d") || same_string(name, "f64"))
return simsimd_datatype_f64_k;
else
return simsimd_datatype_unknown_k;
}

static PyObject* api_get_capabilities(PyObject* self) {
simsimd_capability_t caps = simsimd_capabilities();
PyObject* cap_dict = PyDict_New();
Expand Down Expand Up @@ -209,25 +224,62 @@ static PyObject* api_pairs(simsimd_metric_kind_t metric_kind, PyObject* args) {

static PyObject* api_combos(simsimd_metric_punned_t metric, PyObject* args) {}

static PyObject* api_pairs_l2sq(PyObject* self, PyObject* args) { return api_pairs(simsimd_metric_l2sq_k, args); }
static PyObject* api_pairs_cos(PyObject* self, PyObject* args) { return api_pairs(simsimd_metric_cos_k, args); }
static PyObject* api_pairs_ip(PyObject* self, PyObject* args) { return api_pairs(simsimd_metric_ip_k, args); }
static PyObject* api_pointer(simsimd_metric_kind_t metric_kind, PyObject* args) {
char const* type_name = PyUnicode_AsUTF8(PyTuple_GetItem(args, 0));
if (!type_name) {
PyErr_SetString(PyExc_ValueError, "Invalid type name");
return NULL;
}

static PyMethodDef simsimd_methods[] = {
// NumPy and SciPy compatible interfaces
{"sqeuclidean", api_pairs_l2sq, METH_VARARGS, "L2sq (Squared Euclidean) distances between a pair of tensors"},
{"cosine", api_pairs_cos, METH_VARARGS, "Cosine (Angular) distances between a pair of tensors"},
{"dot", api_pairs_ip, METH_VARARGS, "Inner (Dot) Product distances between a pair of tensors"},
simsimd_datatype_t datatype = python_string_to_datatype(type_name);
if (!type_name) {
PyErr_SetString(PyExc_ValueError, "Unsupported type");
return NULL;
}

simsimd_metric_punned_t metric = simsimd_metric_punned(metric_kind, datatype, 0xFFFFFFFF);
if (metric == NULL) {
PyErr_SetString(PyExc_ValueError, "No such metric");
return NULL;
}

return PyLong_FromUnsignedLongLong((unsigned long long)metric);
}

static PyObject* api_l2sq_pointer(PyObject* self, PyObject* args) { return api_pointer(simsimd_metric_l2sq_k, args); }
static PyObject* api_cos_pointer(PyObject* self, PyObject* args) { return api_pointer(simsimd_metric_cos_k, args); }
static PyObject* api_ip_pointer(PyObject* self, PyObject* args) { return api_pointer(simsimd_metric_ip_k, args); }

static PyObject* api_l2sq(PyObject* self, PyObject* args) { return api_pairs(simsimd_metric_l2sq_k, args); }
static PyObject* api_cos(PyObject* self, PyObject* args) { return api_pairs(simsimd_metric_cos_k, args); }
static PyObject* api_ip(PyObject* self, PyObject* args) { return api_pairs(simsimd_metric_ip_k, args); }

static PyMethodDef simsimd_methods[] = {
// Introspecting library and hardware capabilities
{"get_capabilities", api_get_capabilities, METH_NOARGS, "Get hardware capabilities"},

// NumPy and SciPy compatible interfaces (two matrix or vector arguments)
{"sqeuclidean", api_l2sq, METH_VARARGS, "L2sq (Sq. Euclidean) distances between a pair of matrices"},
{"cosine", api_cos, METH_VARARGS, "Cosine (Angular) distances between a pair of matrices"},
{"inner", api_ip, METH_VARARGS, "Inner (Dot) Product distances between a pair of matrices"},

// Compute distance between each pair of the two collections of inputs (two matrix arguments)
{"cdist_sqeuclidean", api_l2sq, METH_VARARGS, "L2sq (Sq. Euclidean) distances between every pair of vectors"},
{"cdist_cosine", api_cos, METH_VARARGS, "Cosine (Angular) distances between every pair of vectors"},
{"cdist_inner", api_ip, METH_VARARGS, "Inner (Dot) Product distances between every pair of vectors"},

// Pairwise distances between observations in n-dimensional space (single matrix argument)
{"pdist_sqeuclidean", api_l2sq, METH_VARARGS, "L2sq (Sq. Euclidean) distances between every pair of vectors"},
{"pdist_cosine", api_cos, METH_VARARGS, "Cosine (Angular) distances between every pair of vectors"},
{"pdist_inner", api_ip, METH_VARARGS, "Inner (Dot) Product distances between every pair of vectors"},

// Exposing underlying API for USearch
{"get_sqeuclidean_address", api_get_sqeuclidean, METH_NOARGS, "L2sq (Squared Euclidean) function pointer as `int`"},
{"get_cosine_address", api_get_cosine, METH_NOARGS, "L2sq (Squared Euclidean) function pointer as `int`"},
{"get_dot_address", api_get_dot, METH_NOARGS, "L2sq (Squared Euclidean) function pointer as `int`"},
{NULL, NULL, 0, NULL} /* Sentinel */
};
{"pointer_to_sqeuclidean", api_l2sq_pointer, METH_VARARGS, "L2sq (Sq. Euclidean) function pointer as `int`"},
{"pointer_to_cosine", api_cos_pointer, METH_VARARGS, "Cosine (Angular) function pointer as `int`"},
{"pointer_to_inner", api_ip_pointer, METH_VARARGS, "Inner (Dot) Product function pointer as `int`"},

// Sentinel
{NULL, NULL, 0, NULL}};

static PyModuleDef simsimd_module = {
PyModuleDef_HEAD_INIT,
Expand Down
19 changes: 17 additions & 2 deletions python/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,30 @@
from scipy.spatial.distance import cosine, sqeuclidean


def test_pointers_availability():
"""Tests the availability of pre-compiled functions for compatibility with USearch."""
assert simd.pointer_to_sqeuclidean("f32") != 0
assert simd.pointer_to_cosine("f32") != 0
assert simd.pointer_to_inner("f32") != 0

assert simd.pointer_to_sqeuclidean("f16") != 0
assert simd.pointer_to_cosine("f16") != 0
assert simd.pointer_to_inner("f16") != 0

assert simd.pointer_to_sqeuclidean("i8") != 0
assert simd.pointer_to_cosine("i8") != 0
assert simd.pointer_to_inner("i8") != 0


@pytest.mark.parametrize("ndim", [3, 97, 1536])
@pytest.mark.parametrize("dtype", [np.float32, np.float16])
def test_dot(ndim, dtype):
"""Compares the simd.dot() function with numpy.dot(), measuring the accuracy error for f16, and f32 types."""
a = np.random.randn(ndim).astype(dtype)
b = np.random.randn(ndim).astype(dtype)

expected = 1 - np.dot(a, b)
result = simd.dot(a, b)
expected = 1 - np.inner(a, b)
result = simd.inner(a, b)

np.testing.assert_allclose(expected, result, rtol=1e-2)

Expand Down

0 comments on commit d8acf83

Please sign in to comment.