From 044556ecd2a27887d5b9e86c8d69552624fc8c79 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Tue, 22 Oct 2024 14:43:21 -0700 Subject: [PATCH] [TST]: update test_persist to create 1+ collections (#2933) --- chromadb/test/property/test_persist.py | 96 +++++++++++++++++--------- 1 file changed, 63 insertions(+), 33 deletions(-) diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index 92f65b27714..bd926fb0b4c 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -3,7 +3,7 @@ from multiprocessing.connection import Connection import multiprocessing.context import time -from typing import Generator, Callable +from typing import Generator, Callable, List, Tuple from uuid import UUID from hypothesis import given import hypothesis.strategies as st @@ -66,6 +66,7 @@ def settings(request: pytest.FixtureRequest) -> Generator[Settings, None, None]: with_hnsw_params=True, with_persistent_hnsw_params=st.just(True), # Makes it more likely to find persist-related bugs (by default these are set to 2000). + # Lower values make it more likely that a test will trigger a persist to disk. max_hnsw_batch_size=10, max_hnsw_sync_threshold=10, ), @@ -73,37 +74,62 @@ def settings(request: pytest.FixtureRequest) -> Generator[Settings, None, None]: ) +@st.composite +def collection_and_recordset_strategy( + draw: st.DrawFn, +) -> Tuple[strategies.Collection, strategies.RecordSet]: + collection = draw( + strategies.collections( + with_hnsw_params=True, + with_persistent_hnsw_params=st.just(True), + # Makes it more likely to find persist-related bugs (by default these are set to 2000). + max_hnsw_batch_size=10, + max_hnsw_sync_threshold=10, + ) + ) + recordset = draw(strategies.recordsets(st.just(collection))) + return collection, recordset + + @given( - collection_strategy=collection_st, - embeddings_strategy=strategies.recordsets(collection_st), + collection_and_recordset_strategies=st.lists( + collection_and_recordset_strategy(), + min_size=1, + unique_by=(lambda x: x[0].name, lambda x: x[0].name), + ) ) def test_persist( settings: Settings, - collection_strategy: strategies.Collection, - embeddings_strategy: strategies.RecordSet, + collection_and_recordset_strategies: List[ + Tuple[strategies.Collection, strategies.RecordSet] + ], ) -> None: system_1 = System(settings) system_1.start() client_1 = ClientCreator.from_system(system_1) client_1.reset() - coll = client_1.create_collection( - name=collection_strategy.name, - metadata=collection_strategy.metadata, # type: ignore[arg-type] - embedding_function=collection_strategy.embedding_function, - ) + for ( + collection_strategy, + recordset_strategy, + ) in collection_and_recordset_strategies: + coll = client_1.create_collection( + name=collection_strategy.name, + metadata=collection_strategy.metadata, # type: ignore[arg-type] + embedding_function=collection_strategy.embedding_function, + ) - coll.add(**embeddings_strategy) # type: ignore[arg-type] + coll.add(**recordset_strategy) # type: ignore[arg-type] - invariants.count(coll, embeddings_strategy) - invariants.metadatas_match(coll, embeddings_strategy) - invariants.documents_match(coll, embeddings_strategy) - invariants.ids_match(coll, embeddings_strategy) - invariants.ann_accuracy( - coll, - embeddings_strategy, - embedding_function=collection_strategy.embedding_function, - ) + invariants.count(coll, recordset_strategy) + invariants.metadatas_match(coll, recordset_strategy) + invariants.documents_match(coll, recordset_strategy) + invariants.ids_match(coll, recordset_strategy) + invariants.ann_accuracy( + coll, + recordset_strategy, + embedding_function=collection_strategy.embedding_function, + ) system_1.stop() del client_1 @@ -113,19 +139,23 @@ def test_persist( system_2.start() client_2 = ClientCreator.from_system(system_2) - coll = client_2.get_collection( - name=collection_strategy.name, - embedding_function=collection_strategy.embedding_function, - ) - invariants.count(coll, embeddings_strategy) - invariants.metadatas_match(coll, embeddings_strategy) - invariants.documents_match(coll, embeddings_strategy) - invariants.ids_match(coll, embeddings_strategy) - invariants.ann_accuracy( - coll, - embeddings_strategy, - embedding_function=collection_strategy.embedding_function, - ) + for ( + collection_strategy, + recordset_strategy, + ) in collection_and_recordset_strategies: + coll = client_2.get_collection( + name=collection_strategy.name, + embedding_function=collection_strategy.embedding_function, + ) + invariants.count(coll, recordset_strategy) + invariants.metadatas_match(coll, recordset_strategy) + invariants.documents_match(coll, recordset_strategy) + invariants.ids_match(coll, recordset_strategy) + invariants.ann_accuracy( + coll, + recordset_strategy, + embedding_function=collection_strategy.embedding_function, + ) system_2.stop() del client_2