Skip to content

Commit

Permalink
Added the ability for users to access the optional arguments of Netwo…
Browse files Browse the repository at this point in the history
…rkX layout functions. (#439)
  • Loading branch information
nwlandry authored Jul 28, 2023
1 parent 426c07d commit ea1aaf6
Showing 1 changed file with 43 additions and 13 deletions.
56 changes: 43 additions & 13 deletions xgi/drawing/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]


def random_layout(H, center=None, dim=2, seed=None):
def random_layout(H, center=None, seed=None):
"""Position nodes uniformly at random in the unit square.
For every node, a position is generated by choosing each of dim coordinates
Expand All @@ -36,8 +36,6 @@ def random_layout(H, center=None, dim=2, seed=None):
center : array-like, optional
Coordinate pair around which to center the layout.
If None (default), does not center the positions.
dim : int, optional
Dimension of layout, by default 2.
seed : int, optional
Set the random state for deterministic node layouts.
If int, `seed` is the seed used by the random number generator,
Expand Down Expand Up @@ -76,15 +74,15 @@ def random_layout(H, center=None, dim=2, seed=None):
if seed is not None:
np.random.seed(seed)

H, center = nx.drawing.layout._process_params(H, center, dim)
pos = np.random.rand(len(H), dim) + center
H, center = nx.drawing.layout._process_params(H, center, 2)
pos = np.random.rand(len(H), 2) + center
pos = pos.astype(np.float32)
pos = dict(zip(H, pos))

return pos


def pairwise_spring_layout(H, seed=None):
def pairwise_spring_layout(H, seed=None, k=None, **kwargs):
"""
Position the nodes using Fruchterman-Reingold force-directed
algorithm using the graph projection of the hypergraph
Expand All @@ -99,6 +97,13 @@ def pairwise_spring_layout(H, seed=None):
If int, `seed` is the seed used by the random number generator,
If None (default), random numbers are sampled from the
numpy random number generator without initialization.
k : float
The spring constant of the links. When k=None (default),
k = 1/sqrt(N). For more information, see the documentation
for the NetworkX spring_layout() function.
kwargs :
Optional arguments for the NetworkX spring_layout() function.
See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.spring_layout.html
Returns
-------
Expand Down Expand Up @@ -131,7 +136,7 @@ def pairwise_spring_layout(H, seed=None):
if isinstance(H, SimplicialComplex):
H = convert.from_max_simplices(H)
G = convert.convert_to_graph(H)
pos = nx.spring_layout(G, seed=seed)
pos = nx.spring_layout(G, seed=seed, k=k, **kwargs)
return pos


Expand Down Expand Up @@ -190,7 +195,9 @@ def _augmented_projection(H, weighted=False):
return G


def barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
def barycenter_spring_layout(
H, return_phantom_graph=False, seed=None, k=None, **kwargs
):
"""
Position the nodes using Fruchterman-Reingold force-directed
algorithm using an augmented version of the the graph projection
Expand All @@ -211,6 +218,14 @@ def barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
If int, `seed` is the seed used by the random number generator,
If None (default), random numbers are sampled from the
numpy random number generator without initialization.
k : float
The spring constant of the links. When k=None (default),
k = 1/sqrt(N). For more information, see the documentation
for the NetworkX spring_layout() function.
kwargs :
Optional arguments for the NetworkX spring_layout() function.
See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.spring_layout.html
Returns
-------
Expand Down Expand Up @@ -241,7 +256,7 @@ def barycenter_spring_layout(H, return_phantom_graph=False, seed=None):

# Creating a dictionary for the position of the nodes with the standard spring
# layout
pos_with_phantom_nodes = nx.spring_layout(G, seed=seed)
pos_with_phantom_nodes = nx.spring_layout(G, seed=seed, k=k, **kwargs)

# Retaining only the positions of the real nodes
pos = {k: pos_with_phantom_nodes[k] for k in list(H.nodes)}
Expand All @@ -252,7 +267,9 @@ def barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
return pos


def weighted_barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
def weighted_barycenter_spring_layout(
H, return_phantom_graph=False, seed=None, k=None, **kwargs
):
"""Position the nodes using Fruchterman-Reingold force-directed algorithm.
This uses an augmented version of the the graph projection of the hypergraph (or
Expand All @@ -275,6 +292,14 @@ def weighted_barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
If int, `seed` is the seed used by the random number generator,
If None (default), random numbers are sampled from the
numpy random number generator without initialization.
k : float
The spring constant of the links. When k=None (default),
k = 1/sqrt(N). For more information, see the documentation
for the NetworkX spring_layout() function.
kwargs :
Optional arguments for the NetworkX spring_layout() function.
See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.spring_layout.html
Returns
-------
Expand Down Expand Up @@ -305,7 +330,9 @@ def weighted_barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
G = _augmented_projection(H, weighted=True)

# Creating a dictionary for node position with the standard spring layout
pos_with_phantom_nodes = nx.spring_layout(G, weight="weight", seed=seed)
pos_with_phantom_nodes = nx.spring_layout(
G, weight="weight", seed=seed, k=k, **kwargs
)

# Retaining only the positions of the real nodes
pos = {k: pos_with_phantom_nodes[k] for k in list(H.nodes)}
Expand Down Expand Up @@ -452,7 +479,7 @@ def spiral_layout(H, center=None, resolution=0.35, equidistant=False):
return pos


def barycenter_kamada_kawai_layout(H, return_phantom_graph=False):
def barycenter_kamada_kawai_layout(H, return_phantom_graph=False, **kwargs):
"""Position nodes using Kamada-Kawai path-length cost-function
using an augmented version of the the graph projection
of the hypergraph (or simplicial complex), where phantom nodes
Expand All @@ -467,6 +494,9 @@ def barycenter_kamada_kawai_layout(H, return_phantom_graph=False):
return_phantom_graph: bool (default=False)
If True the function returns also the augmented version of the
the graph projection of the hypergraph (or simplicial complex).
kwargs :
Optional arguments for the NetworkX spring_layout() function.
See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.kamada_kawai_layout.html
Returns
-------
Expand All @@ -479,7 +509,7 @@ def barycenter_kamada_kawai_layout(H, return_phantom_graph=False):
G = _augmented_projection(H)

# Creating a dictionary for the position of the nodes with the standard spring layout
pos_with_phantom_nodes = nx.kamada_kawai_layout(G)
pos_with_phantom_nodes = nx.kamada_kawai_layout(G, **kwargs)

# Retaining only the positions of the real nodes
pos = {k: pos_with_phantom_nodes[k] for k in list(H.nodes)}
Expand Down

0 comments on commit ea1aaf6

Please sign in to comment.