Skip to content

Commit d978c96

Browse files
adamnschFlorentinD
andcommitted
Allow node creation from rel expression in from_gql_create
Co-Authored-By: Florentin Dörre <florentin.dorre@neotechnology.com>
1 parent 4ef1c2f commit d978c96

File tree

2 files changed

+68
-48
lines changed

2 files changed

+68
-48
lines changed

python-wrapper/src/neo4j_viz/gql_create.py

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def from_gql_create(
193193
It also does not handle all possible GQL syntax, but it should work for most common cases.
194194
For more complex cases, we recommend using a Neo4j database and the `from_neo4j` method.
195195
196-
197196
Parameters
198197
----------
199198
query : str
@@ -213,11 +212,6 @@ def from_gql_create(
213212
if not re.match(r"(?i)^create\b", query):
214213
raise ValueError("Query must begin with 'CREATE' (case insensitive).")
215214

216-
nodes = []
217-
relationships = []
218-
alias_to_id = {}
219-
anonymous_count = 0
220-
221215
query = re.sub(r"(?i)^create\s*", "", query, count=1).rstrip(";").strip()
222216
parts = []
223217
paren_level = 0
@@ -252,8 +246,8 @@ def from_gql_create(
252246
snippet = _get_snippet(query, len(query) - 1)
253247
raise ValueError(f"Unbalanced square brackets near: `{snippet}`.")
254248

255-
node_pattern = re.compile(r"^\(([^)]*)\)$") # Changed here
256-
rel_pattern = re.compile(r"^\(([^)]+)\)-\s*\[\s*:(\w+)\s*(\{[^}]*\})?\s*\]->\(([^)]+)\)$")
249+
node_pattern = re.compile(r"^\(([^)]*)\)$")
250+
rel_pattern = re.compile(r"^\(([^)]*)\)-\s*\[\s*:(\w+)\s*(\{[^}]*\})?\s*\]->\(([^)]*)\)$")
257251

258252
node_top_level_keys = set(Node.model_fields.keys())
259253
node_top_level_keys.remove("id")
@@ -263,7 +257,10 @@ def from_gql_create(
263257
rel_top_level_keys.remove("source")
264258
rel_top_level_keys.remove("target")
265259

266-
empty_set: set[str] = set()
260+
nodes = []
261+
relationships = []
262+
alias_to_id = {}
263+
anonymous_count = 0
267264

268265
for part in parts:
269266
node_m = node_pattern.match(part)
@@ -276,45 +273,64 @@ def from_gql_create(
276273
if alias not in alias_to_id:
277274
alias_to_id[alias] = str(uuid.uuid4())
278275
nodes.append(Node(id=alias_to_id[alias], **top_level, properties=props))
279-
else:
280-
rel_m = rel_pattern.match(part)
281-
if rel_m:
282-
left_node = rel_m.group(1).strip()
283-
rel_type = rel_m.group(2).replace(":", "").strip()
284-
right_node = rel_m.group(4).strip()
285-
left_alias, _, _ = _parse_labels_and_props(query, left_node, empty_set)
286-
if not left_alias or left_alias not in alias_to_id:
287-
snippet = _get_snippet(query, query.index(left_node))
288-
raise ValueError(f"Relationship references unknown node alias: '{left_alias}' near: `{snippet}`.")
289-
right_alias, _, _ = _parse_labels_and_props(query, right_node, empty_set)
290-
if not right_alias or right_alias not in alias_to_id:
291-
snippet = _get_snippet(query, query.index(right_node))
292-
raise ValueError(f"Relationship references unknown node alias: '{right_alias}' near: `{snippet}`.")
293-
294-
rel_id = str(uuid.uuid4())
295-
rel_props_str = rel_m.group(3) or ""
296-
if rel_props_str:
297-
inner_str = rel_props_str.strip("{}").strip()
298-
prop_start = query.index(inner_str, query.index(inner_str))
299-
top_level, props = _parse_prop_str(query, inner_str, prop_start, rel_top_level_keys)
300-
else:
301-
top_level = {}
302-
props = {}
303-
if "type" in props:
304-
props["__type"] = props["type"]
305-
props["type"] = rel_type
306-
relationships.append(
307-
Relationship(
308-
id=rel_id,
309-
source=alias_to_id[left_alias],
310-
target=alias_to_id[right_alias],
311-
**top_level,
312-
properties=props,
313-
)
314-
)
276+
277+
continue
278+
279+
rel_m = rel_pattern.match(part)
280+
if rel_m:
281+
left_node = rel_m.group(1).strip()
282+
right_node = rel_m.group(4).strip()
283+
284+
# Parse left node pattern
285+
left_alias, left_top_level, left_props = _parse_labels_and_props(query, left_node, node_top_level_keys)
286+
if not left_alias:
287+
left_alias = f"_anon_{anonymous_count}"
288+
anonymous_count += 1
289+
if left_alias not in alias_to_id:
290+
alias_to_id[left_alias] = str(uuid.uuid4())
291+
nodes.append(Node(id=alias_to_id[left_alias], **left_top_level, properties=left_props))
292+
elif left_alias not in alias_to_id:
293+
snippet = _get_snippet(query, query.index(left_node))
294+
raise ValueError(f"Relationship references unknown node alias: '{left_alias}' near: `{snippet}`.")
295+
296+
# Parse right node pattern
297+
right_alias, right_top_level, right_props = _parse_labels_and_props(query, right_node, node_top_level_keys)
298+
if not right_alias:
299+
right_alias = f"_anon_{anonymous_count}"
300+
anonymous_count += 1
301+
if right_alias not in alias_to_id:
302+
alias_to_id[right_alias] = str(uuid.uuid4())
303+
nodes.append(Node(id=alias_to_id[right_alias], **right_top_level, properties=right_props))
304+
elif right_alias not in alias_to_id:
305+
snippet = _get_snippet(query, query.index(right_node))
306+
raise ValueError(f"Relationship references unknown node alias: '{right_alias}' near: `{snippet}`.")
307+
308+
rel_id = str(uuid.uuid4())
309+
rel_type = rel_m.group(2).replace(":", "").strip()
310+
rel_props_str = rel_m.group(3) or ""
311+
if rel_props_str:
312+
inner_str = rel_props_str.strip("{}").strip()
313+
prop_start = query.index(inner_str, query.index(inner_str))
314+
top_level, props = _parse_prop_str(query, inner_str, prop_start, rel_top_level_keys)
315315
else:
316-
snippet = part[:30]
317-
raise ValueError(f"Invalid element in CREATE near: `{snippet}`.")
316+
top_level = {}
317+
props = {}
318+
if "type" in props:
319+
props["__type"] = props["type"]
320+
props["type"] = rel_type
321+
relationships.append(
322+
Relationship(
323+
id=rel_id,
324+
source=alias_to_id[left_alias],
325+
target=alias_to_id[right_alias],
326+
**top_level,
327+
properties=props,
328+
)
329+
)
330+
continue
331+
332+
snippet = part[:30]
333+
raise ValueError(f"Invalid element in CREATE near: `{snippet}`.")
318334

319335
if size_property is not None:
320336
for node in nodes:

python-wrapper/tests/test_gql_create.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def test_from_gql_create_syntax() -> None:
2121
({age: 29}),
2222
(a)-[:LINK {weight: 4}]->(wizardMan),
2323
(e)-[:LINK]->(d),
24-
(e)-[:OTHER_LINK {weight: -2, type: 1, source: 1337, caption: "Balloon"}]->(f);
24+
(e)-[:OTHER_LINK {weight: -2, type: 1, source: 1337, caption: "Balloon"}]->(f),
25+
()-[:LINK]->({name: 'Florentin'});
2526
"""
2627
expected_node_dicts: list[dict[str, dict[str, Any]]] = [
2728
{
@@ -49,6 +50,8 @@ def test_from_gql_create_syntax() -> None:
4950
{"top_level": {}, "properties": {"labels": []}},
5051
{"top_level": {}, "properties": {"name": "Fawad", "age": 78, "labels": ["Person", "User"]}},
5152
{"top_level": {}, "properties": {"age": 29, "labels": []}},
53+
{"top_level": {}, "properties": {"labels": []}},
54+
{"top_level": {}, "properties": {"name": "Florentin", "labels": []}},
5255
]
5356

5457
VG = from_gql_create(query, node_caption=None, relationship_caption=None)
@@ -70,6 +73,7 @@ def test_from_gql_create_syntax() -> None:
7073
"top_level": {"caption": "Balloon"},
7174
"properties": {"weight": -2, "type": "OTHER_LINK", "__type": 1, "source": 1337},
7275
},
76+
{"source_idx": 9, "target_idx": 10, "top_level": {}, "properties": {"type": "LINK"}},
7377
]
7478

7579
assert len(VG.relationships) == len(expected_relationships_dicts)

0 commit comments

Comments
 (0)