Skip to content

Commit 9763371

Browse files
adamnschFlorentinD
andcommitted
Start moving custom fields into properties dict
Only `Node`, `Relationship` and `from_neo4j` updated so far. Co-Authored-By: Florentin Dörre <florentin.dorre@neotechnology.com>
1 parent d0dae64 commit 9763371

File tree

6 files changed

+109
-93
lines changed

6 files changed

+109
-93
lines changed
Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable
4-
from typing import Any, Optional, Union
3+
from typing import Optional, Union
54

65
import neo4j.graph
76
from neo4j import Result
@@ -22,8 +21,8 @@ def from_neo4j(
2221
Create a VisualizationGraph from a Neo4j Graph or Neo4j Result object.
2322
2423
All node and relationship properties will be included in the visualization graph.
25-
If the property names are conflicting with those of `Node` and `Relationship` objects, they will be prefixed
26-
with `__`.
24+
If the properties are named as the fields of the `Node` or `Relationship` classes, they will be included as
25+
top level fields of the respective objects. Otherwise, they will be included in the `properties` dictionary.
2726
2827
Parameters
2928
----------
@@ -63,69 +62,64 @@ def from_neo4j(
6362

6463

6564
def _map_node(node: neo4j.graph.Node, size_property: Optional[str], caption_property: Optional[str]) -> Node:
66-
labels = sorted([label for label in node.labels])
65+
top_level_fields = {"id": node.element_id}
6766

6867
if size_property:
69-
size = node.get(size_property)
70-
else:
71-
size = None
68+
top_level_fields["size"] = node.get(size_property)
7269

70+
labels = sorted([label for label in node.labels])
7371
if caption_property:
7472
if caption_property == "labels":
7573
if len(labels) > 0:
76-
caption = ":".join([label for label in labels])
77-
else:
78-
caption = None
74+
top_level_fields["caption"] = ":".join([label for label in labels])
7975
else:
80-
caption = str(node.get(caption_property))
76+
top_level_fields["caption"] = str(node.get(caption_property))
77+
78+
properties = {}
79+
for prop, value in node.items():
80+
if prop not in Node.model_fields.keys():
81+
properties[prop] = value
82+
continue
83+
84+
if prop in top_level_fields:
85+
properties[prop] = value
86+
continue
8187

82-
base_node_props = dict(id=node.element_id, caption=caption, labels=labels, size=size)
88+
top_level_fields[prop] = value
8389

84-
protected_props = base_node_props.keys()
85-
additional_node_props = {k: v for k, v in node.items()}
86-
additional_node_props = _rename_protected_props(additional_node_props, protected_props)
90+
if "labels" in properties:
91+
properties["__labels"] = properties["labels"]
92+
properties["labels"] = labels
8793

88-
return Node(**base_node_props, **additional_node_props)
94+
return Node(**top_level_fields, properties=properties)
8995

9096

9197
def _map_relationship(rel: neo4j.graph.Relationship, caption_property: Optional[str]) -> Optional[Relationship]:
9298
if rel.start_node is None or rel.end_node is None:
9399
return None
94100

101+
top_level_fields = {"id": rel.element_id, "source": rel.start_node.element_id, "target": rel.end_node.element_id}
102+
95103
if caption_property:
96104
if caption_property == "type":
97-
caption = rel.type
105+
top_level_fields["caption"] = rel.type
98106
else:
99-
caption = str(rel.get(caption_property))
100-
else:
101-
caption = None
102-
103-
base_rel_props = dict(
104-
id=rel.element_id,
105-
source=rel.start_node.element_id,
106-
target=rel.end_node.element_id,
107-
_type=rel.type,
108-
caption=caption,
109-
)
110-
111-
protected_props = base_rel_props.keys()
112-
additional_rel_props = {k: v for k, v in rel.items()}
113-
additional_rel_props = _rename_protected_props(additional_rel_props, protected_props)
114-
115-
return Relationship(
116-
**base_rel_props,
117-
**additional_rel_props,
118-
)
119-
120-
121-
def _rename_protected_props(
122-
additional_props: dict[str, Any],
123-
protected_props: Iterable[str],
124-
) -> dict[str, Union[str, int, float]]:
125-
for prop in protected_props:
126-
if prop not in additional_props:
107+
top_level_fields["caption"] = str(rel.get(caption_property))
108+
109+
properties = {}
110+
for prop, value in rel.items():
111+
if prop not in Relationship.model_fields.keys():
112+
properties[prop] = value
127113
continue
128114

129-
additional_props[f"__{prop}"] = additional_props.pop(prop)
115+
if prop in top_level_fields:
116+
properties[prop] = value
117+
continue
118+
119+
top_level_fields[prop] = value
120+
121+
if "type" in properties:
122+
properties["__type"] = properties["type"]
123+
properties["type"] = rel.type
130124

131-
return additional_props
125+
return Relationship(**top_level_fields, properties=properties)

python-wrapper/src/neo4j_viz/node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class Node(BaseModel, extra="allow"):
4545
x: Optional[RealNumber] = Field(None, description="The x-coordinate of the node")
4646
#: The y-coordinate of the node
4747
y: Optional[RealNumber] = Field(None, description="The y-coordinate of the node")
48+
#: The properties of the node
49+
properties: dict[str, Any] = Field(default_factory=dict, description="The properties of the node")
4850

4951
@field_serializer("color")
5052
def serialize_color(self, color: Color) -> str:

python-wrapper/src/neo4j_viz/relationship.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class Relationship(BaseModel, extra="allow"):
4343
)
4444
#: The color of the relationship. Allowed input is for example "#FF0000", "red" or (255, 0, 0)
4545
color: Optional[ColorType] = Field(None, description="The color of the relationship")
46+
#: The properties of the relationship
47+
properties: dict[str, Any] = Field(default_factory=dict, description="The properties of the relationship")
4648

4749
@field_serializer("color")
4850
def serialize_color(self, color: Color) -> str:

python-wrapper/tests/test_neo4j.py

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def graph_setup(neo4j_session: Session) -> Generator[None, None, None]:
1919

2020

2121
@pytest.mark.requires_neo4j_and_gds
22-
def test_from_neo4j_graph(neo4j_session: Session) -> None:
22+
def test_from_neo4j_graph_basic(neo4j_session: Session) -> None:
2323
graph = neo4j_session.run("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a").graph()
2424

2525
VG = from_neo4j(graph)
@@ -31,27 +31,31 @@ def test_from_neo4j_graph(neo4j_session: Session) -> None:
3131
Node(
3232
id=node_ids[0],
3333
caption="_CI_A",
34-
labels=["_CI_A"],
35-
name="Alice",
36-
height=20,
37-
__id=42,
38-
_id=1337,
39-
__caption="hello",
34+
properties=dict(
35+
labels=["_CI_A"],
36+
name="Alice",
37+
height=20,
38+
id=42,
39+
_id=1337,
40+
caption="hello",
41+
),
4042
),
4143
Node(
4244
id=node_ids[1],
4345
caption="_CI_A:_CI_B",
44-
labels=["_CI_A", "_CI_B"],
45-
name="Bob",
46-
height=10,
47-
__id=84,
48-
__size=11,
49-
__labels=[1, 2],
46+
size=11,
47+
properties=dict(
48+
labels=["_CI_A", "_CI_B"],
49+
name="Bob",
50+
height=10,
51+
id=84,
52+
__labels=[1, 2],
53+
),
5054
),
5155
]
5256

