Skip to content

Commit 2b77f17

Browse files
authored
Merge pull request #182 from neo4j/gds-sampling
Sample large graphs by default with `from_gds`
2 parents f11d1be + 3c3a02d commit 2b77f17

File tree

4 files changed

+70
-18
lines changed

4 files changed

+70
-18
lines changed

changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
## Breaking changes
55

66
* The `from_gds` method now fetches all node properties of a given GDS projection by default, instead of none.
7-
* The `from_gds` now adds node labels as captions for nodes.
7+
* The `from_gds` method now adds node labels as captions for nodes.
8+
* The `from_gds` method now samples large graphs before fetching them by default, but this can be overridden.
89

910

1011
## New features

docs/source/integration.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ The ``from_gds`` method takes two mandatory positional parameters:
105105
* An initialized ``GraphDataScience`` object for the connection to the GDS instance, and
106106
* A ``Graph`` representing the projection that one wants to import.
107107

108+
The optional ``max_node_count`` parameter can be used to limit the number of nodes that are imported from the
109+
projection.
110+
By default, it is set to 10.000, meaning that if the projection has more than 10.000 nodes, ``from_gds`` will sample
111+
from it using random walk with restarts, to get a smaller graph that can be visualized.
112+
If you want to have more control of the sampling, such as choosing a specific start node for the sample, you can call
113+
a `sampling <https://neo4j.com/docs/graph-data-science/current/management-ops/graph-creation/sampling/>`_
114+
method yourself and passing the resulting projection to ``from_gds``.
115+
108116
We can also provide an optional ``size_property`` parameter, which should refer to a node property of the projection,
109117
and will be used to determine the sizes of the nodes in the visualization.
110118

python-wrapper/src/neo4j_viz/gds.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22

33
from itertools import chain
44
from typing import Optional
5+
from uuid import uuid4
56

67
import pandas as pd
78
from graphdatascience import Graph, GraphDataScience
9+
from pandas import Series
810

911
from .pandas import _from_dfs
1012
from .visualization_graph import VisualizationGraph
1113

1214

13-
def _node_dfs(
15+
def _fetch_node_dfs(
1416
gds: GraphDataScience, G: Graph, node_properties: list[str], node_labels: list[str]
1517
) -> dict[str, pd.DataFrame]:
1618
return {
@@ -21,17 +23,17 @@ def _node_dfs(
2123
}
2224

2325

24-
def _rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame:
26+
def _fetch_rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame:
2527
relationship_properties = G.relationship_properties()
28+
assert isinstance(relationship_properties, Series)
2629

27-
if len(relationship_properties) > 0:
28-
if isinstance(relationship_properties, pd.Series):
29-
relationship_properties_per_type = relationship_properties.tolist()
30-
property_set: set[str] = set()
31-
for props in relationship_properties_per_type:
32-
if props:
33-
property_set.update(props)
30+
relationship_properties_per_type = relationship_properties.tolist()
31+
property_set: set[str] = set()
32+
for props in relationship_properties_per_type:
33+
if props:
34+
property_set.update(props)
3435

36+
if len(property_set) > 0:
3537
return gds.graph.relationshipProperties.stream(
3638
G, relationship_properties=list(property_set), separate_property_columns=True
3739
)
@@ -45,6 +47,7 @@ def from_gds(
4547
size_property: Optional[str] = None,
4648
additional_node_properties: Optional[list[str]] = None,
4749
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
50+
max_node_count: int = 10_000,
4851
) -> VisualizationGraph:
4952
"""
5053
Create a VisualizationGraph from a GraphDataScience object and a Graph object.
@@ -68,6 +71,9 @@ def from_gds(
6871
node_radius_min_max : tuple[float, float], optional
6972
Minimum and maximum node radius, by default (3, 60).
7073
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
74+
max_node_count : int, optional
75+
The maximum number of nodes to fetch from the graph. The graph will be sampled using random walk with restarts
76+
if its node count exceeds this number.
7177
"""
7278
node_properties_from_gds = G.node_properties()
7379
assert isinstance(node_properties_from_gds, pd.Series)
@@ -86,14 +92,40 @@ def from_gds(
8692
node_properties = set()
8793
if additional_node_properties is not None:
8894
node_properties.update(additional_node_properties)
89-
9095
if size_property is not None:
9196
node_properties.add(size_property)
92-
9397
node_properties = list(node_properties)
94-
node_dfs = _node_dfs(gds, G, node_properties, G.node_labels())
98+
99+
node_count = G.node_count()
100+
if node_count > max_node_count:
101+
sampling_ratio = float(max_node_count) / node_count
102+
sample_name = f"neo4j-viz_sample_{uuid4()}"
103+
G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True)
104+
else:
105+
G_fetched = G
106+
107+
property_name = None
108+
try:
109+
# Since GDS does not allow us to only fetch node IDs, we add the degree property
110+
# as a temporary property to ensure that we have at least one property to fetch
111+
if len(actual_node_properties) == 0:
112+
property_name = f"neo4j-viz_property_{uuid4()}"
113+
gds.degree.mutate(G_fetched, mutateProperty=property_name)
114+
node_properties = [property_name]
115+
116+
node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties, G_fetched.node_labels())
117+
rel_df = _fetch_rel_df(gds, G_fetched)
118+
finally:
119+
if G_fetched.name() != G.name():
120+
G_fetched.drop()
121+
elif property_name is not None:
122+
gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name])
123+
95124
for df in node_dfs.values():
96125
df.rename(columns={"nodeId": "id"}, inplace=True)
126+
if property_name is not None and property_name in df.columns:
127+
df.drop(columns=[property_name], inplace=True)
128+
rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True)
97129

