Skip to content

Commit

Permalink
Implemented algorithms for K-Nearest Neighbor Search (KNN) (#2554)
Browse files Browse the repository at this point in the history
* Extended API with functions for vector similarity based on KD-trees https://en.wikipedia.org/wiki/K-d_tree

ndpi_kd_tree* ndpi_kd_create(u_int num_dimensions);
void ndpi_kd_free(ndpi_kd_tree *tree);
void ndpi_kd_clear(ndpi_kd_tree *tree);
bool ndpi_kd_insert(ndpi_kd_tree *tree, const double *data_vector, void *user_data);
ndpi_kd_tree_result *ndpi_kd_nearest(ndpi_kd_tree *tree, const double *data_vector);
u_int32_t ndpi_kd_num_results(ndpi_kd_tree_result *res);
bool ndpi_kd_result_end(ndpi_kd_tree_result *res);
double* ndpi_kd_result_get_item(ndpi_kd_tree_result *res, double **user_data);
bool ndpi_kd_result_next(ndpi_kd_tree_result *res);
void ndpi_kd_result_free(ndpi_kd_tree_result *res);
double ndpi_kd_distance(double *a1, double *b2, u_int num_dimensions);
  • Loading branch information
lucaderi authored Sep 10, 2024
1 parent f4d2002 commit 7fdc4b2
Show file tree
Hide file tree
Showing 9 changed files with 1,698 additions and 13 deletions.
86 changes: 86 additions & 0 deletions example/ndpiReader.c
Original file line number Diff line number Diff line change
Expand Up @@ -6113,6 +6113,91 @@ void loadStressTest() {

/* *********************************************** */

void kdUnitTest() {
ndpi_kd_tree *t = ndpi_kd_create(5);
double v[][5] = {
{ 0, 4, 2, 3, 4 },
{ 0, 1, 2, 3, 6 },
{ 1, 2, 3, 4, 5 },
};
double v1[5] = { 0, 1, 2, 3, 8 };
u_int i, sz = 5*sizeof(double), num = sizeof(v) / sz;
ndpi_kd_tree_result *res;
double *ret, *to_find = v[1];

assert(t);

for(i=0; i<num; i++)
assert(ndpi_kd_insert(t, v[i], NULL) == true);

assert((res = ndpi_kd_nearest(t, to_find)) != NULL);
assert(ndpi_kd_num_results(res) == 1);
assert((ret = ndpi_kd_result_get_item(res, NULL)) != NULL);
assert(memcmp(ret, to_find, sz) == 0);
ndpi_kd_result_free(res);

assert((res = ndpi_kd_nearest(t, v1)) != NULL);
assert(ndpi_kd_num_results(res) == 1);
assert((ret = ndpi_kd_result_get_item(res, NULL)) != NULL);
assert(memcmp(ret, v1, sz) != 0);
assert(ndpi_kd_distance(ret, v1, 5) == 4.);
ndpi_kd_result_free(res);

ndpi_kd_free(t);
}

/* *********************************************** */

void ballTreeUnitTest() {
ndpi_btree *ball_tree;
double v[][5] = {
{ 0, 4, 2, 3, 4 },
{ 0, 1, 2, 3, 6 },
{ 1, 2, 3, 4, 5 },
};
double v1[] = { 0, 1, 2, 3, 8 };
double *rows[] = { v[0], v[1], v[2] };
double *q_rows[] = { v1 };
u_int32_t num_columns = 5;
u_int32_t num_rows = sizeof(v) / (sizeof(double)*num_columns);
ndpi_knn result;
u_int32_t nun_results = 2;
int i, j;

ball_tree = ndpi_btree_init(rows, num_rows, num_columns);
assert(ball_tree != NULL);
result = ndpi_btree_query(ball_tree, q_rows,
sizeof(q_rows) / sizeof(double*),
num_columns, nun_results);

assert(result.n_samples == 2);

for (i = 0; i < result.n_samples; i++) {
printf("{\"knn_idx\": [");
for (j = 0; j < result.n_neighbors; j++)
{
printf("%d", result.indices[i][j]);
if (j != result.n_neighbors - 1)
printf(", ");
}
printf("],\n \"knn_dist\": [");
for (j = 0; j < result.n_neighbors; j++)
{
printf("%.12lf", result.distances[i][j]);
if (j != result.n_neighbors - 1)
printf(", ");
}
printf("]\n}\n");
if (i != result.n_samples - 1)
printf(", ");
}

ndpi_free_knn(result);
ndpi_free_btree(ball_tree);
}

/* *********************************************** */

void encodeDomainsUnitTest() {
NDPI_PROTOCOL_BITMASK all;
struct ndpi_detection_module_struct *ndpi_str = ndpi_init_detection_module(NULL);
Expand Down Expand Up @@ -6290,6 +6375,7 @@ int main(int argc, char **argv) {
exit(0);
#endif

kdUnitTest();
encodeDomainsUnitTest();
loadStressTest();
domainsUnitTest();
Expand Down
75 changes: 64 additions & 11 deletions src/include/ndpi_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ extern "C" {
*
*/
extern u_int16_t ndpi_get_proto_by_name(struct ndpi_detection_module_struct *ndpi_mod, const char *name);

/**
* Return the name of the protocol given its ID
*
Expand All @@ -716,8 +716,8 @@ extern "C" {
* @par id = the protocol id
* @return the name of the protocol
*
*/
extern ndpi_master_app_protocol ndpi_get_protocol_by_name(struct ndpi_detection_module_struct *ndpi_str, const char *name);
*/
extern ndpi_master_app_protocol ndpi_get_protocol_by_name(struct ndpi_detection_module_struct *ndpi_str, const char *name);

/**
* Return the ID of the category
Expand Down Expand Up @@ -1059,7 +1059,7 @@ extern "C" {
bool ndpi_is_proto(ndpi_master_app_protocol proto, u_int16_t p);
bool ndpi_is_proto_unknown(ndpi_master_app_protocol proto);
bool ndpi_is_proto_equals(ndpi_master_app_protocol to_check, ndpi_master_app_protocol to_match, bool exact_match_only);

ndpi_proto_defaults_t* ndpi_get_proto_defaults(struct ndpi_detection_module_struct *ndpi_mod);
u_int ndpi_get_ndpi_num_supported_protocols(struct ndpi_detection_module_struct *ndpi_mod);
u_int ndpi_get_ndpi_num_custom_protocols(struct ndpi_detection_module_struct *ndpi_mod);
Expand Down Expand Up @@ -1753,7 +1753,7 @@ extern "C" {

void ndpi_md5(const u_char *data, size_t data_len, u_char hash[16]);
void ndpi_sha256(const u_char *data, size_t data_len, u_int8_t sha_hash[32]);

u_int16_t ndpi_crc16_ccit(const void* data, size_t n_bytes);
u_int16_t ndpi_crc16_ccit_false(const void *data, size_t n_bytes);
u_int16_t ndpi_crc16_xmodem(const void *data, size_t n_bytes);
Expand Down Expand Up @@ -1865,6 +1865,59 @@ extern "C" {

/* ******************************* */

/* create a kd-tree for num_dimensions vector items */
ndpi_kd_tree* ndpi_kd_create(u_int num_dimensions);

/* free the ndpi_kd_tree */
void ndpi_kd_free(ndpi_kd_tree *tree);

/* remove all the elements from the tree */
void ndpi_kd_clear(ndpi_kd_tree *tree);

/* insert a node, specifying its position, and optional data.
Return true = OK, false otherwise
*/
bool ndpi_kd_insert(ndpi_kd_tree *tree, const double *data_vector, void *user_data);

/* Find the nearest node from a given point.
* This function returns a pointer to a result set with at most one element.
*/
ndpi_kd_tree_result *ndpi_kd_nearest(ndpi_kd_tree *tree, const double *data_vector);

/* returns the size of the result set (in elements) */
u_int32_t ndpi_kd_num_results(ndpi_kd_tree_result *res);

/* returns the current element and updates user_data with the data put during insert */
double* ndpi_kd_result_get_item(ndpi_kd_tree_result *res, double **user_data);

/* frees a result set returned by kd_nearest_range() */
void ndpi_kd_result_free(ndpi_kd_tree_result *res);

/* Returns the distance (square root of the individual elements difference) */
double ndpi_kd_distance(double *a1, double *b2, u_int num_dimensions);

/* ******************************* */

/*
Ball Tree: similar to KD-tree but more efficient with high cardinalities
- https://en.wikipedia.org/wiki/Ball_tree
- https://www.geeksforgeeks.org/ball-tree-and-kd-tree-algorithms/
- https://varshasaini.in/kd-tree-and-ball-tree-knn-algorithm/
- https://varshasaini.in/k-nearest-neighbor-knn-algorithm-in-machine-learning/
NOTE:
with ball tree, data is a vector of vector pointers (no array)
*/
ndpi_btree* ndpi_btree_init(double **data, u_int32_t n_rows, u_int32_t n_columns);
ndpi_knn ndpi_btree_query(ndpi_btree *b, double **query_data,
u_int32_t query_data_num_rows, u_int32_t query_data_num_columns,
u_int32_t max_num_results);
void ndpi_free_knn(ndpi_knn knn);
void ndpi_free_btree(ndpi_btree *tree);

/* ******************************* */

/*
* Finds outliers using Z-score
* Z-Score = (Value - Mean) / StdDev
Expand Down Expand Up @@ -1912,9 +1965,9 @@ extern "C" {
*
*/
double ndpi_pearson_correlation(u_int32_t *values_a, u_int32_t *values_b, u_int16_t num_values);

/* ******************************* */

/*
* Checks if a specified value is an outlier with respect to past values
* using the Z-score.
Expand All @@ -1926,12 +1979,12 @@ extern "C" {
* t = 1 - The value to check should not exceed the past values
* t > 1 - The value to check has to be within (t * stddev) boundaries
* @par lower - [out] Lower threshold
* @par upper - [out] Upper threshold
* @par upper - [out] Upper threshold
*
* @return true if the specified value is an outlier, false otherwise
*
*/

bool ndpi_is_outlier(u_int32_t *past_values, u_int32_t num_past_values,
u_int32_t value_to_check, float threshold,
float *lower, float *upper);
Expand Down Expand Up @@ -2113,7 +2166,7 @@ extern "C" {
ndpi_domain_classify *s,
u_int16_t *class_id /* out */,
char *hostname);

/* ******************************* */

/*
Expand Down Expand Up @@ -2251,7 +2304,7 @@ extern "C" {
*/
u_int ndpi_encode_domain(struct ndpi_detection_module_struct *ndpi_str,
char *domain, char *out, u_int out_len);

/* ******************************* */

const char *ndpi_lru_cache_idx_to_name(lru_cache_type idx);
Expand Down
16 changes: 16 additions & 0 deletions src/include/ndpi_typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,22 @@ struct ndpi_bin {

/* **************************************** */

/* Implemented in third_party/src/kdtree.c */
typedef void ndpi_kd_tree;
typedef void ndpi_kd_tree_result;

/* Implemented in third_party/src/ball.c */
typedef void ndpi_btree;

typedef struct {
double **distances;
int **indices;
int n_samples;
int n_neighbors;
} ndpi_knn;

/* **************************************** */

#define HW_HISTORY_LEN 4
#define MAX_SQUARE_ERROR_ITERATIONS 64 /* MUST be < num_values_rollup (256 max) */

Expand Down
67 changes: 65 additions & 2 deletions src/lib/ndpi_analyze.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
#include "ndpi_api.h"
#include "ndpi_config.h"
#include "third_party/include/hll.h"

#include "third_party/include/kdtree.h"
#include "third_party/include/ball.h"
#include "ndpi_replace_printf.h"

/* ********************************************************************************* */
Expand Down Expand Up @@ -1916,7 +1917,7 @@ u_int32_t ndpi_crc32(const void *data, size_t length, u_int32_t crc)
{
const u_int8_t *p = (const u_int8_t*)data;
crc = ~crc;

while (length--)
{
crc = crc32_ieee_table[(crc ^ *p++) & 0xFF] ^ (crc >> 8);
Expand Down Expand Up @@ -2078,3 +2079,65 @@ void ndpi_popcount_count(struct ndpi_popcount *h, const u_int8_t *buf, u_int32_t

h->tot_bytes_count += buf_len;
}

/* ********************************************************************************* */
/* ********************************************************************************* */

ndpi_kd_tree* ndpi_kd_create(u_int num_dimensions) { return(kd_create((int)num_dimensions)); }

void ndpi_kd_free(ndpi_kd_tree *tree) { kd_free((struct kdtree *)tree); }

void ndpi_kd_clear(ndpi_kd_tree *tree) { kd_clear((struct kdtree *)tree); }

bool ndpi_kd_insert(ndpi_kd_tree *tree, const double *data_vector, void *user_data) {
return(kd_insert((struct kdtree *)tree, data_vector, user_data) == 0 ? true : false);
}

ndpi_kd_tree_result *ndpi_kd_nearest(ndpi_kd_tree *tree, const double *data_vector) {
return(kd_nearest((struct kdtree *)tree, data_vector));
}

u_int32_t ndpi_kd_num_results(ndpi_kd_tree_result *res) { return((u_int32_t)kd_res_size((struct kdres*)res)); }

double* ndpi_kd_result_get_item(ndpi_kd_tree_result *res, double **user_data) {
return(kd_res_item((struct kdres*)res, user_data));
}

void ndpi_kd_result_free(ndpi_kd_tree_result *res) { kd_res_free((struct kdres *)res); }

double ndpi_kd_distance(double *a1, double *a2, u_int num_dimensions) {
double dist_sq = 0, diff;
u_int i;

for(i=0; i<num_dimensions; i++) {
diff = a1[i] - a2[i];

#if 0
if(diff != 0) {
printf("Difference %.3f at position %u\n", diff, pos);
}
#endif
dist_sq += diff*diff;
}

return(dist_sq);
}

/* ********************************************************************************* */
/* ********************************************************************************* */

ndpi_btree* ndpi_btree_init(double **data, u_int32_t n_rows, u_int32_t n_columns) {
return((ndpi_btree*)btree_init(data, (int)n_rows, (int)n_columns, 30));
}

ndpi_knn ndpi_btree_query(ndpi_btree *b, double **query_data,
u_int32_t query_data_num_rows, u_int32_t query_data_num_columns,
u_int32_t max_num_results) {
return(btree_query((t_btree*)b, query_data, (int)query_data_num_rows,
(int)query_data_num_columns, (int)max_num_results));
}

void ndpi_free_knn(ndpi_knn knn) { free_knn(knn, knn.n_samples); }

void ndpi_free_btree(ndpi_btree *b) { free_tree((t_btree*)b); }

Loading

0 comments on commit 7fdc4b2

Please sign in to comment.