5357
assert len(VG.nodes) == 2
54-
assert sorted(VG.nodes, key=lambda x: x.name) == expected_nodes # type: ignore[attr-defined]
58+
assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes
5559

5660
assert len(VG.relationships) == 2
5761
vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo")
@@ -76,27 +80,31 @@ def test_from_neo4j_result(neo4j_session: Session) -> None:
7680
Node(
7781
id=node_ids[0],
7882
caption="_CI_A",
79-
labels=["_CI_A"],
80-
name="Alice",
81-
height=20,
82-
__id=42,
83-
_id=1337,
84-
__caption="hello",
83+
properties=dict(
84+
labels=["_CI_A"],
85+
name="Alice",
86+
height=20,
87+
id=42,
88+
_id=1337,
89+
caption="hello",
90+
),
8591
),
8692
Node(
8793
id=node_ids[1],
8894
caption="_CI_A:_CI_B",
89-
labels=["_CI_A", "_CI_B"],
90-
name="Bob",
91-
height=10,
92-
__id=84,
93-
__size=11,
94-
__labels=[1, 2],
95+
size=11,
96+
properties=dict(
97+
labels=["_CI_A", "_CI_B"],
98+
name="Bob",
99+
height=10,
100+
id=84,
101+
__labels=[1, 2],
102+
),
95103
),
96104
]
97105

