diff --git a/qdrant_client/conversions/conversion.py b/qdrant_client/conversions/conversion.py index 00a8b900..ce73ff3d 100644 --- a/qdrant_client/conversions/conversion.py +++ b/qdrant_client/conversions/conversion.py @@ -2901,9 +2901,12 @@ def convert_search_points( def convert_query_request( cls, model: rest.QueryRequest, collection_name: str ) -> grpc.QueryPoints: + prefetch = ( + [model.prefetch] if isinstance(model.prefetch, rest.Prefetch) else model.prefetch + ) return grpc.QueryPoints( collection_name=collection_name, - prefetch=[cls.convert_prefetch_query(prefetch) for prefetch in model.prefetch] + prefetch=[cls.convert_prefetch_query(p) for p in prefetch] if model.prefetch is not None else None, query=cls.convert_query_interface(model.query) if model.query is not None else None, diff --git a/qdrant_client/local/async_qdrant_local.py b/qdrant_client/local/async_qdrant_local.py index d6186712..ab131cf3 100644 --- a/qdrant_client/local/async_qdrant_local.py +++ b/qdrant_client/local/async_qdrant_local.py @@ -442,7 +442,7 @@ async def query_batch_points( prefetch=request.prefetch, query_filter=request.filter, limit=request.limit, - offset=request.offset, + offset=request.offset or 0, with_payload=request.with_payload, with_vectors=request.with_vector, score_threshold=request.score_threshold, diff --git a/qdrant_client/local/qdrant_local.py b/qdrant_client/local/qdrant_local.py index 2cc86395..8a74d357 100644 --- a/qdrant_client/local/qdrant_local.py +++ b/qdrant_client/local/qdrant_local.py @@ -476,7 +476,7 @@ def query_batch_points( prefetch=request.prefetch, query_filter=request.filter, limit=request.limit, - offset=request.offset, + offset=request.offset or 0, with_payload=request.with_payload, with_vectors=request.with_vector, score_threshold=request.score_threshold, diff --git a/tests/congruence_tests/test_query_batch.py b/tests/congruence_tests/test_query_batch.py index f8367ac3..bc4cd30e 100644 --- a/tests/congruence_tests/test_query_batch.py +++ b/tests/congruence_tests/test_query_batch.py @@ -42,6 +42,9 @@ def __init__(self): self.dense_vector_query_batch_text.append( models.QueryRequest( query=np.random.random(text_vector_size).tolist(), + prefetch=models.Prefetch( + query=np.random.random(text_vector_size).tolist(), limit=5, using="text" + ), limit=5, using="text", with_payload=True, diff --git a/tests/conversions/test_validate_conversions.py b/tests/conversions/test_validate_conversions.py index 4d666015..352005d8 100644 --- a/tests/conversions/test_validate_conversions.py +++ b/tests/conversions/test_validate_conversions.py @@ -355,3 +355,25 @@ def test_convert_flat_filter(): assert recovered.must[0] == rest_filter.must assert recovered.should[0] == rest_filter.should assert recovered.must_not[0] == rest_filter.must_not + + +def test_query_points(): + from qdrant_client import models + from qdrant_client.conversions.conversion import GrpcToRest, RestToGrpc + + prefetch = models.Prefetch(query=models.NearestQuery(nearest=[1.0, 2.0])) + query_request = models.QueryRequest( + query=1, + limit=5, + using="test", + with_payload=True, + prefetch=prefetch, + ) + grpc_query_request = RestToGrpc.convert_query_request(query_request, "check") + recovered = GrpcToRest.convert_query_points(grpc_query_request) + + assert recovered.query == models.NearestQuery(nearest=query_request.query) + assert recovered.limit == query_request.limit + assert recovered.using == query_request.using + assert recovered.with_payload == query_request.with_payload + assert recovered.prefetch[0] == query_request.prefetch