Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: Covariance for Arm Neon #139

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions include/simsimd/probability.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,34 @@ extern "C" {
*/
SIMSIMD_PUBLIC void simsimd_kl_f64_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_js_f64_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_cov_f64_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result);
SIMSIMD_PUBLIC void simsimd_kl_f32_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_js_f32_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_cov_f32_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result);
SIMSIMD_PUBLIC void simsimd_kl_f16_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_js_f16_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_cov_f16_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result);

/* Double-precision serial backends for all numeric types.
* For single-precision computation check out the "*_serial" counterparts of those "*_accurate" functions.
*/
SIMSIMD_PUBLIC void simsimd_kl_f32_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_js_f32_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_cov_f32_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result);
SIMSIMD_PUBLIC void simsimd_kl_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_js_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_cov_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result);

/* SIMD-powered backends for Arm NEON, mostly using 32-bit arithmetic over 128-bit words.
* By far the most portable backend, covering most Arm v8 devices, over a billion phones, and almost all
* server CPUs produced before 2023.
*/
SIMSIMD_PUBLIC void simsimd_kl_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_cov_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result);
SIMSIMD_PUBLIC void simsimd_kl_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence);
SIMSIMD_PUBLIC void simsimd_cov_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result);

/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer, using 32-bit arithmetic over 256-bit words.
* First demonstrated in 2011, at least one Haswell-based processor was still being sold in 2022 — the Pentium G3420.
Expand Down Expand Up @@ -106,20 +113,48 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_
*result = (simsimd_distance_t)d / 2; \
}

#define SIMSIMD_MAKE_COV(name, input_type, accumulator_type, converter) \
SIMSIMD_PUBLIC void simsimd_cov_##input_type##_##name(simsimd_##input_type##_t const* a, \
simsimd_##input_type##_t const* b, simsimd_size_t n, \
simsimd_distance_t* result) { \
simsimd_##accumulator_type##_t mean_a = 0; \
simsimd_##accumulator_type##_t mean_b = 0; \
simsimd_##accumulator_type##_t d = 0; \
for (simsimd_size_t i = 0; i != n; ++i) { \
simsimd_##accumulator_type##_t ai = converter(a[i]); \
simsimd_##accumulator_type##_t bi = converter(b[i]); \
mean_a += ai; \
mean_b += bi; \
} \
mean_a /= n; \
mean_b /= n; \
for (simsimd_size_t i = 0; i != n; ++i) { \
simsimd_##accumulator_type##_t ai = converter(a[i]); \
simsimd_##accumulator_type##_t bi = converter(b[i]); \
d += (ai - mean_a) * (bi - mean_b); \
} \
*result = (simsimd_distance_t)d / (n - 1); \
}

