Skip to content

Commit

Permalink
Enhance AdvancedSearch attribute handling and Spanner schema
Browse files Browse the repository at this point in the history
- Added 'name' field to AdvancedSearchAttribute model for better attribute identification.
- Updated AdvancedSearchService to utilize the new 'name' field when organizing attributes by entry.
- Modified SpannerRepository to include 'name' in SQL queries and results, improving data retrieval.
- Altered Spanner schema to include 'Name' column in the AdvancedSearchAttribute table for consistency.
  • Loading branch information
syucream committed Jan 4, 2025
1 parent 75f95a9 commit 6652436
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 17 deletions.
12 changes: 2 additions & 10 deletions entry/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def update_documents(kls, entity: Entity, is_update: bool = False):
entry_id=spanner_entry_id,
attribute_id=spanner_attr_id,
type=AttrType(attr.schema.type),
name=attr.name,
origin_entity_attr_id=entity_attr.id,
origin_attribute_id=attr.id,
)
Expand Down Expand Up @@ -584,22 +585,13 @@ def search_entries_v3(
# Get attributes and their values
attr_values = repo.get_entry_attributes(entry_ids, attr_names)

# Get all EntityAttr objects in a single query
# TODO bundle attr name in the attribute table in Spanner?
entity_attr_ids = {attr.origin_entity_attr_id for attr, _ in attr_values}
entity_attrs = {attr.id: attr for attr in EntityAttr.objects.filter(id__in=entity_attr_ids)}

# Organize attributes by entry
attrs_by_entry: dict[str, dict[str, dict]] = {}
for attr, value in attr_values:
if attr.entry_id not in attrs_by_entry:
attrs_by_entry[attr.entry_id] = {}

entity_attr = entity_attrs.get(attr.origin_entity_attr_id)
if not entity_attr:
raise RuntimeError(f"EntityAttr not found for id: {attr.origin_entity_attr_id}")

attrs_by_entry[attr.entry_id][entity_attr.name] = {
attrs_by_entry[attr.entry_id][attr.name] = {
"type": attr.type,
"value": value.value,
# TODO: Implement proper ACL check
Expand Down
1 change: 1 addition & 0 deletions entry/spanner/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CREATE TABLE AdvancedSearchAttribute (
EntryId STRING(36) NOT NULL,
AttributeId STRING(36) NOT NULL,
Type INT64 NOT NULL,
Name STRING(200) NOT NULL,
OriginEntityAttrId INT64 NOT NULL,
OriginAttributeId INT64 NOT NULL
) PRIMARY KEY (EntryId, AttributeId),
Expand Down
32 changes: 25 additions & 7 deletions entry/spanner_advanced_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class AdvancedSearchAttribute(BaseModel):
entry_id: str = Field(description="UUID of the parent entry")
attribute_id: str = Field(description="UUID of the attribute")
type: AttrType = Field(description="Type of the attribute")
name: str = Field(max_length=200, description="Name of the attribute")
origin_entity_attr_id: int = Field(description="Original entity attribute ID from Django")
origin_attribute_id: int = Field(description="Original attribute ID from Django")

Expand Down Expand Up @@ -187,6 +188,7 @@ def insert_attributes(self, attrs: list[AdvancedSearchAttribute], batch: Spanner
"EntryId",
"AttributeId",
"Type",
"Name",
"OriginEntityAttrId",
"OriginAttributeId",
),
Expand All @@ -195,6 +197,7 @@ def insert_attributes(self, attrs: list[AdvancedSearchAttribute], batch: Spanner
attr.entry_id,
attr.attribute_id,
attr.type.value,
attr.name,
attr.origin_entity_attr_id,
attr.origin_attribute_id,
)
Expand Down Expand Up @@ -295,6 +298,13 @@ def search_entries(
"entity_ids": spanner_v1.param_types.Array(spanner_v1.param_types.INT64)
}

if attribute_names:
query += " AND a.Name IN UNNEST(@attribute_names)"
params["attribute_names"] = attribute_names
param_types["attribute_names"] = spanner_v1.param_types.Array(
spanner_v1.param_types.STRING
)

if entry_name_pattern:
query += " AND LOWER(e.Name) LIKE CONCAT('%', LOWER(@name_pattern), '%')"
params["name_pattern"] = entry_name_pattern
Expand Down Expand Up @@ -331,6 +341,13 @@ def get_entry_attributes(
params = {"entry_ids": entry_ids}
param_types = {"entry_ids": spanner_v1.param_types.Array(spanner_v1.param_types.STRING)}

if attribute_names:
query += " AND a.Name IN UNNEST(@attribute_names)"
params["attribute_names"] = attribute_names
param_types["attribute_names"] = spanner_v1.param_types.Array(
spanner_v1.param_types.STRING
)

with self.database.snapshot() as snapshot:
results = snapshot.execute_sql(query, params=params, param_types=param_types)
return [
Expand All @@ -339,15 +356,16 @@ def get_entry_attributes(
entry_id=row[0],
attribute_id=row[1],
type=AttrType(row[2]),
origin_entity_attr_id=row[3],
origin_attribute_id=row[4],
name=row[3],
origin_entity_attr_id=row[4],
origin_attribute_id=row[5],
),
AdvancedSearchAttributeValue(
entry_id=row[5],
attribute_id=row[6],
attribute_value_id=row[7],
value=str(row[8]),
raw_value=row[9],
entry_id=row[6],
attribute_id=row[7],
attribute_value_id=row[8],
value=str(row[9]),
raw_value=row[10],
),
)
for row in results
Expand Down

0 comments on commit 6652436

Please sign in to comment.