Skip to content

Commit

Permalink
🐛 Destination Weaviate: Multi Tenancy Support (#34229)
Browse files Browse the repository at this point in the history
Co-authored-by: Joe Reuter <joe@airbyte.io>
  • Loading branch information
Marcus0086 and Joe Reuter authored Jan 17, 2024
1 parent cbbbeb9 commit be09dfe
Show file tree
Hide file tree
Showing 6 changed files with 328 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class WeaviateIndexingConfigModel(BaseModel):
)
batch_size: int = Field(title="Batch Size", description="The number of records to send to Weaviate in each batch", default=128)
text_field: str = Field(title="Text Field", description="The field in the object that contains the embedded text", default="text")
tenant_id: str = Field(title="Tenant ID", description="The tenant ID to use for multi tenancy", airbyte_secret=True, default="")
default_vectorizer: str = Field(
title="Default Vectorizer",
description="The vectorizer to use if new classes need to be created",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def _create_client(self):
batch_size=None, dynamic=False, weaviate_error_retries=weaviate.WeaviateErrorRetryConf(number_retries=5)
)

def _add_tenant_to_class_if_missing(self, class_name: str):
class_tenants = self.client.schema.get_class_tenants(class_name=class_name)
if class_tenants is not None and self.config.tenant_id not in [tenant.name for tenant in class_tenants]:
self.client.schema.add_class_tenants(class_name=class_name, tenants=[weaviate.Tenant(name=self.config.tenant_id)])
logging.info(f"Added tenant {self.config.tenant_id} to class {class_name}")
else:
logging.info(f"Tenant {self.config.tenant_id} already exists in class {class_name}")

def check(self) -> Optional[str]:
deployment_mode = os.environ.get("DEPLOYMENT_MODE", "")
if deployment_mode.casefold() == CLOUD_DEPLOYMENT_MODE and not self._uses_safe_config():
Expand All @@ -69,6 +77,11 @@ def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None:
self._create_client()
classes = {c["class"]: c for c in self.client.schema.get().get("classes", [])}
self.has_record_id_metadata = defaultdict(lambda: False)

if self.config.tenant_id.strip():
for class_name in classes.keys():
self._add_tenant_to_class_if_missing(class_name)

for stream in catalog.streams:
class_name = self._stream_to_class_name(stream.stream.name)
schema = classes[class_name] if class_name in classes else None
Expand All @@ -78,24 +91,29 @@ def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None:
self.client.schema.create_class(schema)
logging.info(f"Recreated class {class_name}")
elif class_name not in classes:
self.client.schema.create_class(
{
"class": class_name,
"vectorizer": self.config.default_vectorizer,
"properties": [
{
# Record ID is used for bookkeeping, not for searching
"name": METADATA_RECORD_ID_FIELD,
"dataType": ["text"],
"description": "Record ID, used for bookkeeping.",
"indexFilterable": True,
"indexSearchable": False,
"tokenization": "field",
}
],
}
)
config = {
"class": class_name,
"vectorizer": self.config.default_vectorizer,
"properties": [
{
# Record ID is used for bookkeeping, not for searching
"name": METADATA_RECORD_ID_FIELD,
"dataType": ["text"],
"description": "Record ID, used for bookkeeping.",
"indexFilterable": True,
"indexSearchable": False,
"tokenization": "field",
}
],
}
if self.config.tenant_id.strip():
config["multiTenancyConfig"] = {"enabled": True}

self.client.schema.create_class(config)
logging.info(f"Created class {class_name}")

if self.config.tenant_id.strip():
self._add_tenant_to_class_if_missing(class_name)
else:
self.has_record_id_metadata[class_name] = schema is not None and any(
prop.get("name") == METADATA_RECORD_ID_FIELD for prop in schema.get("properties", {})
Expand All @@ -105,10 +123,18 @@ def delete(self, delete_ids, namespace, stream):
if len(delete_ids) > 0:
class_name = self._stream_to_class_name(stream)
if self.has_record_id_metadata[class_name]:
self.client.batch.delete_objects(
class_name=class_name,
where={"path": [METADATA_RECORD_ID_FIELD], "operator": "ContainsAny", "valueStringArray": delete_ids},
)
where_filter = {"path": [METADATA_RECORD_ID_FIELD], "operator": "ContainsAny", "valueStringArray": delete_ids}
if self.config.tenant_id.strip():
self.client.batch.delete_objects(
class_name=class_name,
tenant=self.config.tenant_id,
where=where_filter,
)
else:
self.client.batch.delete_objects(
class_name=class_name,
where=where_filter,
)

def index(self, document_chunks, namespace, stream):
if len(document_chunks) == 0:
Expand All @@ -124,7 +150,12 @@ def index(self, document_chunks, namespace, stream):
weaviate_object[self.config.text_field] = chunk.page_content
object_id = str(uuid.uuid4())
class_name = self._stream_to_class_name(chunk.record.stream)
self.client.batch.add_data_object(weaviate_object, class_name, object_id, vector=chunk.embedding)
if self.config.tenant_id.strip():
self.client.batch.add_data_object(
weaviate_object, class_name, object_id, vector=chunk.embedding, tenant=self.config.tenant_id
)
else:
self.client.batch.add_data_object(weaviate_object, class_name, object_id, vector=chunk.embedding)
self._flush()

def _stream_to_class_name(self, stream_name: str) -> str:
Expand Down
Loading

0 comments on commit be09dfe

Please sign in to comment.