SIMSIMD_MAKE_KL(serial, f64, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f64_serial
SIMSIMD_MAKE_JS(serial, f64, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f64_serial
SIMSIMD_MAKE_COV(serial, f64, f64, SIMSIMD_IDENTIFY) // simsimd_cov_f64_serial

SIMSIMD_MAKE_KL(serial, f32, f32, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f32_serial
SIMSIMD_MAKE_JS(serial, f32, f32, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f32_serial
SIMSIMD_MAKE_COV(serial, f32, f32, SIMSIMD_IDENTIFY) // simsimd_cov_f32_serial

SIMSIMD_MAKE_KL(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f16_serial
SIMSIMD_MAKE_JS(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f16_serial
SIMSIMD_MAKE_COV(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16) // simsimd_cov_f16_serial

SIMSIMD_MAKE_KL(accurate, f32, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f32_accurate
SIMSIMD_MAKE_JS(accurate, f32, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f32_accurate
SIMSIMD_MAKE_COV(accurate, f32, f64, SIMSIMD_IDENTIFY) // simsimd_cov_f32_accurate

SIMSIMD_MAKE_KL(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f16_accurate
SIMSIMD_MAKE_JS(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f16_accurate
SIMSIMD_MAKE_COV(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16) // simsimd_cov_f16_accurate

#if SIMSIMD_TARGET_ARM
#if SIMSIMD_TARGET_NEON
Expand Down Expand Up @@ -201,6 +236,46 @@ SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const* a, simsimd_f32_t co
*result = sum;
}

SIMSIMD_PUBLIC void simsimd_cov_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result) {
// Compuging mean of a and b
float32x4_t sum_vec_a = vdupq_n_f32(0);
float32x4_t sum_vec_b = vdupq_n_f32(0);
simsimd_size_t i = 0;
for (; i + 4 <= n; i += 4) {
float32x4_t a_vec = vld1q_f32(a + i);
float32x4_t b_vec = vld1q_f32(b + i);
sum_vec_a = vaddq_f32(sum_vec_a, a_vec);
sum_vec_b = vaddq_f32(sum_vec_b, b_vec);
}
simsimd_f32_t sum_a = vaddvq_f32(sum_vec_a);
simsimd_f32_t sum_b = vaddvq_f32(sum_vec_b);
for (; i < n; ++i) {
sum_a += a[i];
sum_b += b[i];
}
simsimd_f32_t mean_a = sum_a / n;
simsimd_f32_t mean_b = sum_b / n;
float32x4_t mean_a_vec = vdupq_n_f32(mean_a);
float32x4_t mean_b_vec = vdupq_n_f32(mean_b);

// Computing covariance
float32x4_t sum_vec = vdupq_n_f32(0);
i = 0;
for (; i + 4 <= n; i += 4) {
float32x4_t a_vec = vld1q_f32(a + i);
float32x4_t b_vec = vld1q_f32(b + i);
float32x4_t prod_vec = vmulq_f32(vsubq_f32(a_vec, mean_a_vec), vsubq_f32(b_vec, mean_b_vec));
sum_vec = vaddq_f32(sum_vec, prod_vec);
}
simsimd_f32_t sum = vaddvq_f32(sum_vec);
for (; i < n; ++i) {
simsimd_f32_t prod = (a[i] - mean_a) * (b[i] - mean_b);
sum += prod;
}
simsimd_f32_t cov = sum / (n - 1);
*result = cov;
}

#pragma clang attribute pop
#pragma GCC pop_options

Expand Down Expand Up @@ -260,6 +335,46 @@ SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const* a, simsimd_f16_t co
*result = sum;
}

SIMSIMD_PUBLIC void simsimd_cov_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result) {
// Compuging mean of a and b
float32x4_t sum_vec_a = vdupq_n_f32(0);
float32x4_t sum_vec_b = vdupq_n_f32(0);
simsimd_size_t i = 0;
for (; i + 4 <= n; i += 4) {
float32x4_t a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a + i));
float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b + i));
sum_vec_a = vaddq_f32(sum_vec_a, a_vec);
sum_vec_b = vaddq_f32(sum_vec_b, b_vec);
}
simsimd_f32_t sum_a = vaddvq_f32(sum_vec_a);
simsimd_f32_t sum_b = vaddvq_f32(sum_vec_b);
for (; i < n; ++i) {
sum_a += SIMSIMD_UNCOMPRESS_F16(a[i]);
sum_b += SIMSIMD_UNCOMPRESS_F16(b[i]);
}
simsimd_f32_t mean_a = sum_a / n;
simsimd_f32_t mean_b = sum_b / n;
float32x4_t mean_a_vec = vdupq_n_f32(mean_a);
float32x4_t mean_b_vec = vdupq_n_f32(mean_b);

// Computing covariance
float32x4_t sum_vec = vdupq_n_f32(0);
i = 0;
for (; i + 4 <= n; i += 4) {
float32x4_t a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a + i));
float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b + i));
float32x4_t prod_vec = vmulq_f32(vsubq_f32(a_vec, mean_a_vec), vsubq_f32(b_vec, mean_b_vec));
sum_vec = vaddq_f32(sum_vec, prod_vec);
}
simsimd_f32_t sum = vaddvq_f32(sum_vec);
for (; i < n; ++i) {
simsimd_f32_t prod = (SIMSIMD_UNCOMPRESS_F16(a[i]) - mean_a) * (SIMSIMD_UNCOMPRESS_F16(b[i]) - mean_b);
sum += prod;
}
simsimd_f32_t cov = sum / (n - 1);
*result = cov;
}

#pragma clang attribute pop
#pragma GCC pop_options
#endif // SIMSIMD_TARGET_NEON
Expand Down
7 changes: 7 additions & 0 deletions include/simsimd/simsimd.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ typedef enum {
simsimd_metric_js_k = 's', ///< Jensen-Shannon divergence
simsimd_metric_jensen_shannon_k = 's', ///< Jensen-Shannon divergence alias

simsimd_metric_cov_k = 'r', ///< Covariance
simsimd_metric_covariance_k = 'r' ///< Covariance alias

} simsimd_metric_kind_t;