98130
node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates()
99131
if size_property is not None:
@@ -114,9 +146,6 @@ def from_gds(
114146
if "caption" not in actual_node_properties:
115147
node_df["caption"] = node_df["labels"].astype(str)
116148

117-
rel_df = _rel_df(gds, G)
118-
rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True)
119-
120149
try:
121150
return _from_dfs(node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"})
122151
except ValueError as e:

python-wrapper/tests/test_gds.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,10 @@ def test_from_gds_mocked(mocker: MockerFixture) -> None:
170170
lambda x: pd.Series({lbl: node_properties for lbl in nodes.keys()}),
171171
)
172172
mocker.patch("graphdatascience.Graph.node_labels", lambda x: list(nodes.keys()))
173+
mocker.patch("graphdatascience.Graph.node_count", lambda x: sum(len(df) for df in nodes.values()))
173174
mocker.patch("graphdatascience.GraphDataScience.__init__", lambda x: None)
174-
mocker.patch("neo4j_viz.gds._node_dfs", return_value=nodes)
175-
mocker.patch("neo4j_viz.gds._rel_df", return_value=rels)
175+
mocker.patch("neo4j_viz.gds._fetch_node_dfs", return_value=nodes)
176+
mocker.patch("neo4j_viz.gds._fetch_rel_df", return_value=rels)
176177

177178
gds = GraphDataScience() # type: ignore[call-arg]
178179
G = Graph() # type: ignore[call-arg]
@@ -244,3 +245,16 @@ def test_from_gds_node_errors(gds: Any) -> None:
244245
additional_node_properties=["component", "size"],
245246
node_radius_min_max=None,
246247
)
248+
249+
250+
@pytest.mark.requires_neo4j_and_gds
251+
def test_from_gds_sample(gds: Any) -> None:
252+
from neo4j_viz.gds import from_gds
253+
254+
with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G:
255+
VG = from_gds(gds, G)
256+
257+
assert len(VG.nodes) >= 9_500
258+
assert len(VG.nodes) <= 10_500
259+
assert len(VG.relationships) >= 9_500
260+
assert len(VG.relationships) <= 10_500

0 commit comments

Comments
 (0)