Skip to content

Commit

Permalink
refactor: remove redundant methods, only @Property conn
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Jan 18, 2024
1 parent 30bbcf5 commit bc88099
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 235 deletions.
194 changes: 6 additions & 188 deletions airflow/providers/qdrant/hooks/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from __future__ import annotations

from functools import cached_property
from typing import Any, Iterable, Mapping, Sequence
from typing import Any

from grpc import RpcError
from qdrant_client import QdrantClient, models
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse

from airflow.hooks.base import BaseHook
Expand Down Expand Up @@ -52,6 +52,7 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
widget=BS3TextFieldWidget(),
description="Optional. Qualified URL of the Qdrant instance."
"Example: https://xyz-example.eu-central.aws.cloud.qdrant.io:6333",
default=None,
),
"grpc_port": IntegerField(
lazy_gettext("GPRC Port"),
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(self, conn_id: str = default_conn_name, **kwargs) -> None:
self.get_conn()

def get_conn(self) -> QdrantClient:
"""Get a Qdrant client instance."""
"""Get a Qdrant client instance for interfacing with the database."""
connection = self.get_connection(self.conn_id)
host = connection.host or None
port = connection.port or 6333
Expand All @@ -108,7 +109,7 @@ def get_conn(self) -> QdrantClient:
grpc_port = extra.get("grpc_port", 6334)
prefer_gprc = extra.get("prefer_gprc", False)
https = extra.get("https", False)
prefix = extra.get("prefix", "")
prefix = extra.get("prefix", None)
timeout = extra.get("timeout", None)

return QdrantClient(
Expand All @@ -125,7 +126,7 @@ def get_conn(self) -> QdrantClient:

@cached_property
def conn(self) -> QdrantClient:
"""Get a Qdrant client instance."""
"""Get a Qdrant client instance for interfacing with the database."""
return self.get_conn()

def verify_connection(self) -> tuple[bool, str]:
Expand All @@ -135,186 +136,3 @@ def verify_connection(self) -> tuple[bool, str]:
return True, "Connection established!"
except (UnexpectedResponse, RpcError, ValueError) as e:
return False, str(e)

def list_collections(self) -> list[str]:
"""Get a list of collections in the Qdrant instance."""
return [collection.name for collection in self.conn.get_collections().collections]

def upsert(
self,
collection_name: str,
vectors: Iterable[models.VectorStruct],
payload: Iterable[dict[str, Any]] | None = None,
ids: Iterable[str | int] | None = None,
batch_size: int = 64,
parallel: int = 1,
method: str | None = None,
max_retries: int = 3,
wait: bool = True,
) -> None:
"""
Upload points to a Qdrant collection.
:param collection_name: Name of the collection to upload points to.
:param vectors: An iterable over vectors to upload.
:param payload: Iterable of vectors payload, Optional. Defaults to None.
:param ids: Iterable of custom vectors ids, Optional. Defaults to None.
:param batch_size: Number of points to upload per-request. Defaults to 64.
:param parallel: Number of parallel upload processes. Defaults to 1.
:param method: Start method for parallel processes. Defaults to forkserver.
:param max_retries: Number of retries for failed requests. Defaults to 3.
:param wait: Await for the results to be applied on the server side. Defaults to True.
"""
return self.conn.upload_collection(
collection_name=collection_name,
vectors=vectors,
payload=payload,
ids=ids,
batch_size=batch_size,
parallel=parallel,
method=method,
max_retries=max_retries,
wait=wait,
)

def delete(
self,
collection_name: str,
points_selector: models.PointsSelector,
wait: bool = True,
ordering: models.WriteOrdering | None = None,
shard_key_selector: models.ShardKeySelector | None = None,
) -> None:
"""
Delete points from a Qdrant collection.
:param collection_name: Name of the collection to delete points from.
:param points_selector: Selector for points to delete.
:param wait: Await for the results to be applied on the server side. Defaults to True.
:param ordering: Ordering of the write operation. Defaults to None.
:param shard_key_selector: Selector for the shard key. Defaults to None.
"""
self.conn.delete(
collection_name=collection_name,
points_selector=points_selector,
wait=wait,
ordering=ordering,
shard_key_selector=shard_key_selector,
)

def search(
self,
collection_name: str,
query_vector: Sequence[float]
| tuple[str, list[float]]
| models.NamedVector
| models.NamedSparseVector,
query_filter: models.Filter | None = None,
search_params: models.SearchParams | None = None,
limit: int = 10,
offset: int | None = None,
with_payload: bool | Sequence[str] | models.PayloadSelector = True,
with_vectors: bool | Sequence[str] = False,
score_threshold: float | None = None,
consistency: models.ReadConsistency | None = None,
shard_key_selector: models.ShardKeySelector | None = None,
timeout: int | None = None,
):
"""
Search for the closest points in a Qdrant collection.
:param collection_name: Name of the collection to upload points to.
:param quey_vector: Query vector to search for.
:param query_filter: Filter for the query. Defaults to None.
:param search_params: Additional search parameters. Defaults to None.
:param limit: Number of results to return. Defaults to 10.
:param offset: Offset of the first results to return. Defaults to None.
:param with_payload: To specify which stored payload should be attached to the result. Defaults to True.
:param with_vectors: To specify whether vectors should be attached to the result. Defaults to False.
:param score_threshold: To specify the minimum score threshold of the results. Defaults to None.
:param consistency: Defines how many replicas should be queried before returning the result. Defaults to None.
:param shard_key_selector: To specify which shards should be queried.. Defaults to None.
:param wait: Await for the results to be applied on the server side. Defaults to True.
"""
return self.conn.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=query_filter,
search_params=search_params,
limit=limit,
offset=offset,
with_payload=with_payload,
with_vectors=with_vectors,
score_threshold=score_threshold,
consistency=consistency,
shard_key_selector=shard_key_selector,
timeout=timeout,
)

def create_collection(
self,
collection_name: str,
vectors_config: models.VectorParams | Mapping[str, models.VectorParams],
sparse_vectors_config: Mapping[str, models.SparseVectorParams] | None = None,
shard_number: int | None = None,
sharding_method: models.ShardingMethod | None = None,
replication_factor: int | None = None,
write_consistency_factor: int | None = None,
on_disk_payload: bool | None = None,
hnsw_config: models.HnswConfigDiff | None = None,
optimizers_config: models.OptimizersConfigDiff | None = None,
wal_config: models.WalConfigDiff | None = None,
quantization_config: models.QuantizationConfig | None = None,
init_from: models.InitFrom | None = None,
timeout: int | None = None,
) -> bool:
"""
Create a new Qdrant collection.
:param collection_name: Name of the collection to upload points to.
:param vectors_config: Configuration of the vector storage contains size and distance for the vectors.
:param sparse_vectors_config: Configuration of the sparse vector storage. Defaults to None.
:param shard_number: Number of shards in collection. Default is 1, minimum is 1.
:param sharding_method: Defines strategy for shard creation. Defaults to auto.
:param replication_factor: Replication factor for collection. Default is 1, minimum is 1.
:param write_consistency_factor: Write consistency factor for collection. Default is 1, minimum is 1.
:param on_disk_payload: If true - point`s payload will not be stored in memory.
:param hnsw_config: Parameters for HNSW index.
:param optimizers_config: Parameters for optimizer.
:param wal_config: Parameters for Write-Ahead-Log.
:param quantization_config: Parameters for quantization, if None - quantization will be disabled.
:param init_from: Whether to use data stored in another collection to initialize this collection.
:param timeout: Timeout for the request. Defaults to None.
"""
return self.conn.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config,
shard_number=shard_number,
sharding_method=sharding_method,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=timeout,
)

