Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ability for users to access the optional arguments of NetworkX layout functions #439

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def draw(
If True, draw ids on the hyperedges. If a dict, must contain (edge_id: label)
pairs. By default, False.
aspect : {"auto", "equal"} or float, optional
Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed
directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full
Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed
directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full
description at
https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set_aspect.html
**kwargs : optional args
Expand Down Expand Up @@ -1283,8 +1283,8 @@ def draw_hypergraph_hull(
radius : float, optional
Radius of the convex hull in the vicinity of the nodes, by default 0.05.
aspect : {"auto", "equal"} or float, optional
Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed
directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full
Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed
directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full
description at
https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set_aspect.html
**kwargs : optional args
Expand Down
56 changes: 43 additions & 13 deletions xgi/drawing/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]


def random_layout(H, center=None, dim=2, seed=None):
def random_layout(H, center=None, seed=None):
"""Position nodes uniformly at random in the unit square.

For every node, a position is generated by choosing each of dim coordinates
Expand All @@ -36,8 +36,6 @@ def random_layout(H, center=None, dim=2, seed=None):
center : array-like, optional
Coordinate pair around which to center the layout.
If None (default), does not center the positions.
dim : int, optional
Dimension of layout, by default 2.
seed : int, optional
Set the random state for deterministic node layouts.
If int, `seed` is the seed used by the random number generator,
Expand Down Expand Up @@ -76,15 +74,15 @@ def random_layout(H, center=None, dim=2, seed=None):
if seed is not None:
np.random.seed(seed)

H, center = nx.drawing.layout._process_params(H, center, dim)
pos = np.random.rand(len(H), dim) + center
H, center = nx.drawing.layout._process_params(H, center, 2)
pos = np.random.rand(len(H), 2) + center
pos = pos.astype(np.float32)
pos = dict(zip(H, pos))

return pos


def pairwise_spring_layout(H, seed=None):
def pairwise_spring_layout(H, seed=None, k=None, **kwargs):
"""
Position the nodes using Fruchterman-Reingold force-directed
algorithm using the graph projection of the hypergraph
Expand All @@ -99,6 +97,13 @@ def pairwise_spring_layout(H, seed=None):
If int, `seed` is the seed used by the random number generator,
If None (default), random numbers are sampled from the
numpy random number generator without initialization.
k : float
The spring constant of the links. When k=None (default),
k = 1/sqrt(N). For more information, see the documentation
for the NetworkX spring_layout() function.
kwargs :
Optional arguments for the NetworkX spring_layout() function.
See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.spring_layout.html

Returns
-------
Expand Down Expand Up @@ -131,7 +136,7 @@ def pairwise_spring_layout(H, seed=None):
if isinstance(H, SimplicialComplex):
H = convert.from_max_simplices(H)
G = convert.convert_to_graph(H)
pos = nx.spring_layout(G, seed=seed)
pos = nx.spring_layout(G, seed=seed, k=k, **kwargs)
return pos


Expand Down Expand Up @@ -190,7 +195,9 @@ def _augmented_projection(H, weighted=False):
return G


def barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
def barycenter_spring_layout(
H, return_phantom_graph=False, seed=None, k=None, **kwargs
):
"""
Position the nodes using Fruchterman-Reingold force-directed
algorithm using an augmented version of the the graph projection
Expand All @@ -211,6 +218,14 @@ def barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
If int, `seed` is the seed used by the random number generator,
If None (default), random numbers are sampled from the
numpy random number generator without initialization.
k : float
The spring constant of the links. When k=None (default),
k = 1/sqrt(N). For more information, see the documentation
for the NetworkX spring_layout() function.
kwargs :
Optional arguments for the NetworkX spring_layout() function.
See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.spring_layout.html


Returns
-------
Expand Down Expand Up @@ -241,7 +256,7 @@ def barycenter_spring_layout(H, return_phantom_graph=False, seed=None):

# Creating a dictionary for the position of the nodes with the standard spring
# layout
pos_with_phantom_nodes = nx.spring_layout(G, seed=seed)
pos_with_phantom_nodes = nx.spring_layout(G, seed=seed, k=k, **kwargs)

# Retaining only the positions of the real nodes
pos = {k: pos_with_phantom_nodes[k] for k in list(H.nodes)}
Expand All @@ -252,7 +267,9 @@ def barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
return pos


def weighted_barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
def weighted_barycenter_spring_layout(
H, return_phantom_graph=False, seed=None, k=None, **kwargs
):
"""Position the nodes using Fruchterman-Reingold force-directed algorithm.

This uses an augmented version of the the graph projection of the hypergraph (or
Expand All @@ -275,6 +292,14 @@ def weighted_barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
If int, `seed` is the seed used by the random number generator,
If None (default), random numbers are sampled from the
numpy random number generator without initialization.
k : float
The spring constant of the links. When k=None (default),
k = 1/sqrt(N). For more information, see the documentation
for the NetworkX spring_layout() function.
kwargs :
Optional arguments for the NetworkX spring_layout() function.
See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.spring_layout.html


Returns
-------
Expand Down Expand Up @@ -305,7 +330,9 @@ def weighted_barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
G = _augmented_projection(H, weighted=True)

# Creating a dictionary for node position with the standard spring layout
pos_with_phantom_nodes = nx.spring_layout(G, weight="weight", seed=seed)
pos_with_phantom_nodes = nx.spring_layout(
G, weight="weight", seed=seed, k=k, **kwargs
)

# Retaining only the positions of the real nodes
pos = {k: pos_with_phantom_nodes[k] for k in list(H.nodes)}
Expand Down Expand Up @@ -452,7 +479,7 @@ def spiral_layout(H, center=None, resolution=0.35, equidistant=False):
return pos


def barycenter_kamada_kawai_layout(H, return_phantom_graph=False):
def barycenter_kamada_kawai_layout(H, return_phantom_graph=False, **kwargs):
"""Position nodes using Kamada-Kawai path-length cost-function
using an augmented version of the the graph projection
of the hypergraph (or simplicial complex), where phantom nodes
Expand All @@ -467,6 +494,9 @@ def barycenter_kamada_kawai_layout(H, return_phantom_graph=False):
return_phantom_graph: bool (default=False)
If True the function returns also the augmented version of the
the graph projection of the hypergraph (or simplicial complex).
kwargs :
Optional arguments for the NetworkX spring_layout() function.
See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.kamada_kawai_layout.html

Returns
-------
Expand All @@ -479,7 +509,7 @@ def barycenter_kamada_kawai_layout(H, return_phantom_graph=False):
G = _augmented_projection(H)

# Creating a dictionary for the position of the nodes with the standard spring layout
pos_with_phantom_nodes = nx.kamada_kawai_layout(G)
pos_with_phantom_nodes = nx.kamada_kawai_layout(G, **kwargs)

# Retaining only the positions of the real nodes
pos = {k: pos_with_phantom_nodes[k] for k in list(H.nodes)}
Expand Down