Skip to content

Commit

Permalink
Add WeaviateDocumentIngestOperator (#36402)
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsharma2 authored Dec 24, 2023
1 parent 63544e1 commit 97d2266
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 1 deletion.
85 changes: 85 additions & 0 deletions airflow/providers/weaviate/operators/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,88 @@ def execute(self, context: Context) -> list:
tenant=self.tenant,
)
return insertion_errors


class WeaviateDocumentIngestOperator(BaseOperator):
"""
Create or replace objects belonging to documents.
In real-world scenarios, information sources like Airflow docs, Stack Overflow, or other issues
are considered 'documents' here. It's crucial to keep the database objects in sync with these sources.
If any changes occur in these documents, this function aims to reflect those changes in the database.
.. note::
This function assumes responsibility for identifying changes in documents, dropping relevant
database objects, and recreating them based on updated information. It's crucial to handle this
process with care, ensuring backups and validation are in place to prevent data loss or
inconsistencies.
Provides users with multiple ways of dealing with existing values.
replace: replace the existing objects with new objects. This option requires to identify the
objects belonging to a document. which by default is done by using document_column field.
skip: skip the existing objects and only add the missing objects of a document.
error: raise an error if an object belonging to a existing document is tried to be created.
:param data: A single pandas DataFrame or a list of dicts to be ingested.
:param class_name: Name of the class in Weaviate schema where data is to be ingested.
:param existing: Strategy for handling existing data: 'skip', or 'replace'. Default is 'skip'.
:param document_column: Column in DataFrame that identifying source document.
:param uuid_column: Column with pre-generated UUIDs. If not provided, UUIDs will be generated.
:param vector_column: Column with embedding vectors for pre-embedded data.
:param batch_config_params: Additional parameters for Weaviate batch configuration.
:param tenant: The tenant to which the object will be added.
:param verbose: Flag to enable verbose output during the ingestion process.
:return: list of UUID which failed to create
"""

template_fields: Sequence[str] = ("input_data",)

def __init__(
self,
conn_id: str,
input_data: pd.DataFrame | list[dict[str, Any]] | list[pd.DataFrame],
class_name: str,
document_column: str,
existing: str = "skip",
uuid_column: str = "id",
vector_col: str = "Vector",
batch_config_params: dict | None = None,
tenant: str | None = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
self.hook_params = kwargs.pop("hook_params", {})

super().__init__(**kwargs)

self.conn_id = conn_id
self.input_data = input_data
self.class_name = class_name
self.document_column = document_column
self.existing = existing
self.uuid_column = uuid_column
self.vector_col = vector_col
self.batch_config_params = batch_config_params
self.tenant = tenant
self.verbose = verbose

@cached_property
def hook(self) -> WeaviateHook:
"""Return an instance of the WeaviateHook."""
return WeaviateHook(conn_id=self.conn_id, **self.hook_params)

def execute(self, context: Context) -> list:
self.log.debug("Total input objects : %s", len(self.input_data))
insertion_errors = self.hook.create_or_replace_document_objects(
data=self.input_data,
class_name=self.class_name,
document_column=self.document_column,
existing=self.existing,
uuid_column=self.uuid_column,
vector_column=self.vector_col,
batch_config_params=self.batch_config_params,
tenant=self.tenant,
verbose=self.verbose,
)
return insertion_errors
51 changes: 50 additions & 1 deletion tests/providers/weaviate/operators/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

import pytest

from airflow.providers.weaviate.operators.weaviate import WeaviateIngestOperator
from airflow.providers.weaviate.operators.weaviate import (
WeaviateDocumentIngestOperator,
WeaviateIngestOperator,
)


class TestWeaviateIngestOperator:
Expand Down Expand Up @@ -73,3 +76,49 @@ def test_templates(self, create_task_instance_of_operator):

assert dag_id == ti.task.input_json
assert dag_id == ti.task.input_data


class TestWeaviateDocumentIngestOperator:
@pytest.fixture
def operator(self):
return WeaviateDocumentIngestOperator(
task_id="weaviate_task",
conn_id="weaviate_conn",
input_data=[{"data": "sample_data"}],
class_name="my_class",
document_column="docLink",
existing="skip",
uuid_column="id",
vector_col="vector",
batch_config_params={"size": 1000},
)

def test_constructor(self, operator):
assert operator.conn_id == "weaviate_conn"
assert operator.input_data == [{"data": "sample_data"}]
assert operator.class_name == "my_class"
assert operator.document_column == "docLink"
assert operator.existing == "skip"
assert operator.uuid_column == "id"
assert operator.vector_col == "vector"
assert operator.batch_config_params == {"size": 1000}
assert operator.hook_params == {}

@patch("airflow.providers.weaviate.operators.weaviate.WeaviateDocumentIngestOperator.log")
def test_execute_with_input_json(self, mock_log, operator):
operator.hook.create_or_replace_document_objects = MagicMock()

operator.execute(context=None)

operator.hook.create_or_replace_document_objects.assert_called_once_with(
data=[{"data": "sample_data"}],
class_name="my_class",
document_column="docLink",
existing="skip",
uuid_column="id",
vector_column="vector",
batch_config_params={"size": 1000},
tenant=None,
verbose=False,
)
mock_log.debug.assert_called_once_with("Total input objects : %s", len([{"data": "sample_data"}]))

0 comments on commit 97d2266

Please sign in to comment.