Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import pandas as pd
from weaviate.auth import AuthCredentials
from weaviate.collections import Collection
from weaviate.collections.classes.batch import ErrorReference
from weaviate.collections.classes.config import CollectionConfig, CollectionConfigSimple
from weaviate.collections.classes.internal import (
Object,
Expand Down Expand Up @@ -268,6 +269,61 @@ def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame | None)
data = json.loads(data.to_json(orient="records"))
return cast("list[dict[str, Any]]", data)

def batch_create_links(
self,
collection_name: str,
data: list[dict[str, Any]] | pd.DataFrame | None,
from_property_col: str = "from_property",
from_uuid_col: str = "from_uuid",
to_uuid_col: str = "to",
retry_attempts_per_object: int = 5,
) -> list[ErrorReference] | None:
"""
Batch create links from an object to another other object through cross-references (https://weaviate.io/developers/weaviate/manage-data/import#import-with-references).

:param collection_name: The name of the collection containing the source objects.
:param data: list or dataframe of objects we want to create links.
:param from_property_col: name of the reference property column.
:param from_uuid_col: Name of the column containing the from UUID.
:param to_uuid_col: Name of the column containing the target UUID.
:param retry_attempts_per_object: number of time to try in case of failure before giving up.
"""
converted_data = self._convert_dataframe_to_list(data)
collection = self.get_collection(collection_name)

with collection.batch.dynamic() as batch:
# Batch create links
for data_obj in converted_data:
for attempt in Retrying(
stop=stop_after_attempt(retry_attempts_per_object),
retry=(
retry_if_exception(lambda exc: check_http_error_is_retryable(exc))
| retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
),
):
with attempt:
from_property = data_obj.pop(from_property_col, None)
from_uuid = data_obj.pop(from_uuid_col, None)
to_uuid = data_obj.pop(to_uuid_col, None)
self.log.debug(
"Attempt %s of create links between %s and %s using reference property %s",
attempt.retry_state.attempt_number,
from_uuid,
to_uuid,
from_property,
)
batch.add_reference(
from_property=from_property,
from_uuid=from_uuid,
to=to_uuid,
)

failed_references = collection.batch.failed_references
if failed_references:
self.log.error("Number of failed imports: %s", len(failed_references))

return failed_references

def batch_data(
self,
collection_name: str,
Expand Down
88 changes: 88 additions & 0 deletions providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,94 @@ def test_create_collection(weaviate_hook):
)


@pytest.mark.parametrize(
argnames=["data", "expected_length"],
argvalues=[
(
[
{
"from_uuid": "0fe86eae-45f7-456c-b19f-04fc59e9ce41",
"to_uuid": "360b6f5b-ed23-413c-a6e8-cb864a52e712",
"from_property": "hasCategory",
},
{
"from_uuid": "34ccb2e1-1cfc-46e5-94d2-48c335e52c29",
"to_uuid": "a775ef49-a8ab-480d-ac85-b70197654072",
"from_property": "hasCategory",
},
],
2,
),
(
pd.DataFrame.from_dict(
{
"from_uuid": [
"0fe86eae-45f7-456c-b19f-04fc59e9ce41",
"34ccb2e1-1cfc-46e5-94d2-48c335e52c29",
],
"to_uuid": [
"360b6f5b-ed23-413c-a6e8-cb864a52e712",
"a775ef49-a8ab-480d-ac85-b70197654072",
],
"from_property": ["hasCategory", "hasCategory"],
}
),
2,
),
],
ids=("batch create link data as list of dicts", "batch create link data as dataframe"),
)
def test_batch_create_links(data, expected_length, weaviate_hook):
"""
Test the batch_create_links method of WeaviateHook.
"""
# Mock the Weaviate Collection
mock_collection = MagicMock()
weaviate_hook.get_collection = MagicMock(return_value=mock_collection)

# Define test data
test_collection_name = "TestCollection"

# Test the batch_data method
weaviate_hook.batch_create_links(test_collection_name, data, to_uuid_col="to_uuid")

mock_batch_context = mock_collection.batch.dynamic.return_value.__enter__.return_value
assert mock_batch_context.add_reference.call_count == expected_length


def test_batch_create_links_retry(weaviate_hook):
"""Test to ensure retrying working as expected"""
# Mock the Weaviate Collection
mock_collection = MagicMock()
weaviate_hook.get_collection = MagicMock(return_value=mock_collection)

data = [
{
"from_uuid": "0fe86eae-45f7-456c-b19f-04fc59e9ce41",
"to": "360b6f5b-ed23-413c-a6e8-cb864a52e712",
"from_property": "hasCategory",
},
{
"from_uuid": "34ccb2e1-1cfc-46e5-94d2-48c335e52c29",
"to": "a775ef49-a8ab-480d-ac85-b70197654072",
"from_property": "hasCategory",
},
]
response = requests.Response()
response.status_code = 429
error = requests.exceptions.HTTPError()
error.response = response
side_effect = [None, error, error, None]

mock_collection.batch.dynamic.return_value.__enter__.return_value.add_reference.side_effect = side_effect

weaviate_hook.batch_create_links("TestCollection", data)

assert mock_collection.batch.dynamic.return_value.__enter__.return_value.add_reference.call_count == len(
side_effect
)


@pytest.mark.parametrize(
argnames=["data", "expected_length"],
argvalues=[
Expand Down