diff --git a/src/search/hnsw_indexer.cc b/src/search/hnsw_indexer.cc index bb2918d5d03..de0f519a9b5 100644 --- a/src/search/hnsw_indexer.cc +++ b/src/search/hnsw_indexer.cc @@ -172,14 +172,13 @@ StatusOr ComputeSimilarity(const VectorItem& left, const VectorItem& rig } } -HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage) +HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage, + std::random_device::result_type seed) : search_key(search_key), metadata(vector), storage(storage), - m_level_normalization_factor(1.0 / std::log(metadata->m)) { - std::random_device rand_dev; - generator = std::mt19937(rand_dev()); -} + generator(std::mt19937(seed)), + m_level_normalization_factor(1.0 / std::log(metadata->m)) {} uint16_t HnswIndex::RandomizeLayer() { std::uniform_real_distribution level_dist(0.0, 1.0); diff --git a/src/search/hnsw_indexer.h b/src/search/hnsw_indexer.h index cfe352ff709..579352a8b22 100644 --- a/src/search/hnsw_indexer.h +++ b/src/search/hnsw_indexer.h @@ -92,7 +92,8 @@ struct HnswIndex { std::mt19937 generator; double m_level_normalization_factor; - HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage); + HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage, + std::random_device::result_type seed = std::random_device()()); static StatusOr> DecodeNodesToVectorItems(engine::Context& ctx, const std::vector& node_key, diff --git a/tests/cppunit/hnsw_index_test.cc b/tests/cppunit/hnsw_index_test.cc index 332c1582afc..022f2a73880 100644 --- a/tests/cppunit/hnsw_index_test.cc +++ b/tests/cppunit/hnsw_index_test.cc @@ -66,6 +66,7 @@ struct HnswIndexTest : TestBase { std::string idx_name = "hnsw_test_idx"; std::string key = "vector"; std::unique_ptr hnsw_index; + const std::random_device::result_type seed = 14863; // fixed seed for reproducibility HnswIndexTest() { metadata.vector_type = redis::VectorType::FLOAT64; @@ -73,7 +74,7 @@ struct HnswIndexTest : TestBase { metadata.m = 3; metadata.distance_metric = redis::DistanceMetric::L2; auto search_key = redis::SearchKey(ns, idx_name, key); - hnsw_index = std::make_unique(search_key, &metadata, storage_.get()); + hnsw_index = std::make_unique(search_key, &metadata, storage_.get(), seed); } void TearDown() override { hnsw_index.reset(); }