Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qdrant Provider #36805

Merged
merged 37 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
e1d407f
feat: Qdrant Provider
Anush008 Jan 16, 2024
13fc647
ci: qdrant to providers_bug_report_yml
Anush008 Jan 16, 2024
105534d
Merge branch 'main' into main
Anush008 Jan 16, 2024
30bbcf5
Merge branch 'main' into main
Anush008 Jan 17, 2024
bc88099
refactor: remove redundant methods, only @property conn
Anush008 Jan 18, 2024
9245f24
test: Qdrant integration tests and CI
Anush008 Jan 18, 2024
c34ddbd
Merge branch 'main' into main
Anush008 Jan 18, 2024
c305c5e
Merge branch 'main' into main
Anush008 Jan 18, 2024
63bdd64
Merge branch 'main' into main
Anush008 Jan 18, 2024
3628b27
Merge branch 'main' into main
Anush008 Jan 19, 2024
b974184
chore: remove redundant call
Anush008 Jan 20, 2024
a681e20
Merge remote-tracking branch 'upstream/main'
Anush008 Jan 23, 2024
afef677
chore: remove timeout param
Anush008 Jan 24, 2024
cdf713e
chore: ready the provider
Anush008 Jan 24, 2024
17cac73
Merge branch 'main' into main
Anush008 Jan 24, 2024
b1a9416
Update airflow/providers/qdrant/provider.yaml
Anush008 Jan 24, 2024
38ce62d
chore: default to None https
Anush008 Jan 24, 2024
3d3359c
Merge branch 'main' into main
Anush008 Jan 24, 2024
41b201f
docs: removed timeout param
Anush008 Jan 25, 2024
08d7884
docs: improved https param description
Anush008 Jan 25, 2024
588b702
Merge remote-tracking branch 'upstream/main'
Anush008 Jan 26, 2024
86c50b2
Apply suggestions from code review
Anush008 Jan 30, 2024
01e5287
Merge remote-tracking branch 'upstream/main'
Anush008 Jan 30, 2024
1c4e48d
chore: new pre-commit run
Anush008 Jan 30, 2024
2135a97
Merge remote-tracking branch 'upstream/main'
Anush008 Feb 1, 2024
dfdfbdb
chore: update pre-commit
Anush008 Feb 1, 2024
a29dbd5
Update airflow/providers/qdrant/CHANGELOG.rst
eladkal Feb 2, 2024
2099c1d
Merge branch 'main' into main
Anush008 Feb 2, 2024
3584045
Merge branch 'main' into main
Anush008 Feb 3, 2024
16688d3
chore: pin Airflow >= 2.7.0
Anush008 Feb 3, 2024
e38876e
Merge branch 'main' into main
Anush008 Feb 4, 2024
d4da8cd
Merge branch 'main' into main
Anush008 Feb 5, 2024
7892ea9
Merge branch 'main' into main
Anush008 Feb 6, 2024
d6b3cf4
test: @pytest.mark.db_test
Anush008 Feb 7, 2024
9a2a213
Merge branch 'main' into main
Anush008 Feb 7, 2024
b6bd06d
Merge branch 'main' into main
Anush008 Feb 7, 2024
829a4ce
Merge branch 'main' into main
Anush008 Feb 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ body:
- pinecone
- postgres
- presto
- qdrant
- redis
- salesforce
- samba
Expand Down
6 changes: 6 additions & 0 deletions .github/boring-cyborg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,12 @@ labelPRBasedOnFilePath:
- tests/providers/presto/**/*
- tests/system/providers/presto/**/*

provider:qdrant:
- airflow/providers/qdrant/**/*
- docs/apache-airflow-providers-qdrant/**/*
- tests/providers/qdrant/**/*
- tests/system/providers/qdrant/**/*

