diff --git a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py index 7d5fa129492d1..c2fe0dba3c59d 100644 --- a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py +++ b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py @@ -41,6 +41,7 @@ import pandas as pd from weaviate.auth import AuthCredentials from weaviate.collections import Collection + from weaviate.collections.classes.batch import ErrorReference from weaviate.collections.classes.config import CollectionConfig, CollectionConfigSimple from weaviate.collections.classes.internal import ( Object, @@ -268,6 +269,61 @@ def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame | None) data = json.loads(data.to_json(orient="records")) return cast("list[dict[str, Any]]", data) + def batch_create_links( + self, + collection_name: str, + data: list[dict[str, Any]] | pd.DataFrame | None, + from_property_col: str = "from_property", + from_uuid_col: str = "from_uuid", + to_uuid_col: str = "to", + retry_attempts_per_object: int = 5, + ) -> list[ErrorReference] | None: + """ + Batch create links from an object to another other object through cross-references (https://weaviate.io/developers/weaviate/manage-data/import#import-with-references). + + :param collection_name: The name of the collection containing the source objects. + :param data: list or dataframe of objects we want to create links. + :param from_property_col: name of the reference property column. + :param from_uuid_col: Name of the column containing the from UUID. + :param to_uuid_col: Name of the column containing the target UUID. + :param retry_attempts_per_object: number of time to try in case of failure before giving up. + """ + converted_data = self._convert_dataframe_to_list(data) + collection = self.get_collection(collection_name) + + with collection.batch.dynamic() as batch: + # Batch create links + for data_obj in converted_data: + for attempt in Retrying( + stop=stop_after_attempt(retry_attempts_per_object), + retry=( + retry_if_exception(lambda exc: check_http_error_is_retryable(exc)) + | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES) + ), + ): + with attempt: + from_property = data_obj.pop(from_property_col, None) + from_uuid = data_obj.pop(from_uuid_col, None) + to_uuid = data_obj.pop(to_uuid_col, None) + self.log.debug( + "Attempt %s of create links between %s and %s using reference property %s", + attempt.retry_state.attempt_number, + from_uuid, + to_uuid, + from_property, + ) + batch.add_reference( + from_property=from_property, + from_uuid=from_uuid, + to=to_uuid, + ) + + failed_references = collection.batch.failed_references + if failed_references: + self.log.error("Number of failed imports: %s", len(failed_references)) + + return failed_references + def batch_data( self, collection_name: str, diff --git a/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py b/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py index 48abfb3ffee8a..7f615035a56a8 100644 --- a/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py +++ b/providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py @@ -462,6 +462,94 @@ def test_create_collection(weaviate_hook): ) +@pytest.mark.parametrize( + argnames=["data", "expected_length"], + argvalues=[ + ( + [ + { + "from_uuid": "0fe86eae-45f7-456c-b19f-04fc59e9ce41", + "to_uuid": "360b6f5b-ed23-413c-a6e8-cb864a52e712", + "from_property": "hasCategory", + }, + { + "from_uuid": "34ccb2e1-1cfc-46e5-94d2-48c335e52c29", + "to_uuid": "a775ef49-a8ab-480d-ac85-b70197654072", + "from_property": "hasCategory", + }, + ], + 2, + ), + ( + pd.DataFrame.from_dict( + { + "from_uuid": [ + "0fe86eae-45f7-456c-b19f-04fc59e9ce41", + "34ccb2e1-1cfc-46e5-94d2-48c335e52c29", + ], + "to_uuid": [ + "360b6f5b-ed23-413c-a6e8-cb864a52e712", + "a775ef49-a8ab-480d-ac85-b70197654072", + ], + "from_property": ["hasCategory", "hasCategory"], + } + ), + 2, + ), + ], + ids=("batch create link data as list of dicts", "batch create link data as dataframe"), +) +def test_batch_create_links(data, expected_length, weaviate_hook): + """ + Test the batch_create_links method of WeaviateHook. + """ + # Mock the Weaviate Collection + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + # Define test data + test_collection_name = "TestCollection" + + # Test the batch_data method + weaviate_hook.batch_create_links(test_collection_name, data, to_uuid_col="to_uuid") + + mock_batch_context = mock_collection.batch.dynamic.return_value.__enter__.return_value + assert mock_batch_context.add_reference.call_count == expected_length + + +def test_batch_create_links_retry(weaviate_hook): + """Test to ensure retrying working as expected""" + # Mock the Weaviate Collection + mock_collection = MagicMock() + weaviate_hook.get_collection = MagicMock(return_value=mock_collection) + + data = [ + { + "from_uuid": "0fe86eae-45f7-456c-b19f-04fc59e9ce41", + "to": "360b6f5b-ed23-413c-a6e8-cb864a52e712", + "from_property": "hasCategory", + }, + { + "from_uuid": "34ccb2e1-1cfc-46e5-94d2-48c335e52c29", + "to": "a775ef49-a8ab-480d-ac85-b70197654072", + "from_property": "hasCategory", + }, + ] + response = requests.Response() + response.status_code = 429 + error = requests.exceptions.HTTPError() + error.response = response + side_effect = [None, error, error, None] + + mock_collection.batch.dynamic.return_value.__enter__.return_value.add_reference.side_effect = side_effect + + weaviate_hook.batch_create_links("TestCollection", data) + + assert mock_collection.batch.dynamic.return_value.__enter__.return_value.add_reference.call_count == len( + side_effect + ) + + @pytest.mark.parametrize( argnames=["data", "expected_length"], argvalues=[