From 589772635be1d468256f851be0f4c68168af1f43 Mon Sep 17 00:00:00 2001 From: George Date: Tue, 15 Oct 2024 11:25:06 +0200 Subject: [PATCH] fix: fix prefetch conversion, fix local mode query batch points offset, add tests (#812) --- qdrant_client/conversions/conversion.py | 5 ++++- qdrant_client/local/async_qdrant_local.py | 2 +- qdrant_client/local/qdrant_local.py | 2 +- tests/congruence_tests/test_query_batch.py | 3 +++ .../conversions/test_validate_conversions.py | 22 +++++++++++++++++++ 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/qdrant_client/conversions/conversion.py b/qdrant_client/conversions/conversion.py index 00a8b9007..ce73ff3d5 100644 --- a/qdrant_client/conversions/conversion.py +++ b/qdrant_client/conversions/conversion.py @@ -2901,9 +2901,12 @@ def convert_search_points( def convert_query_request( cls, model: rest.QueryRequest, collection_name: str ) -> grpc.QueryPoints: + prefetch = ( + [model.prefetch] if isinstance(model.prefetch, rest.Prefetch) else model.prefetch + ) return grpc.QueryPoints( collection_name=collection_name, - prefetch=[cls.convert_prefetch_query(prefetch) for prefetch in model.prefetch] + prefetch=[cls.convert_prefetch_query(p) for p in prefetch] if model.prefetch is not None else None, query=cls.convert_query_interface(model.query) if model.query is not None else None, diff --git a/qdrant_client/local/async_qdrant_local.py b/qdrant_client/local/async_qdrant_local.py index d6186712b..ab131cf3d 100644 --- a/qdrant_client/local/async_qdrant_local.py +++ b/qdrant_client/local/async_qdrant_local.py @@ -442,7 +442,7 @@ async def query_batch_points( prefetch=request.prefetch, query_filter=request.filter, limit=request.limit, - offset=request.offset, + offset=request.offset or 0, with_payload=request.with_payload, with_vectors=request.with_vector, score_threshold=request.score_threshold, diff --git a/qdrant_client/local/qdrant_local.py b/qdrant_client/local/qdrant_local.py index 2cc86395d..8a74d3576 100644 --- a/qdrant_client/local/qdrant_local.py +++ b/qdrant_client/local/qdrant_local.py @@ -476,7 +476,7 @@ def query_batch_points( prefetch=request.prefetch, query_filter=request.filter, limit=request.limit, - offset=request.offset, + offset=request.offset or 0, with_payload=request.with_payload, with_vectors=request.with_vector, score_threshold=request.score_threshold, diff --git a/tests/congruence_tests/test_query_batch.py b/tests/congruence_tests/test_query_batch.py index f8367ac3f..bc4cd30e5 100644 --- a/tests/congruence_tests/test_query_batch.py +++ b/tests/congruence_tests/test_query_batch.py @@ -42,6 +42,9 @@ def __init__(self): self.dense_vector_query_batch_text.append( models.QueryRequest( query=np.random.random(text_vector_size).tolist(), + prefetch=models.Prefetch( + query=np.random.random(text_vector_size).tolist(), limit=5, using="text" + ), limit=5, using="text", with_payload=True, diff --git a/tests/conversions/test_validate_conversions.py b/tests/conversions/test_validate_conversions.py index 4d666015b..352005d81 100644 --- a/tests/conversions/test_validate_conversions.py +++ b/tests/conversions/test_validate_conversions.py @@ -355,3 +355,25 @@ def test_convert_flat_filter(): assert recovered.must[0] == rest_filter.must assert recovered.should[0] == rest_filter.should assert recovered.must_not[0] == rest_filter.must_not + + +def test_query_points(): + from qdrant_client import models + from qdrant_client.conversions.conversion import GrpcToRest, RestToGrpc + + prefetch = models.Prefetch(query=models.NearestQuery(nearest=[1.0, 2.0])) + query_request = models.QueryRequest( + query=1, + limit=5, + using="test", + with_payload=True, + prefetch=prefetch, + ) + grpc_query_request = RestToGrpc.convert_query_request(query_request, "check") + recovered = GrpcToRest.convert_query_points(grpc_query_request) + + assert recovered.query == models.NearestQuery(nearest=query_request.query) + assert recovered.limit == query_request.limit + assert recovered.using == query_request.using + assert recovered.with_payload == query_request.with_payload + assert recovered.prefetch[0] == query_request.prefetch