/**
Expand Down Expand Up @@ -318,6 +321,7 @@ SIMSIMD_PUBLIC void simsimd_find_metric_punned( //
case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_neon, *c = simsimd_cap_neon_k; return;
case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_neon, *c = simsimd_cap_neon_k; return;
case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_neon, *c = simsimd_cap_neon_k; return;
case simsimd_metric_cov_k: *m = (m_t)&simsimd_cov_f32_neon, *c = simsimd_cap_neon_k; return;
default: break;
}
#endif
Expand All @@ -339,6 +343,7 @@ SIMSIMD_PUBLIC void simsimd_find_metric_punned( //
case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_serial, *c = simsimd_cap_serial_k; return;
case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_serial, *c = simsimd_cap_serial_k; return;
case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_serial, *c = simsimd_cap_serial_k; return;
case simsimd_metric_cov_k: *m = (m_t)&simsimd_cov_f32_serial, *c = simsimd_cap_serial_k; return;
default: break;
}

Expand All @@ -364,6 +369,7 @@ SIMSIMD_PUBLIC void simsimd_find_metric_punned( //
case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_neon, *c = simsimd_cap_neon_k; return;
case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_neon, *c = simsimd_cap_neon_k; return;
case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_neon, *c = simsimd_cap_neon_k; return;
case simsimd_metric_cov_k: *m = (m_t)&simsimd_cov_f16_neon, *c = simsimd_cap_neon_k; return;
default: break;
}
#endif
Expand Down Expand Up @@ -397,6 +403,7 @@ SIMSIMD_PUBLIC void simsimd_find_metric_punned( //
case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_serial, *c = simsimd_cap_serial_k; return;
case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_serial, *c = simsimd_cap_serial_k; return;
case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_serial, *c = simsimd_cap_serial_k; return;
case simsimd_metric_cov_k: *m = (m_t)&simsimd_cov_f16_serial, *c = simsimd_cap_serial_k; return;
default: break;
}

Expand Down
8 changes: 8 additions & 0 deletions python/lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ simsimd_metric_kind_t python_string_to_metric_kind(char const* name) {
return simsimd_metric_js_k;
else if (same_string(name, "jaccard"))
return simsimd_metric_jaccard_k;
else if (same_string(name, "covariance") || same_string(name, "cov"))
return simsimd_metric_cov_k;
else
return simsimd_metric_unknown_k;
}
Expand Down Expand Up @@ -629,6 +631,7 @@ static PyObject* api_cos_pointer(PyObject* self, PyObject* args) { return impl_p
static PyObject* api_dot_pointer(PyObject* self, PyObject* args) { return impl_pointer(simsimd_metric_dot_k, args); }
static PyObject* api_kl_pointer(PyObject* self, PyObject* args) { return impl_pointer(simsimd_metric_kl_k, args); }
static PyObject* api_js_pointer(PyObject* self, PyObject* args) { return impl_pointer(simsimd_metric_js_k, args); }
static PyObject* api_cov_pointer(PyObject* self, PyObject* args) { return impl_pointer(simsimd_metric_cov_k, args); }
static PyObject* api_hamming_pointer(PyObject* self, PyObject* args) {
return impl_pointer(simsimd_metric_hamming_k, args);
}
Expand All @@ -654,6 +657,9 @@ static PyObject* api_kl(PyObject* self, PyObject* const* args, Py_ssize_t nargs)
static PyObject* api_js(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
return impl_metric(simsimd_metric_js_k, args, nargs);
}
static PyObject* api_cov(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
return impl_metric(simsimd_metric_cov_k, args, nargs);
}
static PyObject* api_hamming(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
return impl_metric(simsimd_metric_hamming_k, args, nargs);
}
Expand All @@ -677,6 +683,7 @@ static PyMethodDef simsimd_methods[] = {
{"jaccard", api_jaccard, METH_FASTCALL, "Jaccard (Bitwise Tanimoto) distances between a pair of matrices"},
{"kullbackleibler", api_kl, METH_FASTCALL, "Kullback-Leibler divergence between probability distributions"},
{"jensenshannon", api_js, METH_FASTCALL, "Jensen-Shannon divergence between probability distributions"},
{"covariance", api_cov, METH_FASTCALL, "Covariance between a pair of samples"},

// Conventional `cdist` and `pdist` insterfaces with third string argument, and optional `threads` arg
{"cdist", api_cdist, METH_VARARGS | METH_KEYWORDS,
Expand All @@ -688,6 +695,7 @@ static PyMethodDef simsimd_methods[] = {
{"pointer_to_inner", api_dot_pointer, METH_VARARGS, "Inner (Dot) Product function pointer as `int`"},
{"pointer_to_kullbackleibler", api_dot_pointer, METH_VARARGS, "Kullback-Leibler function pointer as `int`"},
{"pointer_to_jensenshannon", api_dot_pointer, METH_VARARGS, "Jensen-Shannon function pointer as `int`"},
{"pointer_to_covariance", api_cov_pointer, METH_VARARGS, "Covariance function pointer as `int`"},

// Sentinel
{NULL, NULL, 0, NULL}};
Expand Down
14 changes: 14 additions & 0 deletions python/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,20 @@ def test_jensen_shannon(ndim, dtype):

np.testing.assert_allclose(expected, result, atol=SIMSIMD_ATOL, rtol=0)

@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["float32", "float16"])
def test_covariance(ndim, dtype):
"""Compares the simd.covariance() function with numpy.cov(), measuring the accuracy error for f16, and f32 types."""
np.random.seed()
a = np.random.randn(ndim).astype(dtype)
b = np.random.randn(ndim).astype(dtype)

expected = np.cov(a, b)[0, 1]
result = simd.covariance(a, b)

np.testing.assert_allclose(expected, result, atol=SIMSIMD_ATOL, rtol=0)

@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
Expand Down