provider:redis:
- airflow/providers/redis/**/*
- docs/apache-airflow-providers-redis/**/*
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,9 @@ jobs:
breeze testing integration-tests --integration kafka
breeze down
if: needs.build-info.outputs.is-airflow-runner != 'true'
- name: "Integration Tests Postgres: Qdrant"
run: breeze testing integration-tests --integration qdrant
if: needs.build-info.outputs.is-airflow-runner == 'true'
- name: "Integration Tests Postgres: all-testable"
run: breeze testing integration-tests --integration all-testable
if: needs.build-info.outputs.is-airflow-runner == 'true'
Expand Down
7 changes: 4 additions & 3 deletions INSTALL
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,10 @@ gcp_api, github, github-enterprise, google, google-auth, graphviz, grpc, hashico
http, imap, influxdb, jdbc, jenkins, kerberos, kubernetes, ldap, leveldb, microsoft-azure,
microsoft-mssql, microsoft-psrp, microsoft-winrm, mongo, mssql, mysql, neo4j, odbc, openai,
openfaas, openlineage, opensearch, opsgenie, oracle, otel, pagerduty, pandas, papermill, password,
pgvector, pinecone, pinot, postgres, presto, rabbitmq, redis, s3, s3fs, salesforce, samba, saml,
segment, sendgrid, sentry, sftp, singularity, slack, smtp, snowflake, spark, sqlite, ssh, statsd,
tableau, tabular, telegram, trino, vertica, virtualenv, weaviate, webhdfs, winrm, yandex, zendesk
pgvector, pinecone, pinot, postgres, presto, qdrant, rabbitmq, redis, s3, s3fs, salesforce, samba,
saml, segment, sendgrid, sentry, sftp, singularity, slack, smtp, snowflake, spark, sqlite, ssh,
statsd, tableau, tabular, telegram, trino, vertica, virtualenv, weaviate, webhdfs, winrm, yandex,
zendesk

# END REGULAR EXTRAS HERE

Expand Down
26 changes: 26 additions & 0 deletions airflow/providers/qdrant/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

.. http://www.apache.org/licenses/LICENSE-2.0

.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.

``apache-airflow-providers-qdrant``

Changelog
---------

1.0.0
.....

* ``Initial version of the provider. (#36805)``
16 changes: 16 additions & 0 deletions airflow/providers/qdrant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
16 changes: 16 additions & 0 deletions airflow/providers/qdrant/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
128 changes: 128 additions & 0 deletions airflow/providers/qdrant/hooks/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from functools import cached_property
from typing import Any

from grpc import RpcError
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse

from airflow.hooks.base import BaseHook


class QdrantHook(BaseHook):
"""
Hook for interfacing with a Qdrant instance.

:param conn_id: The connection id to use when connecting to Qdrant. Defaults to `qdrant_default`.
"""

conn_name_attr = "conn_id"
conn_type = "qdrant"
default_conn_name = "qdrant_default"
hook_name = "Qdrant"

@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import BooleanField, IntegerField, StringField

return {
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
"url": StringField(
lazy_gettext("URL"),
widget=BS3TextFieldWidget(),
description="Optional. Qualified URL of the Qdrant instance."
"Example: https://xyz-example.eu-central.aws.cloud.qdrant.io:6333",
),
"grpc_port": IntegerField(
lazy_gettext("GPRC Port"),
widget=BS3TextFieldWidget(),
description="Optional. Port of the gRPC interface.",
default=6334,
),
"prefer_gprc": BooleanField(
lazy_gettext("Prefer GRPC"),
widget=BS3TextFieldWidget(),
description="Optional. Whether to use gPRC interface whenever possible in custom methods.",
default=False,
),
"https": BooleanField(
lazy_gettext("HTTPS"),
widget=BS3TextFieldWidget(),
description="Optional. Whether to use HTTPS(SSL) protocol.",
),
"prefix": StringField(
lazy_gettext("Prefix"),
widget=BS3TextFieldWidget(),
description="Optional. Prefix to the REST URL path."
"Example: `service/v1` will result in http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.",
),
}

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "login", "extra"],
"relabeling": {"password": "API Key"},
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
}

