Skip to content

Commit

Permalink
Optimize buffer mgr (#1444)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

1. Fix test script in mldr_benchmark
2. Revert lru
3. Add multiple chunk in lru to reduce thread race. default to be 7
`./build/RelWithDebInfo/benchmark/local_infinity/knn_query_benchmark
sift 300 true` which "true" means rerank by reading origin data from db,
performance of thread_n = 8 upgrade from ~0.9s to ~0.66s.

### Type of change

- [x] Performance Improvement

---------

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
  • Loading branch information
small-turtle-1 and JinHai-CN authored Jul 5, 2024
1 parent 0778e98 commit e25e10e
Show file tree
Hide file tree
Showing 24 changed files with 361 additions and 209 deletions.
6 changes: 2 additions & 4 deletions benchmark/local_infinity/knn/hnsw_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "hnsw_benchmark_util.h"
#include "CLI11.hpp"

import stl;
import third_party;
Expand Down Expand Up @@ -100,8 +99,7 @@ struct BenchmarkOption {
try {
app_.parse(argc, argv);
} catch (const CLI::ParseError &e) {
std::cout << e.what() << std::endl;
exit(1);
UnrecoverableError(e.what());
}
ParseInner();
}
Expand Down Expand Up @@ -152,7 +150,7 @@ struct BenchmarkOption {

using LabelT = i32;
using Hnsw = KnnHnsw<PlainL2VecStoreType<float>, LabelT>;
using HnswLVQ = KnnHnsw<LVQL2VecStoreType<float, int8_t>, LabelT>;
using HnswLVQ = KnnHnsw<LVQL2VecStoreType<float, i8>, LabelT>;

template <typename HnswT, typename HnswT2>
void Build(const BenchmarkOption &option) {
Expand Down
2 changes: 2 additions & 0 deletions benchmark/local_infinity/knn/hnsw_benchmark_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include "CLI11.hpp"

import stl;
import file_system;
import local_file_system;
Expand Down
1 change: 1 addition & 0 deletions conf/infinity_conf.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ mem_index_capacity = 1048576

[buffer]
buffer_manager_size = "4GB"
lru_num = 7
temp_dir = "/var/infinity/tmp"

[wal]
Expand Down
4 changes: 2 additions & 2 deletions python/benchmark/configs/infinity_gist.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"type": "HNSW",
"params": {
"M": 16,
"ef_construction": 200,
"ef": 200,
"ef_construction": 800,
"ef": 800,
"metric": "l2",
"encode": "lvq"
}
Expand Down
47 changes: 27 additions & 20 deletions python/benchmark/legacy_benchmark/remote_benchmark_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ def wrapped_func(*args, **kwargs):


@trace_unhandled_exceptions
def work(queries, topk, metric_type, column_name, data_type, table_name="sift_benchmark"):
def work(queries, topk, metric_type, column_name, data_type, ef: int, table_name="sift_benchmark"):
conn = ThriftInfinityClient(LOCAL_HOST)
for query in queries:
# print(len(query))
table = RemoteTable(conn, "default_db", table_name)
# table.knn(column_name, query_vec, data_type, metric_type, topk).output(["_row_id"]).to_result()
query_builder = InfinityThriftQueryBuilder(table)
query_builder.output(["_row_id"])
query_builder.knn(column_name, query, data_type, metric_type, topk)
query_builder.knn(column_name, query, data_type, metric_type, topk, {"ef": str(ef)})
query_builder.to_result()
conn.disconnect()

Expand All @@ -137,7 +137,7 @@ def fvecs_read(filename):
break


def process_pool(threads, rounds, query_path, table_name):
def process_pool(threads, rounds, query_path, ef: int, table_name):
if not os.path.exists(query_path):
print(f"File: {query_path} doesn't exist")
raise Exception(f"File: {query_path} doesn't exist")
Expand All @@ -155,7 +155,7 @@ def process_pool(threads, rounds, query_path, table_name):
p = multiprocessing.Pool(threads)
start = time.time()
for idx in range(threads):
p.apply_async(work, args=(queries[idx], 100, "l2", "col1", "float", table_name))
p.apply_async(work, args=(queries[idx], 100, "l2", "col1", "float", ef, table_name))
p.close()
p.join()
end = time.time()
Expand All @@ -169,7 +169,7 @@ def process_pool(threads, rounds, query_path, table_name):
print(result)


def one_thread(rounds, query_path, ground_truth_path, table_name):
def one_thread(rounds, query_path, ground_truth_path, ef: int, table_name):
if not os.path.exists(query_path):
print(f"File: {query_path} doesn't exist")
raise Exception(f"File: {query_path} doesn't exist")
Expand All @@ -181,9 +181,10 @@ def one_thread(rounds, query_path, ground_truth_path, table_name):
table = RemoteTable(conn, "default_db", table_name)
query_builder = InfinityThriftQueryBuilder(table)
query_builder.output(["_row_id"])
query_builder.knn('col1', queries[0], 'float', 'l2', 100, {'ef': '200'})
query_builder.knn('col1', queries[0], 'float', 'l2', 100, {'ef': str(ef)})
res, _ = query_builder.to_result()

dur_sum = 0
for i in range(rounds):

query_results = [[] for _ in range(len(queries))]
Expand All @@ -195,7 +196,7 @@ def one_thread(rounds, query_path, ground_truth_path, table_name):

query_builder = InfinityThriftQueryBuilder(table)
query_builder.output(["_row_id"])
query_builder.knn('col1', query_vec, 'float', 'l2', 100, {'ef': '200'})
query_builder.knn('col1', query_vec, 'float', 'l2', 100, {'ef': str(ef)})
res, _ = query_builder.to_result()
end = time.time()

Expand All @@ -220,13 +221,19 @@ def one_thread(rounds, query_path, ground_truth_path, table_name):
results.append(f"Recall@10: {recall_10}")
results.append(f"Recall@100: {recall_100}")

dur_sum += dur

dur_sum /= rounds
results.append(f"Avg total dur: {dur_sum:.2f} s")
results.append(f"Avg QPS: {(len(queries) / dur_sum):.2f}")

conn.disconnect()

for result in results:
print(result)


def benchmark(threads, rounds, data_set, path):
def benchmark(threads, rounds, data_set, ef: int, path):
if not os.path.exists(path):
print(f"Path: {path} doesn't exist")
raise Exception(f"Path: {path} doesn't exist")
Expand All @@ -236,34 +243,28 @@ def benchmark(threads, rounds, data_set, path):
if threads > 1:
print(f"Multi-threads: {threads}")
print(f"Rounds: {rounds}")
process_pool(threads, rounds, query_path, "sift_benchmark")
process_pool(threads, rounds, query_path, ef, "sift_benchmark")

else:
print(f"Single-thread")
print(f"Rounds: {rounds}")
one_thread(rounds, query_path, ground_truth_path, "sift_benchmark")
one_thread(rounds, query_path, ground_truth_path, ef, "sift_benchmark")
elif data_set == "gist_1m":
query_path = path + "/gist_query.fvecs"
ground_truth_path = path + "/gist_groundtruth.ivecs"
if threads > 1:
print(f"Multi-threads: {threads}")
print(f"Rounds: {rounds}")
process_pool(threads, rounds, query_path, "gist_benchmark")
process_pool(threads, rounds, query_path, ef, "gist_benchmark")

else:
print(f"Single-thread")
print(f"Rounds: {rounds}")
one_thread(rounds, query_path, ground_truth_path, "gist_benchmark")
one_thread(rounds, query_path, ground_truth_path, ef, "gist_benchmark")


if __name__ == '__main__':
current_path = os.getcwd()
parent_path = os.path.dirname(current_path)
parent_path = os.path.dirname(parent_path)
parent_path = os.path.dirname(parent_path)

print(f"Current Path: {current_path}")
print(f"Parent Path: {parent_path}")

parser = argparse.ArgumentParser(description="Benchmark Infinity")

Expand All @@ -288,10 +289,16 @@ def benchmark(threads, rounds, data_set, path):
default='sift_1m', # gist_1m
dest="data_set",
)
parser.add_argument(
"--ef",
type=int,
default=100,
dest="ef"
)

data_dir = parent_path + "/test/data/benchmark/" + parser.parse_args().data_set
data_dir = current_path + "/test/data/benchmark/" + parser.parse_args().data_set
print(f"Data Dir: {data_dir}")

args = parser.parse_args()

benchmark(args.threads, args.rounds, args.data_set, path=data_dir)
benchmark(args.threads, args.rounds, args.data_set, args.ef, path=data_dir)
44 changes: 22 additions & 22 deletions python/benchmark/legacy_benchmark/remote_benchmark_knn_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
from infinity.remote_thrift.table import RemoteTable


def import_data(path):
if os.path.exists(path + "/sift_base.fvecs"):
import_sift_1m(path + "/sift_base.fvecs")
elif os.path.exists(path + "/gist_base.fvecs"):
import_gist_1m(path + "/gist_base.fvecs")
def import_data(path, dataset: str, ef_construction: int):
if dataset == "sift_1m":
import_sift_1m(path + "/sift_base.fvecs", ef_construction)
elif dataset == "gist_1m":
import_gist_1m(path + "/gist_base.fvecs", ef_construction)
else:
raise Exception("Invalid data set")


def import_sift_1m(path):
def import_sift_1m(path, ef_construction: int):
infinity_obj = infinity.connect(LOCAL_HOST)
assert infinity_obj

Expand All @@ -41,13 +41,13 @@ def import_sift_1m(path):
assert res.error_code == ErrorCode.OK

start = time.time()
create_index("sift_benchmark")
create_index("sift_benchmark", ef_construction)
end = time.time()
dur = end - start
print(f"Create index on sift_1m cost time: {dur} s")


def import_gist_1m(path):
def import_gist_1m(path, ef_construction: int):
infinity_obj = infinity.connect(LOCAL_HOST)
assert infinity_obj

Expand All @@ -69,22 +69,22 @@ def import_gist_1m(path):
assert res.error_code == ErrorCode.OK

start = time.time()
create_index("gist_benchmark")
create_index("gist_benchmark", ef_construction)
end = time.time()
dur = end - start
print(f"Create index on gist_1m cost time: {dur} s")


def create_index(table_name):
def create_index(table_name, ef_construction):
conn = ThriftInfinityClient(LOCAL_HOST)
table = RemoteTable(conn, "default_db", table_name)
res = table.create_index("hnsw_index",
[index.IndexInfo("col1",
index.IndexType.Hnsw,
[
index.InitParameter("M", "16"),
index.InitParameter("ef_construction", "200"),
index.InitParameter("ef", "200"),
index.InitParameter("ef_construction", str(ef_construction)),
index.InitParameter("ef", str(ef_construction)),
index.InitParameter("metric", "l2"),
index.InitParameter("encode", "lvq")
])])
Expand All @@ -94,12 +94,6 @@ def create_index(table_name):

if __name__ == '__main__':
current_path = os.getcwd()
parent_path = os.path.dirname(current_path)
parent_path = os.path.dirname(parent_path)
parent_path = os.path.dirname(parent_path)

print(f"Current Path: {current_path}")
print(f"Parent Path: {parent_path}")

parser = argparse.ArgumentParser(description="Benchmark Infinity")

Expand All @@ -110,10 +104,16 @@ def create_index(table_name):
default='sift_1m', # gist_1m
dest="data_set",
)

data_dir = parent_path + "/test/data/benchmark/" + parser.parse_args().data_set
print(f"Data Dir: {data_dir}")
parser.add_argument(
"--ef_construction",
type=int,
default=100,
dest="ef_construction"
)

args = parser.parse_args()

import_data(path=data_dir)
data_dir = current_path + "/test/data/benchmark/" + args.data_set
print(f"Data Dir: {data_dir}")

import_data(data_dir, args.data_set, args.ef_construction)
48 changes: 48 additions & 0 deletions python/benchmark/mldr_benchmark/get_search_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,34 @@ def prepare_sparse_embedding(embedding_file: str, model_args: ModelArgs, queries
return


def bm25_query_yield(queries: list[str], embedding_file: str):
for query in queries:
yield query.translate(query_translation_table)


def dense_query_yield(queries: list[str], embedding_file: str):
return fvecs_read_yield(embedding_file)


def sparse_query_yield(queries: list[str], embedding_file: str):
return read_mldr_sparse_embedding_yield(embedding_file)


def apply_bm25(table, query_str: str, max_hits: int):
return table.match('fulltext_col', query_str, f'topn={max_hits}')


def apply_dense(table, query_embedding, max_hits: int):
return table.knn("dense_col", query_embedding, "float", "ip", max_hits, {"ef": str(max_hits)})


def apply_sparse(table, query_embedding: dict, max_hits: int):
return table.match_sparse("sparse_col", query_embedding, "ip", max_hits, {"alpha": "0.9", "beta": "0.5"})


apply_funcs = {'bm25': apply_bm25, 'dense': apply_dense, 'sparse': apply_sparse}


def prepare_colbert_embedding(embedding_file: str, model_args: ModelArgs, queries: list[str], qids: list[int]):
model = get_colbert_model(model_args)
query_embedding = model.encode_query(queries)
Expand Down Expand Up @@ -109,13 +137,32 @@ def get_test_table(self, language_suffix: str):
self.infinity_table = self.infinity_db.get_table(table_name)
print(f"Get table {table_name} successfully.")

def bm25_query(self, query_str: str, max_hits: int):
# query_str = query.translate(query_translation_table)
result = self.infinity_table.output(["docid_col", "_score"]).match('fulltext_col', query_str,
f'topn={max_hits}').to_pl()
return result['docid_col'], result['SCORE']

def dense_query(self, query_embedding, max_hits: int):
result = self.infinity_table.output(["docid_col", "_similarity"]).knn("dense_col", query_embedding, "float",
"ip", max_hits, {"ef": str(max_hits)}).to_pl()
return result['docid_col'], result['SIMILARITY']

def sparse_query(self, query_embedding: dict, max_hits: int):
result = self.infinity_table.output(["docid_col", "_similarity"]).match_sparse("sparse_col", query_embedding,
"ip", max_hits,
{"alpha": "0.9",
"beta": "0.5"}).to_pl()
return result['docid_col'], result['SIMILARITY']

def common_single_query_func(self, query_type: str, query_target, max_hits: int):
str_params = single_query_func_params[query_type]
result_table = self.infinity_table.output(["docid_col", str_params[0]])
result_table = apply_funcs[query_type](result_table, query_target, max_hits)
result = result_table.to_pl()
return result['docid_col'], result[str_params[1]]


def fusion_query(self, query_targets_list: list, apply_funcs_list: list, max_hits: int):
result_table = self.infinity_table.output(["docid_col", "_score"])
for query_target, apply_func in zip(query_targets_list, apply_funcs_list):
Expand All @@ -124,6 +171,7 @@ def fusion_query(self, query_targets_list: list, apply_funcs_list: list, max_hit
result = result_table.to_pl()
return result['docid_col'], result['SCORE']


def main(self, languages: list[str], query_types: list[str], model_args: ModelArgs, save_dir: str):
for lang in languages:
print(f"Start to search for language: {lang}")
Expand Down
6 changes: 3 additions & 3 deletions python/benchmark/mldr_benchmark/insert_data_50000.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def main(self):
res = self.infinity_table.create_index("hnsw_index", [index.IndexInfo("dense_col", index.IndexType.Hnsw,
[index.InitParameter("M", "16"),
index.InitParameter("ef_construction",
"200"),
index.InitParameter("ef", "200"),
"1000"),
index.InitParameter("ef", "1000"),
index.InitParameter("metric", "ip"),
index.InitParameter("encode", "lvq")])],
ConflictType.Error)
Expand All @@ -120,7 +120,7 @@ def main(self):
"compress")])],
ConflictType.Error)
assert res.error_code == ErrorCode.OK
self.infinity_table.optimize("bmp_index", {"topk": "1000", "bp_reorder": {}})
self.infinity_table.optimize("bmp_index", {"topk": "1000", "bp_reorder": ""})
print("Finish creating BMP index.")


Expand Down
Loading

0 comments on commit e25e10e

Please sign in to comment.