98106
assert len(VG.nodes) == 2
99-
assert sorted(VG.nodes, key=lambda x: x.name) == expected_nodes # type: ignore[attr-defined]
107+
assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes
100108

101109
assert len(VG.relationships) == 2
102110
vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo")
@@ -119,29 +127,33 @@ def test_from_neo4j_graph_full(neo4j_session: Session) -> None:
119127
Node(
120128
id=node_ids[0],
121129
caption="Alice",
122-
labels=["_CI_A"],
123-
name="Alice",
124-
height=20,
125130
size=60.0,
126-
__id=42,
127-
_id=1337,
128-
__caption="hello",
131+
properties=dict(
132+
labels=["_CI_A"],
133+
name="Alice",
134+
height=20,
135+
id=42,
136+
_id=1337,
137+
caption="hello",
138+
),
129139
),
130140
Node(
131141
id=node_ids[1],
132142
caption="Bob",
133-
labels=["_CI_A", "_CI_B"],
134-
name="Bob",
135-
height=10,
136143
size=3.0,
137-
__id=84,
138-
__size=11,
139-
__labels=[1, 2],
144+
properties=dict(
145+
labels=["_CI_A", "_CI_B"],
146+
name="Bob",
147+
size=11,
148+
height=10,
149+
id=84,
150+
__labels=[1, 2],
151+
),
140152
),
141153
]
142154

143155
assert len(VG.nodes) == 2
144-
assert sorted(VG.nodes, key=lambda x: x.name) == expected_nodes # type: ignore[attr-defined]
156+
assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes
145157

146158
assert len(VG.relationships) == 2
147159
vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo")

python-wrapper/tests/test_node.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def test_nodes_with_all_options() -> None:
2626
"pinned": True,
2727
"x": 1,
2828
"y": 10,
29+
"properties": {},
2930
}
3031

3132

@@ -36,6 +37,7 @@ def test_nodes_minimal_node() -> None:
3637

3738
assert node.to_dict() == {
3839
"id": "1",
40+
"properties": {},
3941
}
4042

4143

@@ -48,6 +50,7 @@ def test_node_with_float_size() -> None:
4850
assert node.to_dict() == {
4951
"id": "1",
5052
"size": 10.2,
53+
"properties": {},
5154
}
5255

5356

@@ -60,6 +63,7 @@ def test_node_with_additional_fields() -> None:
6063
assert node.to_dict() == {
6164
"id": "1",
6265
"componentId": 2,
66+
"properties": {},
6367
}
6468

6569

@@ -69,6 +73,7 @@ def test_id_aliases(alias: str) -> None:
6973

7074
assert node.to_dict() == {
7175
"id": "1",
76+
"properties": {},
7277
}
7378

7479

python-wrapper/tests/test_relationship.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def test_rels_with_all_options() -> None:
2323
"captionAlign": "top",
2424
"captionSize": 12,
2525
"color": "#ff0000",
26+
"properties": {},
2627
}
2728

2829

@@ -34,7 +35,7 @@ def test_rels_minimal_rel() -> None:
3435

3536
rel_dict = rel.to_dict()
3637

37-
assert {"id", "from", "to"} == set(rel_dict.keys())
38+
assert {"id", "from", "to", "properties"} == set(rel_dict.keys())
3839
assert rel_dict["from"] == "1"
3940
assert rel_dict["to"] == "2"
4041

@@ -43,12 +44,12 @@ def test_rels_additional_fields() -> None:
4344
rel = Relationship(
4445
source="1",
4546
target="2",
46-
componentId=2,
47+
properties=dict(componentId=2),
4748
)
4849

4950
rel_dict = rel.to_dict()
50-
assert {"id", "from", "to", "componentId"} == set(rel_dict.keys())
51-
assert rel.componentId == 2 # type: ignore[attr-defined]
51+
assert {"id", "from", "to", "properties"} == set(rel_dict.keys())
52+
assert rel.properties["componentId"] == 2
5253

5354

5455
@pytest.mark.parametrize("src_alias", ["source", "sourceNodeId", "source_node_id", "from"])
@@ -63,6 +64,6 @@ def test_aliases(src_alias: str, trg_alias: str) -> None:
6364

6465
rel_dict = rel.to_dict()
6566

66-
assert {"id", "from", "to"} == set(rel_dict.keys())
67+
assert {"id", "from", "to", "properties"} == set(rel_dict.keys())
6768
assert rel_dict["from"] == "1"
6869
assert rel_dict["to"] == "2"

0 commit comments

Comments
 (0)