def get_collection(self, collection_name: str) -> models.CollectionInfo:
"""
Get information about a Qdrant collection.
:param collection_name: Name of the collection to get information about.
"""
return self.conn.get_collection(collection_name=collection_name)

def delete_collection(self, collection_name: str, timeout: int | None) -> bool:
"""
Delete a Qdrant collection.
:param collection_name: Name of the collection to delete.
"""
return self.conn.delete_collection(collection_name=collection_name, timeout=timeout)
2 changes: 1 addition & 1 deletion airflow/providers/qdrant/operators/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def hook(self) -> QdrantHook:

def execute(self, context: Context) -> None:
"""Upload points to a Qdrant collection."""
self.hook.upsert(
self.hook.conn.upload_collection(
collection_name=self.collection_name,
vectors=self.vectors,
payload=self.payload,
Expand Down
69 changes: 40 additions & 29 deletions tests/providers/qdrant/hooks/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,91 +28,102 @@ def setup_method(self):
mock_conn = Mock()
mock_conn.host = "localhost"
mock_conn.port = 6333
mock_conn.extra_dejson = {}
mock_conn.password = "some_test_api_key"
mock_get_connection.return_value = mock_conn
self.qdrant_hook = QdrantHook
self.qdrant_hook = QdrantHook()

self.collection_name = "test_collection"

@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.upsert")
def test_upsert(self, mock_upsert):
@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.conn")
def test_verify_connection(self, mock_conn):
"""Test the verify_connection of the QdrantHook."""
self.qdrant_hook.verify_connection()

mock_conn.get_collections.assert_called_once()

@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.conn")
def test_upsert(self, conn):
"""Test the upsert method of the QdrantHook with appropriate arguments."""
vectors = [[0.732, 0.611, 0.289], [0.217, 0.526, 0.416], [0.326, 0.483, 0.376]]
ids = [32, 21, "b626f6a9-b14d-4af9-b7c3-43d8deb719a6"]
payloads = [{"meta": "data"}, {"meta": "data_2"}, {"meta": "data_3", "extra": "data"}]
parallel = 2
self.qdrant_hook.upsert(
self.qdrant_hook.conn.upsert(
collection_name=self.collection_name,
vectors=vectors,
ids=ids,
payloads=payloads,
parallel=parallel,
)
mock_upsert.assert_called_once_with(
conn.upsert.assert_called_once_with(
collection_name=self.collection_name,
vectors=vectors,
ids=ids,
payloads=payloads,
parallel=parallel,
)

@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.list_collections")
def test_list_collections(self, mock_list_collections):
@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.conn")
def test_list_collections(self, conn):
"""Test that the list_collections is called correctly."""
self.qdrant_hook.list_collections()
mock_list_collections.assert_called_once()
self.qdrant_hook.conn.list_collections()
conn.list_collections.assert_called_once()

@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.create_collection")
def test_create_collection(self, mock_create_collection):
@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.conn")
def test_create_collection(self, conn):
"""Test that the create_collection is called with correct arguments."""

from qdrant_client.models import Distance, VectorParams

self.qdrant_hook.create_collection(
self.qdrant_hook.conn.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
)
mock_create_collection.assert_called_once_with(
conn.create_collection.assert_called_once_with(
collection_name=self.collection_name,
vectors_config=VectorParams(size=384, distance=Distance.COSINE),
)

@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.delete")
def test_delete(self, mock_delete):
@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.conn")
def test_delete(self, conn):
"""Test that the delete is called with correct arguments."""

self.qdrant_hook.delete(collection_name=self.collection_name, points_selector=[32, 21], wait=False)
self.qdrant_hook.conn.delete(
collection_name=self.collection_name, points_selector=[32, 21], wait=False
)

mock_delete.assert_called_once_with(
conn.delete.assert_called_once_with(
collection_name=self.collection_name, points_selector=[32, 21], wait=False
)

@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.search")
def test_search(self, mock_search):
@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.conn")
def test_search(self, conn):
"""Test that the search is called with correct arguments."""

self.qdrant_hook.search(
self.qdrant_hook.conn.search(
collection_name=self.collection_name,
query_vector=[1.0, 2.0, 3.0],
limit=10,
with_vectors=True,
)

mock_search.assert_called_once_with(
conn.search.assert_called_once_with(
collection_name=self.collection_name, query_vector=[1.0, 2.0, 3.0], limit=10, with_vectors=True
)

@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.get_collection")
def test_get_collection(self, mock_get_collection):
@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.conn")
def test_get_collection(self, conn):
"""Test that the get_collection is called with correct arguments."""

self.qdrant_hook.get_collection(collection_name=self.collection_name)
self.qdrant_hook.conn.get_collection(collection_name=self.collection_name)

mock_get_collection.assert_called_once_with(collection_name=self.collection_name)
conn.get_collection.assert_called_once_with(collection_name=self.collection_name)

@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.delete_collection")
def test_delete_collection(self, mock_delete_collection):
@patch("airflow.providers.qdrant.hooks.qdrant.QdrantHook.conn")
def test_delete_collection(self, conn):
"""Test that the delete_collection is called with correct arguments."""

self.qdrant_hook.delete_collection(collection_name=self.collection_name)
self.qdrant_hook.conn.delete_collection(collection_name=self.collection_name)

mock_delete_collection.assert_called_once_with(collection_name=self.collection_name)
conn.delete_collection.assert_called_once_with(collection_name=self.collection_name)
Loading

0 comments on commit bc88099

Please sign in to comment.