diff --git a/CMakeLists.txt b/CMakeLists.txt index 7fbcbdf087..73f59f453f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,10 +173,22 @@ endif () MESSAGE(STATUS "C++ Compilation flags: " ${CMAKE_CXX_FLAGS}) +if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") + set(ARM64 TRUE) +elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + set(ARM64 TRUE) +elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + set(X86_64 TRUE) +endif() + #add_definitions(-march=native) add_definitions(-DSIMDE_ENABLE_NATIVE_ALIASES) if (CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL "18.0") - add_definitions(-mevex512) + if(X86_64) + add_definitions(-mevex512) + else() + add_definitions(-march=native) + endif() endif () execute_process( diff --git a/example/functions.py b/example/functions.py index 0f590fe148..c275ed7c17 100644 --- a/example/functions.py +++ b/example/functions.py @@ -23,6 +23,8 @@ {"c1": 'test@gmail.com', "c2": 'email'}, {"c1": 'test@hotmail.com', "c2": 'email'}, {"c1": ' abc', "c2": 'abc'}, {"c1": 'abc ', "c2": 'abc'}, {"c1": ' abc ', "c2": 'abc'}]) +# varchar functions + #function char_length res = table_obj.output(["*", "char_length(c1)"]).filter("char_length(c1) = 1").to_df() print(res) @@ -74,6 +76,35 @@ res = table_obj.output(["*", "char_position(c1, 'bc')"]).filter("char_position(c1, c1) <> 0").to_df() print(res) +# math functions +db_obj.drop_table("function_example", ConflictType.Ignore) +db_obj.create_table("function_example", + {"c1": {"type": "integer"}, + "c2": {"type": "double"}}, ConflictType.Error) +table_obj = db_obj.get_table("function_example") +table_obj.insert( + [{"c1": 1, "c2": 2.4}, {"c1": 3, "c2": 4.5}, {"c1": 5, "c2": 6.6}, {"c1": 7, "c2": 8}, + {"c1": 9, "c2": 10}, {"c1": 11, "c2": 12}, {"c1": 13, "c2": 14}, {"c1": 15, "c2": 16},]) + +#function sqrt +res = table_obj.output(["*", "sqrt(c1)", "sqrt(c2)"]).to_df() +print(res) + +res = table_obj.output(["*", "sqrt(c1)", "sqrt(c2)"]).filter("sqrt(c1) = 3").to_df() +print(res) + +#function round +res = table_obj.output(["*", "round(c1)", "round(c2)"]).to_df() +print(res) + +#function ceiling +res = table_obj.output(["*", "ceil(c1)", "ceil(c2)"]).to_df() +print(res) + +#function floor +res = table_obj.output(["*", "floor(c1)", "floor(c2)"]).to_df() +print(res) + res = db_obj.drop_table("function_example") infinity_obj.disconnect() \ No newline at end of file diff --git a/example/http/functions.sh b/example/http/functions.sh index 164b31a03f..0541bec5dc 100644 --- a/example/http/functions.sh +++ b/example/http/functions.sh @@ -44,6 +44,10 @@ curl --request POST \ { "name": "tensor", "type": "tensor,4,float" + }, + { + "name": "decimal", + "type": "double" } ] } ' @@ -62,7 +66,8 @@ curl --request POST \ "vec": [1.0, 1.2, 0.8, 0.9], "sparse_column": {"10":1.1, "20":2.2, "30": 3.3}, "year": 2024, - "tensor": [[1.0, 0.0, 0.0, 0.0], [1.1, 0.0, 0.0, 0.0]] + "tensor": [[1.0, 0.0, 0.0, 0.0], [1.1, 0.0, 0.0, 0.0]], + "decimal": 1.4 }, { "num": 2, @@ -70,7 +75,8 @@ curl --request POST \ "vec": [4.0, 4.2, 4.3, 4.5], "sparse_column": {"40":4.4, "50":5.5, "60": 6.6}, "year": 2023, - "tensor": [[4.0, 0.0, 4.3, 4.5], [4.0, 4.2, 4.4, 5.0]] + "tensor": [[4.0, 0.0, 4.3, 4.5], [4.0, 4.2, 4.4, 5.0]], + "decimal": 1.5 }, { "num": 3, @@ -78,7 +84,8 @@ curl --request POST \ "vec": [4.0, 4.2, 4.3, 4.2], "sparse_column": {"70":7.7, "80":8.8, "90": 9.9}, "year": 2019, - "tensor": [[0.9, 0.1, 0.0, 0.0], [1.1, 0.0, 0.0, 0.0]] + "tensor": [[0.9, 0.1, 0.0, 0.0], [1.1, 0.0, 0.0, 0.0]], + "decimal": -1.4 }, { "num": 4, @@ -86,7 +93,8 @@ curl --request POST \ "vec": [4.0, 4.2, 4.3, 4.5], "sparse_column": {"20":7.7, "80":7.8, "90": 97.9}, "year": 2018, - "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]] + "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]], + "decimal": -1.5 }, { "num": 5, @@ -94,7 +102,8 @@ curl --request POST \ "vec": [4.0, 4.2, 4.3, 4.5], "sparse_column": {"20":7.7, "80":7.8, "90": 97.9}, "year": 2018, - "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]] + "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]], + "decimal": 1 }, { "num": 6, @@ -102,7 +111,8 @@ curl --request POST \ "vec": [4.0, 4.2, 4.3, 4.5], "sparse_column": {"20":7.7, "80":7.8, "90": 97.9}, "year": 2018, - "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]] + "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]], + "decimal": -1 }, { "num": 7, @@ -110,7 +120,8 @@ curl --request POST \ "vec": [4.0, 4.2, 4.3, 4.5], "sparse_column": {"20":7.7, "80":7.8, "90": 97.9}, "year": 2018, - "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]] + "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]], + "decimal": 0.4 }, { "num": 8, @@ -118,7 +129,8 @@ curl --request POST \ "vec": [4.0, 4.2, 4.3, 4.5], "sparse_column": {"20":7.7, "80":7.8, "90": 97.9}, "year": 2018, - "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]] + "tensor": [[5.0, 4.2, 4.3, 4.5], [4.0, 4.2, 4.3, 4.4]], + "decimal": 0.5 } ] ' @@ -316,6 +328,34 @@ curl --request GET \ "filter": "char_position(body, '123') = 1" } ' +# show rows of 'tbl1' with sqrt(num) +echo -e '\n\n-- show rows of 'tbl1' with sqrt(num)' +curl --request GET \ + --url http://localhost:23820/databases/default_db/tables/tbl1/docs \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --data ' + { + "output": + [ + "num", "sqrt(num)" + ] + } ' + +# show rows of 'tbl1' with decimal, round(decimal), ceil(decimal), floor(decimal) +echo -e '\n\n-- show rows of 'tbl1' with decimal, round(decimal), ceil(decimal), floor(decimal)' +curl --request GET \ + --url http://localhost:23820/databases/default_db/tables/tbl1/docs \ + --header 'accept: application/json' \ + --header 'content-type: application/json' \ + --data ' + { + "output": + [ + "decimal", "round(decimal)", "ceil(decimal)", "floor(decimal)" + ] + } ' + # drop tbl1 echo -e '\n\n-- drop tbl1' diff --git a/python/infinity_http.py b/python/infinity_http.py index 8896f50f53..e21fddcaa1 100644 --- a/python/infinity_http.py +++ b/python/infinity_http.py @@ -661,7 +661,7 @@ def to_df(self): for res in self.output_res: for k in res: - print(res[k]) + #print(res[k]) if k not in df_dict: df_dict[k] = () tup = df_dict[k] @@ -686,7 +686,7 @@ def to_df(self): new_tup = tup + (res[k],) df_dict[k] = new_tup # print(self.output_res) - # print(df_dict) + print(df_dict) df_type = {} for k in df_dict: @@ -694,19 +694,33 @@ def to_df(self): df_type[k] = type_to_dtype(col_types[k]) if k in ["DISTANCE", "SCORE", "SIMILARITY"]: df_type[k] = dtype('float32') - # "(c1 + c2)" - k1 = k.replace("(", "") - k1 = k1.replace(")", "") - cols = k1.split("+") + k1.split("-") # ["c1 ", " c2", "c1 + c2"] - # print(cols) - # haven't considered data type priority + # "(c1 + c2)", "sqrt(c1), round(c1)" + k1 = k.replace("(", " ") + k1 = k1.replace(")", " ") + k1 = k1.replace("+", " ") + k1 = k1.replace("-", " ") + cols = k1.split(" ") + #print(cols) + + function_name = "" for col in cols: + #print(function_name) if col.strip() in col_types: df_type[k] = type_to_dtype(col_types[col.strip()]) - if col.strip().isdigit(): + df_type[k] = function_return_type(function_name, df_type[k]) + elif col.strip().isdigit() and df_type[k] != dtype('float64'): df_type[k] = dtype('int32') - if is_float(col.strip()): + df_type[k] = function_return_type(function_name, df_type[k]) + elif is_float(col.strip()): df_type[k] = dtype('float64') + df_type[k] = function_return_type(function_name, df_type[k]) + else: + function_name = col.strip().lower() + if (function_name in functions): + df_type[k] = function_return_type(function_name, None) + if (function_name in bool_functions): + df_type[k] = dtype('bool') + break return pd.DataFrame(df_dict).astype(df_type) def to_arrow(self): diff --git a/python/test_pysdk/common/common_data.py b/python/test_pysdk/common/common_data.py index 080a230eec..ca705a6389 100644 --- a/python/test_pysdk/common/common_data.py +++ b/python/test_pysdk/common/common_data.py @@ -37,6 +37,27 @@ "double", "varchar", "boolean" ] +functions = [ + "sqrt", "round", "ceil", "floor", "filter_text", "filter_fulltext", "or", "and", "not" +] + +bool_functions = [ + "filter_text", "filter_fulltext", "or", "and", "not" +] + +def function_return_type(function_name, param_type) : + if function_name == "sqrt": + return dtype('float64') + elif function_name == "round" or function_name == "ceil" or function_name == "floor": + if(param_type == dtype('int8') or param_type == dtype('int16') or param_type == dtype('int32') or param_type == dtype('int64')): + return param_type + else: + return dtype('float64') + elif function_name == "filter_text" or function_name == "filter_fulltext" or function_name == "or" or function_name == "and" or function_name == "not": + return dtype('bool') + else: + return param_type + unsupport_output = ["_similarity", "_row_id", "_score", "_distance"] type_transfrom = { diff --git a/python/test_pysdk/test_select.py b/python/test_pysdk/test_select.py index 009a93bb25..850fd842d7 100644 --- a/python/test_pysdk/test_select.py +++ b/python/test_pysdk/test_select.py @@ -943,4 +943,56 @@ def test_select_position(self, suffix): .astype({'c1': dtype('O'), 'c2': dtype('O')})) res = db_obj.drop_table("test_select_position"+suffix) + assert res.error_code == ErrorCode.OK + + def test_select_sqrt(self, suffix): + db_obj = self.infinity_obj.get_database("default_db") + db_obj.drop_table("test_select_sqrt"+suffix, ConflictType.Ignore) + db_obj.create_table("test_select_sqrt"+suffix, + {"c1": {"type": "integer"}, + "c2": {"type": "double"}}, ConflictType.Error) + table_obj = db_obj.get_table("test_select_sqrt"+suffix) + table_obj.insert( + [{"c1": '1', "c2": '2'}, {"c1": '4', "c2": '5'}, {"c1": '9', "c2": '10'}, {"c1": '16', "c2": '17'}]) + + res = table_obj.output(["*", "sqrt(c1)", "sqrt(c2)"]).to_df() + print(res) + + res = table_obj.output(["*"]).filter("sqrt(c1) = 2").to_df() + pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (4,), + 'c2': (5,)}) + .astype({'c1': dtype('int32'), 'c2': dtype('double')})) + + res = db_obj.drop_table("test_select_sqrt"+suffix) + assert res.error_code == ErrorCode.OK + + def test_select_round(self, suffix): + db_obj = self.infinity_obj.get_database("default_db") + db_obj.drop_table("test_select_round"+suffix, ConflictType.Ignore) + db_obj.create_table("test_select_round"+suffix, + {"c1": {"type": "integer"}, + "c2": {"type": "double"}}, ConflictType.Error) + table_obj = db_obj.get_table("test_select_round"+suffix) + table_obj.insert( + [{"c1": '1', "c2": '2.4'}, {"c1": '4', "c2": '-2.4'}, {"c1": '9', "c2": '2.5'}, {"c1": '16', "c2": '-2.5'}]) + + res = table_obj.output(["c1", "round(c2)"]).to_df() + print(res) + pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (1, 4, 9, 16), + 'round(c2)': (2, -2, 3, -3)}) + .astype({'c1': dtype('int32'), 'round(c2)': dtype('double')})) + + res = table_obj.output(["c1", "ceil(c2)"]).to_df() + print(res) + pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (1, 4, 9, 16), + 'ceil(c2)': (3, -2, 3, -2)}) + .astype({'c1': dtype('int32'), 'ceil(c2)': dtype('double')})) + + res = table_obj.output(["c1", "floor(c2)"]).to_df() + print(res) + pd.testing.assert_frame_equal(res, pd.DataFrame({'c1': (1, 4, 9, 16), + 'floor(c2)': (2, -3, 2, -3)}) + .astype({'c1': dtype('int32'), 'floor(c2)': dtype('double')})) + + res = db_obj.drop_table("test_select_round"+suffix) assert res.error_code == ErrorCode.OK \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 22a3864ce3..f643115533 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -18,47 +18,52 @@ add_subdirectory(parser) # add_definitions(-msse4.2 -mfma) # add_definitions(-mavx2 -mf16c -mpopcnt) + if(APPLE) - execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep FMA" - RESULT_VARIABLE SUPPORT_FMA - OUTPUT_QUIET - ERROR_QUIET) - - execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep SSE4.2" - RESULT_VARIABLE SUPPORT_SSE42 - OUTPUT_QUIET - ERROR_QUIET) - - execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep AVX2" - RESULT_VARIABLE SUPPORT_AVX2 - OUTPUT_QUIET - ERROR_QUIET) - - execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep AVX512" - RESULT_VARIABLE SUPPORT_AVX512 - OUTPUT_QUIET - ERROR_QUIET) + if(X86_64) + execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep FMA" + RESULT_VARIABLE SUPPORT_FMA + OUTPUT_QUIET + ERROR_QUIET) + + execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep SSE4.2" + RESULT_VARIABLE SUPPORT_SSE42 + OUTPUT_QUIET + ERROR_QUIET) + + execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep AVX2" + RESULT_VARIABLE SUPPORT_AVX2 + OUTPUT_QUIET + ERROR_QUIET) + + execute_process(COMMAND sh -c "sysctl -a machdep.cpu.features | grep AVX512" + RESULT_VARIABLE SUPPORT_AVX512 + OUTPUT_QUIET + ERROR_QUIET) + endif() else() #Linux - execute_process(COMMAND grep -q fma /proc/cpuinfo - RESULT_VARIABLE SUPPORT_FMA - OUTPUT_QUIET - ERROR_QUIET) - - execute_process(COMMAND grep -q sse4_2 /proc/cpuinfo - RESULT_VARIABLE SUPPORT_SSE42 - OUTPUT_QUIET - ERROR_QUIET) - - execute_process(COMMAND grep -q avx2 /proc/cpuinfo - RESULT_VARIABLE SUPPORT_AVX2 - OUTPUT_QUIET - ERROR_QUIET) - - execute_process(COMMAND grep -q avx512 /proc/cpuinfo - RESULT_VARIABLE SUPPORT_AVX512 - OUTPUT_QUIET - ERROR_QUIET) + if(X86_64) + execute_process(COMMAND grep -q fma /proc/cpuinfo + RESULT_VARIABLE SUPPORT_FMA + OUTPUT_QUIET + ERROR_QUIET) + + execute_process(COMMAND grep -q sse4_2 /proc/cpuinfo + RESULT_VARIABLE SUPPORT_SSE42 + OUTPUT_QUIET + ERROR_QUIET) + + execute_process(COMMAND grep -q avx2 /proc/cpuinfo + RESULT_VARIABLE SUPPORT_AVX2 + OUTPUT_QUIET + ERROR_QUIET) + + execute_process(COMMAND grep -q avx512 /proc/cpuinfo + RESULT_VARIABLE SUPPORT_AVX512 + OUTPUT_QUIET + ERROR_QUIET) + endif() endif() @@ -283,22 +288,24 @@ target_include_directories(infinity_core PUBLIC "${CMAKE_SOURCE_DIR}/third_party target_include_directories(infinity_core PUBLIC "${CMAKE_BINARY_DIR}/third_party/pcre2") -if (NOT SUPPORT_FMA EQUAL 0) - message(FATAL_ERROR "This project requires the processor support fused multiply-add (FMA) instructions.") -endif () - -if (NOT SUPPORT_SSE42 EQUAL 0) - message(FATAL_ERROR "This project requires the processor support sse4_2 instructions.") -endif () - -if (SUPPORT_AVX2 EQUAL 0 OR SUPPORT_AVX512 EQUAL 0) - message("Compiled by AVX2 or AVX512") - add_definitions(-march=native) - target_compile_options(infinity_core PRIVATE $<$:-march=native>) -else () - message("Compiled by SSE") - add_definitions(-msse4.2 -mfma) - target_compile_options(infinity_core PRIVATE $<$:-msse4.2 -mfma>) +if(X86_64) + if (NOT SUPPORT_FMA EQUAL 0) + message(FATAL_ERROR "This project requires the processor support fused multiply-add (FMA) instructions.") + endif () + + if (NOT SUPPORT_SSE42 EQUAL 0) + message(FATAL_ERROR "This project requires the processor support sse4_2 instructions.") + endif () + + if (SUPPORT_AVX2 EQUAL 0 OR SUPPORT_AVX512 EQUAL 0) + message("Compiled by AVX2 or AVX512") + add_definitions(-march=native) + target_compile_options(infinity_core PRIVATE $<$:-march=native>) + else () + message("Compiled by SSE") + add_definitions(-msse4.2 -mfma) + target_compile_options(infinity_core PRIVATE $<$:-msse4.2 -mfma>) + endif () endif () add_executable(infinity @@ -622,12 +629,17 @@ target_include_directories(unit_test PUBLIC "${CMAKE_BINARY_DIR}/third_party/pcr # target_compile_options(unit_test PRIVATE $<$:-mavx2 -mfma -mf16c -mpopcnt>) -if (SUPPORT_AVX2 EQUAL 0 OR SUPPORT_AVX512 EQUAL 0) - message("Compiled by AVX2 or AVX512") - add_definitions(-mavx2 -mfma -mf16c -mpopcnt) - target_compile_options(unit_test PRIVATE $<$:-mavx2 -mfma -mf16c -mpopcnt>) -else () - message("Compiled by SSE") - add_definitions(-msse4.2 -mfma) - target_compile_options(unit_test PRIVATE $<$:-msse4.2 -mfma>) +if(X86_64) + if (SUPPORT_AVX2 EQUAL 0 OR SUPPORT_AVX512 EQUAL 0) + message("Compiled by AVX2 or AVX512") + add_definitions(-mavx2 -mfma -mf16c -mpopcnt) + target_compile_options(unit_test PRIVATE $<$:-mavx2 -mfma -mf16c -mpopcnt>) + else () + message("Compiled by SSE") + add_definitions(-msse4.2 -mfma) + target_compile_options(unit_test PRIVATE $<$:-msse4.2 -mfma>) + endif () +else() + add_definitions(-march=native) endif () + diff --git a/src/common/simd/diskann_simd_func.cppm b/src/common/simd/diskann_simd_func.cppm index f07746418a..f1e46180e9 100644 --- a/src/common/simd/diskann_simd_func.cppm +++ b/src/common/simd/diskann_simd_func.cppm @@ -24,6 +24,24 @@ export module diskann_simd_func; namespace infinity { +#if defined(__aarch64__) +inline float hsum256_ps_avx(__m256 v) { + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v)); + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + return _mm_cvtss_f32(x32); +} + +inline float hsum_ps_sse1(__m128 v) { // v = [ D C | B A ] + __m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1)); // [ C D | A B ] + __m128 sums = _mm_add_ps(v, shuf); // sums = [ D+C C+D | B+A A+B ] + shuf = _mm_movehl_ps(shuf, sums); // [ C D | D+C C+D ] // let the + // compiler avoid a mov by reusing shuf + sums = _mm_add_ss(sums, shuf); + return _mm_cvtss_f32(sums); +} +#endif + export float hsumFloatVec(const float* array, size_t size) { float sum = 0.0f; size_t i = 0; diff --git a/src/common/simd/search_top_1_sgemm.cppm b/src/common/simd/search_top_1_sgemm.cppm index 99c30214eb..3065312754 100644 --- a/src/common/simd/search_top_1_sgemm.cppm +++ b/src/common/simd/search_top_1_sgemm.cppm @@ -199,8 +199,8 @@ void inner_search_top_1_with_sgemm_sse2(u32 dimension, u32 x_id = i + x_part_begin; float *ip_line = x_y_inner_product_buffer.get() + i * y_part_size; - _mm_prefetch(ip_line, _MM_HINT_NTA); - _mm_prefetch(ip_line + 8, _MM_HINT_NTA); + _mm_prefetch((const char *)(ip_line), _MM_HINT_NTA); + _mm_prefetch((const char *)(ip_line + 8), _MM_HINT_NTA); const __m128 mul_minus2 = _mm_set1_ps(-2); @@ -214,8 +214,8 @@ void inner_search_top_1_with_sgemm_sse2(u32 dimension, u32 j = 0; for (; j < (y_part_size / 8) * 8; j += 8, ip_line += 8) { u32 j_id = j + y_part_begin; - _mm_prefetch(ip_line + 16, _MM_HINT_NTA); - _mm_prefetch(ip_line + 24, _MM_HINT_NTA); + _mm_prefetch((const char *)(ip_line + 16), _MM_HINT_NTA); + _mm_prefetch((const char *)(ip_line + 24), _MM_HINT_NTA); __m128 y_norm_0 = _mm_loadu_ps(square_y.get() + j_id); __m128 y_norm_1 = _mm_loadu_ps(square_y.get() + j_id + 4); diff --git a/src/common/simd/search_top_k_sgemm.cppm b/src/common/simd/search_top_k_sgemm.cppm index 635fd3b1db..f1692d6905 100644 --- a/src/common/simd/search_top_k_sgemm.cppm +++ b/src/common/simd/search_top_k_sgemm.cppm @@ -176,8 +176,8 @@ void inner_search_top_k_with_sgemm_sse2(u32 k, u32 x_id = i + x_part_begin; float *ip_line = x_y_inner_product_buffer.get() + i * y_part_size; - _mm_prefetch(ip_line, _MM_HINT_NTA); - _mm_prefetch(ip_line + 8, _MM_HINT_NTA); + _mm_prefetch((const char*)ip_line, _MM_HINT_NTA); + _mm_prefetch((const char*)(ip_line + 8), _MM_HINT_NTA); const __m128 x_norm = _mm_set1_ps(square_x[x_id]); const __m128 mul_minus2 = _mm_set1_ps(-2); @@ -189,8 +189,8 @@ void inner_search_top_k_with_sgemm_sse2(u32 k, u32 j = 0; for (; j < (y_part_size / 8) * 8; j += 8, ip_line += 8) { u32 j_id = j + y_part_begin; - _mm_prefetch(ip_line + 16, _MM_HINT_NTA); - _mm_prefetch(ip_line + 24, _MM_HINT_NTA); + _mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA); + _mm_prefetch((const char*)(ip_line + 24), _MM_HINT_NTA); const __m128 y_norm_0 = _mm_loadu_ps(square_y.get() + j_id + 0); const __m128 y_norm_1 = _mm_loadu_ps(square_y.get() + j_id + 4); diff --git a/src/function/builtin_functions.cpp b/src/function/builtin_functions.cpp index e9ba363fbc..8fed8c84b2 100644 --- a/src/function/builtin_functions.cpp +++ b/src/function/builtin_functions.cpp @@ -27,6 +27,10 @@ import sum; import add; import abs; +import sqrt; +import round; +import ceil; +import floor; import and_func; import divide; import equals; @@ -101,8 +105,11 @@ void BuiltinFunctions::RegisterScalarFunction() { // Math functions RegisterAbsFunction(catalog_ptr_); - RegisterPowFunction(catalog_ptr_); + RegisterSqrtFunction(catalog_ptr_); + RegisterRoundFunction(catalog_ptr_); + RegisterCeilFunction(catalog_ptr_); + RegisterFloorFunction(catalog_ptr_); // register comparison operator RegisterEqualsFunction(catalog_ptr_); diff --git a/src/function/scalar/ceil.cpp b/src/function/scalar/ceil.cpp new file mode 100644 index 0000000000..7bb2968b5c --- /dev/null +++ b/src/function/scalar/ceil.cpp @@ -0,0 +1,94 @@ +module; + +#include + +module ceil; + +import stl; +import catalog; +import logical_type; +import infinity_exception; +import scalar_function; +import scalar_function_set; + +import third_party; +import internal_types; +import data_type; + +namespace infinity { + +struct CeilFunctionInt { + template + static inline void Run(SourceType value, TargetType &result) { + result = value; + } +}; + +struct CeilFunctionFloat { + template + static inline bool Run(SourceType value, TargetType &result) { + result = ceil(static_cast(value)); + if (std::isnan(result) || std::isinf(result)) { + return false; + } + return true; + } +}; + +void RegisterCeilFunction(const UniquePtr &catalog_ptr) { + String func_name = "ceil"; + + SharedPtr function_set_ptr = MakeShared(func_name); + + ScalarFunction Ceil_int8(func_name, + {DataType(LogicalType::kTinyInt)}, + DataType(LogicalType::kTinyInt), + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(Ceil_int8); + + ScalarFunction Ceil_int16(func_name, + {DataType(LogicalType::kSmallInt)}, + {DataType(LogicalType::kSmallInt)}, + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(Ceil_int16); + + ScalarFunction Ceil_int32(func_name, + {DataType(LogicalType::kInteger)}, + {DataType(LogicalType::kInteger)}, + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(Ceil_int32); + + ScalarFunction Ceil_int64(func_name, + {DataType(LogicalType::kBigInt)}, + {DataType(LogicalType::kBigInt)}, + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(Ceil_int64); + + ScalarFunction Ceil_float(func_name, + {DataType(LogicalType::kFloat)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(Ceil_float); + + ScalarFunction Ceil_double(func_name, + {DataType(LogicalType::kDouble)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(Ceil_double); + + ScalarFunction Ceil_float16(func_name, + {DataType(LogicalType::kFloat16)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(Ceil_float16); + + ScalarFunction Ceil_bfloat16(func_name, + {DataType(LogicalType::kBFloat16)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(Ceil_bfloat16); + + Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr); +} + +} // namespace infinity \ No newline at end of file diff --git a/src/function/scalar/ceil.cppm b/src/function/scalar/ceil.cppm new file mode 100644 index 0000000000..039e385165 --- /dev/null +++ b/src/function/scalar/ceil.cppm @@ -0,0 +1,13 @@ +module; + +export module ceil; + +import stl; + +namespace infinity { + +class Catalog; + +export void RegisterCeilFunction(const UniquePtr &catalog_ptr); + +} \ No newline at end of file diff --git a/src/function/scalar/floor.cpp b/src/function/scalar/floor.cpp new file mode 100644 index 0000000000..211648de96 --- /dev/null +++ b/src/function/scalar/floor.cpp @@ -0,0 +1,94 @@ +module; + +#include + +module floor; + +import stl; +import catalog; +import logical_type; +import infinity_exception; +import scalar_function; +import scalar_function_set; + +import third_party; +import internal_types; +import data_type; + +namespace infinity { + +struct FloorFunctionInt { + template + static inline void Run(SourceType value, TargetType &result) { + result = value; + } +}; + +struct FloorFunctionFloat { + template + static inline bool Run(SourceType value, TargetType &result) { + result = floor(static_cast(value)); + if (std::isnan(result) || std::isinf(result)) { + return false; + } + return true; + } +}; + +void RegisterFloorFunction(const UniquePtr &catalog_ptr) { + String func_name = "floor"; + + SharedPtr function_set_ptr = MakeShared(func_name); + + ScalarFunction floor_int8(func_name, + {DataType(LogicalType::kTinyInt)}, + DataType(LogicalType::kTinyInt), + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(floor_int8); + + ScalarFunction floor_int16(func_name, + {DataType(LogicalType::kSmallInt)}, + {DataType(LogicalType::kSmallInt)}, + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(floor_int16); + + ScalarFunction floor_int32(func_name, + {DataType(LogicalType::kInteger)}, + {DataType(LogicalType::kInteger)}, + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(floor_int32); + + ScalarFunction floor_int64(func_name, + {DataType(LogicalType::kBigInt)}, + {DataType(LogicalType::kBigInt)}, + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(floor_int64); + + ScalarFunction floor_float(func_name, + {DataType(LogicalType::kFloat)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(floor_float); + + ScalarFunction floor_double(func_name, + {DataType(LogicalType::kDouble)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(floor_double); + + ScalarFunction floor_float16(func_name, + {DataType(LogicalType::kFloat16)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(floor_float16); + + ScalarFunction floor_bfloat16(func_name, + {DataType(LogicalType::kBFloat16)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(floor_bfloat16); + + Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr); +} + +} // namespace infinity \ No newline at end of file diff --git a/src/function/scalar/floor.cppm b/src/function/scalar/floor.cppm new file mode 100644 index 0000000000..0894ce5082 --- /dev/null +++ b/src/function/scalar/floor.cppm @@ -0,0 +1,13 @@ +module; + +export module floor; + +import stl; + +namespace infinity { + +class Catalog; + +export void RegisterFloorFunction(const UniquePtr &catalog_ptr); + +} \ No newline at end of file diff --git a/src/function/scalar/round.cpp b/src/function/scalar/round.cpp new file mode 100644 index 0000000000..36e340a989 --- /dev/null +++ b/src/function/scalar/round.cpp @@ -0,0 +1,94 @@ +module; + +#include + +module round; + +import stl; +import catalog; +import logical_type; +import infinity_exception; +import scalar_function; +import scalar_function_set; + +import third_party; +import internal_types; +import data_type; + +namespace infinity { + +struct RoundFunctionInt { + template + static inline void Run(SourceType value, TargetType &result) { + result = value; + } +}; + +struct RoundFunctionFloat { + template + static inline bool Run(SourceType value, TargetType &result) { + result = round(static_cast(value)); + if (std::isnan(result) || std::isinf(result)) { + return false; + } + return true; + } +}; + +void RegisterRoundFunction(const UniquePtr &catalog_ptr) { + String func_name = "round"; + + SharedPtr function_set_ptr = MakeShared(func_name); + + ScalarFunction round_int8(func_name, + {DataType(LogicalType::kTinyInt)}, + DataType(LogicalType::kTinyInt), + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(round_int8); + + ScalarFunction round_int16(func_name, + {DataType(LogicalType::kSmallInt)}, + {DataType(LogicalType::kSmallInt)}, + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(round_int16); + + ScalarFunction round_int32(func_name, + {DataType(LogicalType::kInteger)}, + {DataType(LogicalType::kInteger)}, + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(round_int32); + + ScalarFunction round_int64(func_name, + {DataType(LogicalType::kBigInt)}, + {DataType(LogicalType::kBigInt)}, + &ScalarFunction::UnaryFunction); + function_set_ptr->AddFunction(round_int64); + + ScalarFunction round_float(func_name, + {DataType(LogicalType::kFloat)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(round_float); + + ScalarFunction round_double(func_name, + {DataType(LogicalType::kDouble)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(round_double); + + ScalarFunction round_float16(func_name, + {DataType(LogicalType::kFloat16)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(round_float16); + + ScalarFunction round_bfloat16(func_name, + {DataType(LogicalType::kBFloat16)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(round_bfloat16); + + Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr); +} + +} // namespace infinity \ No newline at end of file diff --git a/src/function/scalar/round.cppm b/src/function/scalar/round.cppm new file mode 100644 index 0000000000..544de64af8 --- /dev/null +++ b/src/function/scalar/round.cppm @@ -0,0 +1,13 @@ +module; + +export module round; + +import stl; + +namespace infinity { + +class Catalog; + +export void RegisterRoundFunction(const UniquePtr &catalog_ptr); + +} \ No newline at end of file diff --git a/src/function/scalar/sqrt.cpp b/src/function/scalar/sqrt.cpp new file mode 100644 index 0000000000..72c37eb1be --- /dev/null +++ b/src/function/scalar/sqrt.cpp @@ -0,0 +1,87 @@ +module; + +#include + +module sqrt; + +import stl; +import catalog; +import logical_type; +import infinity_exception; +import scalar_function; +import scalar_function_set; + +import third_party; +import internal_types; +import data_type; + +namespace infinity { + +struct SqrtFunction { + template + static inline bool Run(SourceType value, TargetType &result) { + if (value < static_cast(0.0f)) { + return false; + } + result = sqrt(static_cast(value)); + return true; + } +}; + +void RegisterSqrtFunction(const UniquePtr &catalog_ptr) { + String func_name = "sqrt"; + + SharedPtr function_set_ptr = MakeShared(func_name); + + ScalarFunction sqrt_int8(func_name, + {DataType(LogicalType::kTinyInt)}, + DataType(LogicalType::kDouble), + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(sqrt_int8); + + ScalarFunction sqrt_int16(func_name, + {DataType(LogicalType::kSmallInt)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(sqrt_int16); + + ScalarFunction sqrt_int32(func_name, + {DataType(LogicalType::kInteger)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(sqrt_int32); + + ScalarFunction sqrt_int64(func_name, + {DataType(LogicalType::kBigInt)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(sqrt_int64); + + ScalarFunction sqrt_float(func_name, + {DataType(LogicalType::kFloat)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(sqrt_float); + + ScalarFunction sqrt_double(func_name, + {DataType(LogicalType::kDouble)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(sqrt_double); + + ScalarFunction sqrt_float16(func_name, + {DataType(LogicalType::kFloat16)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(sqrt_float16); + + ScalarFunction sqrt_bfloat16(func_name, + {DataType(LogicalType::kBFloat16)}, + {DataType(LogicalType::kDouble)}, + &ScalarFunction::UnaryFunctionWithFailure); + function_set_ptr->AddFunction(sqrt_bfloat16); + + Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr); +} + +} // namespace infinity \ No newline at end of file diff --git a/src/function/scalar/sqrt.cppm b/src/function/scalar/sqrt.cppm new file mode 100644 index 0000000000..cf337e2387 --- /dev/null +++ b/src/function/scalar/sqrt.cppm @@ -0,0 +1,13 @@ +module; + +export module sqrt; + +import stl; + +namespace infinity { + +class Catalog; + +export void RegisterSqrtFunction(const UniquePtr &catalog_ptr); + +} \ No newline at end of file diff --git a/src/function/table/knn_scan_data.cpp b/src/function/table/knn_scan_data.cpp index 4d04611cf5..491d988dd9 100644 --- a/src/function/table/knn_scan_data.cpp +++ b/src/function/table/knn_scan_data.cpp @@ -38,6 +38,7 @@ namespace infinity { template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = GetSIMD_FUNCTIONS().L2Distance_func_ptr_; @@ -60,6 +61,7 @@ KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = GetSIMD_FUNCTIONS().HNSW_U8L2_ptr_; @@ -81,6 +83,7 @@ f32 hnsw_u8ip_f32_wrapper(const u8 *v1, const u8 *v2, SizeT dim) { return static template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = &hnsw_u8l2_f32_wrapper; @@ -103,6 +106,7 @@ KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = GetSIMD_FUNCTIONS().HNSW_I8L2_ptr_; @@ -124,6 +128,7 @@ f32 hnsw_i8ip_f32_wrapper(const i8 *v1, const i8 *v2, SizeT dim) { return static template <> KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { + dist_type_ = dist_type; switch (dist_type) { case KnnDistanceType::kL2: { dist_func_ = &hnsw_i8l2_f32_wrapper; diff --git a/src/function/table/knn_scan_data.cppm b/src/function/table/knn_scan_data.cppm index 686dde6a50..cf2130a3fb 100644 --- a/src/function/table/knn_scan_data.cppm +++ b/src/function/table/knn_scan_data.cppm @@ -104,6 +104,7 @@ public: using DistFunc = DistType (*)(const QueryDataType *, const QueryDataType *, SizeT); DistFunc dist_func_{}; + KnnDistanceType dist_type_{}; }; template <> diff --git a/src/storage/knn_index/knn_hnsw/data_store/sparse_vec_store.cppm b/src/storage/knn_index/knn_hnsw/data_store/sparse_vec_store.cppm index b55ed93293..bc4c184fa8 100644 --- a/src/storage/knn_index/knn_hnsw/data_store/sparse_vec_store.cppm +++ b/src/storage/knn_index/knn_hnsw/data_store/sparse_vec_store.cppm @@ -149,8 +149,8 @@ public: void Prefetch(SizeT idx, const Meta &meta) const { const SparseVecEle &vec = vecs_[idx]; - _mm_prefetch(vec.indices_.get(), _MM_HINT_T0); - _mm_prefetch(vec.data_.get(), _MM_HINT_T0); + _mm_prefetch((const char*)vec.indices_.get(), _MM_HINT_T0); + _mm_prefetch((const char*)vec.data_.get(), _MM_HINT_T0); } private: diff --git a/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp b/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp index 62bdb3dfd0..a92eeef079 100644 --- a/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp +++ b/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp @@ -271,7 +271,7 @@ void IVF_Index_Storage::SearchIndex(const KnnDistanceBase1 *knn_distance, search_top_k_with_dis(nprobe, dimension, 1, query_f32_ptr, centroids_num, centroids_data, nprobe_result.data(), centroid_dists.get(), false); } for (const auto part_id : nprobe_result) { - ivf_parts_storage_->SearchIndex(part_id, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func); + ivf_parts_storage_->SearchIndex(part_id, this, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func); } } diff --git a/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm b/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm index 91ad8849be..463225390e 100644 --- a/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm +++ b/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm @@ -29,6 +29,8 @@ namespace infinity { class LocalFileHandle; class KnnDistanceBase1; +export class IVF_Index_Storage; + // always use float for centroids class IVF_Centroids_Storage { u32 embedding_dimension_ = 0; @@ -68,6 +70,7 @@ public: AppendOneEmbedding(u32 part_id, const void *embedding_ptr, SegmentOffset segment_offset, const IVF_Centroids_Storage *ivf_centroids_storage) = 0; virtual void SearchIndex(u32 part_id, + const IVF_Index_Storage *ivf_index_storage, const KnnDistanceBase1 *knn_distance, const void *query_ptr, EmbeddingDataType query_element_type, @@ -75,7 +78,7 @@ public: const std::function &add_result_func) const = 0; }; -export class IVF_Index_Storage { +class IVF_Index_Storage { const IndexIVFOption ivf_option_ = {}; const LogicalType column_logical_type_ = LogicalType::kInvalid; const EmbeddingDataType embedding_data_type_ = EmbeddingDataType::kElemInvalid; @@ -94,6 +97,8 @@ public: [[nodiscard]] LogicalType column_logical_type() const { return column_logical_type_; } [[nodiscard]] EmbeddingDataType embedding_data_type() const { return embedding_data_type_; } [[nodiscard]] u32 embedding_dimension() const { return embedding_dimension_; } + [[nodiscard]] const IVF_Centroids_Storage &ivf_centroids_storage() const { return ivf_centroids_storage_; } + [[nodiscard]] const IVF_Parts_Storage &ivf_parts_storage() const { return *ivf_parts_storage_; } void Train(u32 training_embedding_num, const f32 *training_data, u32 expect_centroid_num = 0); void AddEmbedding(SegmentOffset segment_offset, const void *embedding_ptr); diff --git a/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp b/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp index 5638883410..5348d3809d 100644 --- a/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp +++ b/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp @@ -37,6 +37,8 @@ import ivf_index_util_func; import mlas_matrix_multiply; import vector_distance; import index_base; +import knn_expr; +import simd_functions; namespace infinity { @@ -74,12 +76,12 @@ class IVF_Part_Storage { const IVF_Centroids_Storage *ivf_centroids_storage, const IVF_Parts_Storage *ivf_parts_storage) = 0; - virtual void SearchIndex(const KnnDistanceBase1 *knn_distance, + virtual void SearchIndex(const IVF_Index_Storage *ivf_index_storage, + const KnnDistanceBase1 *knn_distance, const void *query_ptr, EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, - const std::function &add_result_func, - const IVF_Parts_Storage *ivf_parts_storage) const = 0; + const std::function &add_result_func) const = 0; }; template @@ -114,6 +116,9 @@ class IVF_Parts_Storage_Info // size: real_subspace_centroid_num_ * subspace_num_ UniquePtr subspace_centroid_norms_neg_half_ = {}; +public: + auto real_subspace_centroid_num() const { return real_subspace_centroid_num_; } + const f32 *subspace_centroids_data_at_subspace(const u32 subspace_id) const { return subspace_centroids_data_.get() + subspace_id * subspace_dimension_ * real_subspace_centroid_num_; } @@ -128,7 +133,6 @@ class IVF_Parts_Storage_Info NON_CONST_VERSION_MEMBER_FUNC(subspace_centroids_data_at_subspace); NON_CONST_VERSION_MEMBER_FUNC(subspace_centroid_norms_neg_half_at_subspace); -public: IVF_Parts_Storage_Info(const u32 embedding_dim, const u32 centroids_num, const EmbeddingDataType embedding_data_type, @@ -249,6 +253,42 @@ class IVF_Parts_Storage_Info } } } + + void EncodeResidual(const f32 *residual, u32 *encode_output) const { + const auto xy_buffer = MakeUniqueForOverwrite(real_subspace_centroid_num_); + for (u32 j = 0; j < subspace_num_; ++j) { + matrixA_multiply_transpose_matrixB_output_to_C(residual + j * subspace_dimension_, + subspace_centroids_data_at_subspace(j), + 1, + real_subspace_centroid_num_, + subspace_dimension_, + xy_buffer.get()); + // find max id (for every embedding, find centroid with min l2 distance, and equivalently max (x*y - 0.5*y^2)) + const auto *c_norm_data = subspace_centroid_norms_neg_half_at_subspace(j); + f32 max_neg_distance = std::numeric_limits::lowest(); + u32 max_id = 0; + for (u32 k = 0; k < real_subspace_centroid_num_; ++k) { + if (const f32 neg_distance = xy_buffer[k] + c_norm_data[k]; neg_distance > max_neg_distance) { + max_neg_distance = neg_distance; + max_id = k; + } + } + encode_output[j] = max_id; + } + } + + UniquePtr GetIPTable(const f32 *query) const { + auto ip_table = MakeUniqueForOverwrite(subspace_num_ * real_subspace_centroid_num_); + for (u32 i = 0; i < subspace_num_; ++i) { + matrixA_multiply_matrixB_output_to_C(subspace_centroids_data_at_subspace(i), + query + i * subspace_dimension_, + real_subspace_centroid_num_, + 1, + subspace_dimension_, + ip_table.get() + i * real_subspace_centroid_num_); + } + return ip_table; + } }; template @@ -294,12 +334,14 @@ class IVF_Parts_Storage_T final : public IVF_Parts_Storage_Info { } void SearchIndex(const u32 part_id, + const IVF_Index_Storage *ivf_index_storage, const KnnDistanceBase1 *knn_distance, const void *query_ptr, const EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, const std::function &add_result_func) const override { - return ivf_part_storages_[part_id]->SearchIndex(knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func, this); + return ivf_part_storages_[part_id] + ->SearchIndex(ivf_index_storage, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func); } }; @@ -378,12 +420,12 @@ class IVF_Part_Storage_Plain final : public IVF_Part_Storage { ++embedding_num_; } - void SearchIndex(const KnnDistanceBase1 *knn_distance, + void SearchIndex(const IVF_Index_Storage *, + const KnnDistanceBase1 *knn_distance, const void *query_ptr, const EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, - const std::function &add_result_func, - const IVF_Parts_Storage *) const override { + const std::function &add_result_func) const override { auto ReturnT = [&] { if constexpr ((query_element_type == EmbeddingDataType::kElemFloat && IsAnyOf) || (query_element_type == src_embedding_data_type && @@ -617,26 +659,39 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { void AppendOneEmbedding(const void *embedding_ptr, const SegmentOffset segment_offset, - const IVF_Centroids_Storage *, - const IVF_Parts_Storage *) override { - const auto *src_embedding_data = static_cast(embedding_ptr); - (void)(src_embedding_data); - // TODO + const IVF_Centroids_Storage *ivf_centroids_storage, + const IVF_Parts_Storage *ivf_parts_storage) override { + const auto dimension = ivf_centroids_storage->embedding_dimension(); + const auto residual = MakeUniqueForOverwrite(dimension); + const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); + { + const auto [src_embedding_f32, _] = GetF32Ptr(static_cast(embedding_ptr), dimension); + const auto centroid_data = ivf_centroids_storage->data() + part_id() * dimension; + for (u32 i = 0; i < dimension; ++i) { + residual[i] = src_embedding_f32[i] - centroid_data[i]; + } + } + const auto *ivf_parts_storage_info = + dynamic_cast *>(ivf_parts_storage); + assert(ivf_parts_storage_info); + ivf_parts_storage_info->EncodeResidual(residual.get(), encoded_codes.get()); + pq_code_storage_->AppendCodes(encoded_codes.get()); embedding_segment_offsets_.push_back(segment_offset); ++embedding_num_; } - void SearchIndex(const KnnDistanceBase1 *knn_distance, + void SearchIndex(const IVF_Index_Storage *ivf_index_storage, + const KnnDistanceBase1 *knn_distance, const void *query_ptr, const EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, - const std::function &add_result_func, - const IVF_Parts_Storage *) const override { + const std::function &add_result_func) const override { auto ReturnT = [&] { if constexpr ((query_element_type == EmbeddingDataType::kElemFloat && IsAnyOf) || (query_element_type == src_embedding_data_type && (query_element_type == EmbeddingDataType::kElemInt8 || query_element_type == EmbeddingDataType::kElemUInt8))) { - return SearchIndexT(knn_distance, + return SearchIndexT(ivf_index_storage, + knn_distance, static_cast *>(query_ptr), satisfy_filter_func, add_result_func); @@ -661,7 +716,8 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { } template - void SearchIndexT(const KnnDistanceBase1 *knn_distance, + void SearchIndexT(const IVF_Index_Storage *ivf_index_storage, + const KnnDistanceBase1 *knn_distance, const EmbeddingDataTypeToCppTypeT *query_ptr, const std::function &satisfy_filter_func, const std::function &add_result_func) const { @@ -670,19 +726,89 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { if (!knn_distance_1) [[unlikely]] { UnrecoverableError("Invalid KnnDistance1"); } - // TODO - // auto dist_func = knn_distance_1->dist_func_; - // const auto total_embedding_num = embedding_num(); - // for (u32 i = 0; i < total_embedding_num; ++i) { - // const auto segment_offset = embedding_segment_offset(i); - // if (!satisfy_filter_func(segment_offset)) { - // continue; - // } - // auto v_ptr = data_.data() + i * embedding_dimension(); - // auto [calc_ptr, _] = GetSearchCalcPtr(v_ptr, embedding_dimension()); - // auto d = dist_func(calc_ptr, query_ptr, embedding_dimension()); - // add_result_func(d, segment_offset); - // } + const auto &ivf_parts_storage = + static_cast &>(ivf_index_storage->ivf_parts_storage()); + const auto subspace_num = subspace_num_; + const auto real_subspace_centroid_num = ivf_parts_storage.real_subspace_centroid_num(); + const auto dimension = ivf_index_storage->embedding_dimension(); + const auto [query_f32, _] = GetF32Ptr(query_ptr, dimension); + const auto centroid_data = ivf_index_storage->ivf_centroids_storage().data() + part_id() * dimension; + const auto ip_func = GetSIMD_FUNCTIONS().IPDistance_func_ptr_; + switch (const KnnDistanceType dist_type = knn_distance_1->dist_type_; dist_type) { + case KnnDistanceType::kInnerProduct: { + const auto query_centroid_ip = ip_func(query_f32, centroid_data, dimension); + const auto ip_table = ivf_parts_storage.GetIPTable(query_f32); + const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); + const auto total_embedding_num = embedding_num(); + for (u32 i = 0; i < total_embedding_num; ++i) { + const auto segment_offset = embedding_segment_offset(i); + if (!satisfy_filter_func(segment_offset)) { + continue; + } + pq_code_storage_->ExtractCodes(i, encoded_codes.get()); + f32 d = query_centroid_ip; + for (u32 j = 0; j < subspace_num; ++j) { + d += ip_table[j * real_subspace_centroid_num + encoded_codes[j]]; + } + add_result_func(d, segment_offset); + } + break; + } + case KnnDistanceType::kCosine: { + const auto query_l2 = L2NormSquare(query_f32, dimension); + const auto centroid_l2 = L2NormSquare(centroid_data, dimension); + const auto query_centroid_ip = ip_func(query_f32, centroid_data, dimension); + const auto query_ip_table = ivf_parts_storage.GetIPTable(query_f32); + const auto centroid_ip_table = ivf_parts_storage.GetIPTable(centroid_data); + const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); + const auto total_embedding_num = embedding_num(); + for (u32 i = 0; i < total_embedding_num; ++i) { + const auto segment_offset = embedding_segment_offset(i); + if (!satisfy_filter_func(segment_offset)) { + continue; + } + pq_code_storage_->ExtractCodes(i, encoded_codes.get()); + f32 ip = query_centroid_ip; + f32 target_l2 = centroid_l2; + for (u32 j = 0; j < subspace_num; ++j) { + ip += query_ip_table[j * real_subspace_centroid_num + encoded_codes[j]]; + target_l2 -= 2.0f * (centroid_ip_table[j * real_subspace_centroid_num + encoded_codes[j]] + + ivf_parts_storage.subspace_centroid_norms_neg_half_at_subspace(j)[encoded_codes[j]]); + } + const auto d = ip / std::sqrt(query_l2 * target_l2); + add_result_func(d, segment_offset); + } + break; + } + case KnnDistanceType::kL2: { + const auto residual_query = MakeUniqueForOverwrite(dimension); + for (u32 i = 0; i < dimension; ++i) { + residual_query[i] = query_f32[i] - centroid_data[i]; + } + const auto residual_query_l2 = L2NormSquare(residual_query.get(), dimension); + const auto residual_ip_table = ivf_parts_storage.GetIPTable(residual_query.get()); + const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); + const auto total_embedding_num = embedding_num(); + for (u32 i = 0; i < total_embedding_num; ++i) { + const auto segment_offset = embedding_segment_offset(i); + if (!satisfy_filter_func(segment_offset)) { + continue; + } + pq_code_storage_->ExtractCodes(i, encoded_codes.get()); + f32 d = residual_query_l2; + for (u32 j = 0; j < subspace_num; ++j) { + d -= 2.0f * (residual_ip_table[j * real_subspace_centroid_num + encoded_codes[j]] + + ivf_parts_storage.subspace_centroid_norms_neg_half_at_subspace(j)[encoded_codes[j]]); + } + add_result_func(d, segment_offset); + } + break; + } + default: { + RecoverableError(Status::SyntaxError(fmt::format("IVFPQ does not support {} metric now.", KnnExpr::KnnDistanceType2Str(dist_type)))); + break; + } + } } }; diff --git a/src/storage/knn_index/sparse/bmp_alg.cpp b/src/storage/knn_index/sparse/bmp_alg.cpp index b79d90ebdc..a02dd36b7d 100644 --- a/src/storage/knn_index/sparse/bmp_alg.cpp +++ b/src/storage/knn_index/sparse/bmp_alg.cpp @@ -16,7 +16,7 @@ module; #include #include -#include +#include "common/simd/simd_common_intrin_include.h" module bmp_alg; diff --git a/src/storage/knn_index/sparse/bmp_alg_serialize.cpp b/src/storage/knn_index/sparse/bmp_alg_serialize.cpp index e598f6e33f..bb70369d10 100644 --- a/src/storage/knn_index/sparse/bmp_alg_serialize.cpp +++ b/src/storage/knn_index/sparse/bmp_alg_serialize.cpp @@ -15,7 +15,7 @@ module; #include -#include +#include "common/simd/simd_common_intrin_include.h" module bmp_alg; diff --git a/src/storage/knn_index/sparse/bmp_blockterms.cppm b/src/storage/knn_index/sparse/bmp_blockterms.cppm index 95ed2d64ac..d5d7cb11f2 100644 --- a/src/storage/knn_index/sparse/bmp_blockterms.cppm +++ b/src/storage/knn_index/sparse/bmp_blockterms.cppm @@ -14,7 +14,7 @@ module; -#include +#include "common/simd/simd_common_intrin_include.h" export module bmp_blockterms; diff --git a/src/storage/knn_index/sparse/bmp_posting.cpp b/src/storage/knn_index/sparse/bmp_posting.cpp index 7e34cfae65..bfab145e67 100644 --- a/src/storage/knn_index/sparse/bmp_posting.cpp +++ b/src/storage/knn_index/sparse/bmp_posting.cpp @@ -14,7 +14,7 @@ module; -#include +#include "common/simd/simd_common_intrin_include.h" module bm_posting; @@ -41,8 +41,8 @@ void BlockData::AddBlock(BMPBlockID bloc template void BlockData::Prefetch() const { - _mm_prefetch(block_ids_.data(), _MM_HINT_T0); - _mm_prefetch(max_scores_.data(), _MM_HINT_T0); + _mm_prefetch((const char*)block_ids_.data(), _MM_HINT_T0); + _mm_prefetch((const char*)max_scores_.data(), _MM_HINT_T0); } template struct BlockData; @@ -67,7 +67,7 @@ void BlockData::AddBlock(BMPBlockID block_id, D template void BlockData::Prefetch() const { - _mm_prefetch(max_scores_.data(), _MM_HINT_T0); + _mm_prefetch((const char*)max_scores_.data(), _MM_HINT_T0); } template struct BlockData; diff --git a/src/unit_test/storage/knnindex/emvb_search/test_simd.cpp b/src/unit_test/storage/knnindex/emvb_search/test_simd.cpp index cdc080c012..8a9a8896bd 100644 --- a/src/unit_test/storage/knnindex/emvb_search/test_simd.cpp +++ b/src/unit_test/storage/knnindex/emvb_search/test_simd.cpp @@ -13,7 +13,7 @@ // limitations under the License. #include -#include +#include "common/simd/simd_common_intrin_include.h" #include "gtest/gtest.h" import base_test; @@ -25,6 +25,15 @@ using namespace infinity; class SIMDTest : public BaseTest {}; +#if defined(__aarch64__) +inline float hsum256_ps_avx(__m256 v) { + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v)); + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + return _mm_cvtss_f32(x32); +} +#endif + TEST_F(SIMDTest, testsum256) { constexpr u32 test_sum256_loop = 20; diff --git a/test/sql/dql/type/math.slt b/test/sql/dql/type/math.slt new file mode 100644 index 0000000000..5bac2fb666 --- /dev/null +++ b/test/sql/dql/type/math.slt @@ -0,0 +1,49 @@ +statement ok +DROP TABLE IF EXISTS test_math; + +statement ok +CREATE TABLE test_math (c1 integer, c2 float); + +# insert + +statement ok +INSERT INTO test_math VALUES (1, 4), (4, 9), (9, 16), (10, 2.4), (10, -2.4), (10, 2.5), (10, -2.5); + +query I +SELECT *, sqrt(c1), sqrt(c2) FROM test_math WHERE c1 < 10; +---- +1 4.000000 1.000000 2.000000 +4 9.000000 2.000000 3.000000 +9 16.000000 3.000000 4.000000 + +query II +SELECT c1, sqrt(c1) FROM test_math WHERE sqrt(c1) = 2; +---- +4 2.000000 + +query III +SELECT c2, round(c2) FROM test_math WHERE c1 > 9; +---- +2.400000 2.000000 +-2.400000 -2.000000 +2.500000 3.000000 +-2.500000 -3.000000 + +query IV +SELECT c2, ceil(c2) FROM test_math WHERE c1 > 9; +---- +2.400000 3.000000 +-2.400000 -2.000000 +2.500000 3.000000 +-2.500000 -2.000000 + +query V +SELECT c2, floor(c2) FROM test_math WHERE c1 > 9; +---- +2.400000 2.000000 +-2.400000 -3.000000 +2.500000 2.000000 +-2.500000 -3.000000 + +statement ok +DROP TABLE test_math; \ No newline at end of file