def __init__(self, conn_id: str = default_conn_name, **kwargs) -> None:
super().__init__(**kwargs)
self.conn_id = conn_id

def get_conn(self) -> QdrantClient:
"""Get a Qdrant client instance for interfacing with the database."""
connection = self.get_connection(self.conn_id)
host = connection.host or None
port = connection.port or 6333
api_key = connection.password
extra = connection.extra_dejson
url = extra.get("url", None)
grpc_port = extra.get("grpc_port", 6334)
prefer_gprc = extra.get("prefer_gprc", False)
https = extra.get("https", None)
prefix = extra.get("prefix", None)

return QdrantClient(
host=host,
port=port,
url=url,
api_key=api_key,
grpc_port=grpc_port,
prefer_grpc=prefer_gprc,
https=https,
prefix=prefix,
)

@cached_property
def conn(self) -> QdrantClient:
"""Get a Qdrant client instance for interfacing with the database."""
return self.get_conn()

def verify_connection(self) -> tuple[bool, str]:
"""Check the connection to the Qdrant instance."""
try:
self.conn.get_collections()
return True, "Connection established!"
except (UnexpectedResponse, RpcError, ValueError) as e:
return False, str(e)
16 changes: 16 additions & 0 deletions airflow/providers/qdrant/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
109 changes: 109 additions & 0 deletions airflow/providers/qdrant/operators/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Iterable, Sequence

from airflow.models import BaseOperator
from airflow.providers.qdrant.hooks.qdrant import QdrantHook

if TYPE_CHECKING:
from qdrant_client.models import VectorStruct

from airflow.utils.context import Context


class QdrantIngestOperator(BaseOperator):
"""
Upload points to a Qdrant collection.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:QdrantIngestOperator`

:param conn_id: The connection id to connect to a Qdrant instance.
:param collection_name: The name of the collection to ingest data into.
:param vectors: An iterable over vectors to upload.
:param payload: Iterable of vector payloads, Optional. Defaults to None.
:param ids: Iterable of custom vector ids, Optional. Defaults to None.
:param batch_size: Number of points to upload per-request. Defaults to 64.
:param parallel: Number of parallel upload processes. Defaults to 1.
:param method: Start method for parallel processes. Defaults to 'forkserver'.
:param max_retries: Number of retries for failed requests. Defaults to 3.
:param wait: Await for the results to be applied on the server side. Defaults to True.
:param kwargs: Additional keyword arguments passed to the BaseOperator constructor.
"""

template_fields: Sequence[str] = (
"collection_name",
"vectors",
"payload",
"ids",
"batch_size",
"parallel",
"method",
"max_retries",
"wait",
)

def __init__(
self,
*,
conn_id: str = QdrantHook.default_conn_name,
collection_name: str,
vectors: Iterable[VectorStruct],
payload: Iterable[dict[str, Any]] | None = None,
ids: Iterable[int | str] | None = None,
batch_size: int = 64,
parallel: int = 1,
method: str | None = None,
max_retries: int = 3,
wait: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.conn_id = conn_id
self.collection_name = collection_name
self.vectors = vectors
self.payload = payload
self.ids = ids
self.batch_size = batch_size
self.parallel = parallel
self.method = method
self.max_retries = max_retries
self.wait = wait

@cached_property
def hook(self) -> QdrantHook:
"""Return an instance of QdrantHook."""
return QdrantHook(conn_id=self.conn_id)

def execute(self, context: Context) -> None:
"""Upload points to a Qdrant collection."""
self.hook.conn.upload_collection(
collection_name=self.collection_name,
vectors=self.vectors,
payload=self.payload,
ids=self.ids,
batch_size=self.batch_size,
parallel=self.parallel,
method=self.method,
max_retries=self.max_retries,
wait=self.wait,
)
Loading