Skip to content

Commit

Permalink
new: add query interface tests, fix version checking (#670)
Browse files Browse the repository at this point in the history
* new: add query interface tests, fix version checking

* fix: add missing file
  • Loading branch information
joein authored Jun 28, 2024
1 parent 385e445 commit 00d40c8
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 94 deletions.
40 changes: 36 additions & 4 deletions tests/test_async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from qdrant_client.async_qdrant_client import AsyncQdrantClient
from qdrant_client.conversions.conversion import payload_to_grpc
from tests.fixtures.payload import one_random_payload_please
from tests.utils import read_version

NUM_VECTORS = 100
NUM_QUERIES = 100
Expand Down Expand Up @@ -96,7 +97,9 @@ async def test_async_grpc():
@pytest.mark.asyncio
@pytest.mark.parametrize("prefer_grpc", [True, False])
async def test_async_qdrant_client(prefer_grpc):
version = os.getenv("QDRANT_VERSION")
_, minor_version, patch_version, dev_version = read_version()
version_set = minor_version is not None or dev_version

client = AsyncQdrantClient(prefer_grpc=prefer_grpc, timeout=15)
collection_params = dict(
collection_name=COLLECTION_NAME,
Expand All @@ -115,7 +118,7 @@ async def test_async_qdrant_client(prefer_grpc):

await client.get_collection(COLLECTION_NAME)
await client.get_collections()
if version is None or (version >= "v1.8.0" or version == "dev"):
if not version_set or dev_version or minor_version >= 8:
await client.collection_exists(COLLECTION_NAME)

await client.update_collection(
Expand Down Expand Up @@ -217,6 +220,20 @@ async def test_async_qdrant_client(prefer_grpc):
== 7
)

if not version_set or dev_version or minor_version >= 10:
assert (
len(
(
await client.query_points(COLLECTION_NAME, query=np.random.rand(10).tolist())
).points
)
== 10
)
query_responses = await client.query_batch_points(
COLLECTION_NAME, requests=[models.QueryRequest(query=np.random.rand(10).tolist())]
)
assert len(query_responses) == 1 and len(query_responses[0].points) == 10

assert len(await client.retrieve(COLLECTION_NAME, ids=[3, 5])) == 2

await client.create_payload_index(
Expand Down Expand Up @@ -331,7 +348,8 @@ async def test_async_qdrant_client(prefer_grpc):

@pytest.mark.asyncio
async def test_async_qdrant_client_local():
version = os.getenv("QDRANT_VERSION")
_, minor_version, patch_version, dev_version = read_version()
version_set = minor_version is not None or dev_version
client = AsyncQdrantClient(":memory:")

collection_params = dict(
Expand All @@ -344,7 +362,7 @@ async def test_async_qdrant_client_local():

await client.get_collection(COLLECTION_NAME)
await client.get_collections()
if version is None or (version >= "v1.8.0" or version == "dev"):
if not version_set or (dev_version or minor_version >= 8):
await client.collection_exists(COLLECTION_NAME)
await client.update_collection(
COLLECTION_NAME, hnsw_config=models.HnswConfigDiff(m=32, ef_construct=120)
Expand Down Expand Up @@ -422,6 +440,20 @@ async def test_async_qdrant_client_local():
== 4
)

if not version_set or dev_version or minor_version >= 10:
assert (
len(
(
await client.query_points(COLLECTION_NAME, query=np.random.rand(10).tolist())
).points
)
== 10
)
query_responses = await client.query_batch_points(
COLLECTION_NAME, requests=[models.QueryRequest(query=np.random.rand(10).tolist())]
)
assert len(query_responses) == 1 and len(query_responses[0].points) == 10

assert len(await client.recommend(COLLECTION_NAME, positive=[0], limit=5)) == 5
assert (
len(
Expand Down
Loading

0 comments on commit 00d40c8

Please sign in to comment.