Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
nwlandry committed Nov 3, 2023
1 parent 5e83be1 commit aa4aafb
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 160 deletions.
17 changes: 4 additions & 13 deletions tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,10 @@ def test_draw_multilayer(edgelist8):
assert edge_coll.get_zorder() == 2 # edges

# node_fc
assert np.all(
node_coll.get_facecolor() == np.array([[1, 1, 1, 1]])
) # white
assert np.all(node_coll.get_facecolor() == np.array([[1, 1, 1, 1]])) # white

# node_ec
assert np.all(
node_coll.get_edgecolor() == np.array([[0, 0, 0, 1]])
) # black
assert np.all(node_coll.get_edgecolor() == np.array([[0, 0, 0, 1]])) # black

# node_lw
assert np.all(node_coll.get_linewidth() == np.array([1]))
Expand Down Expand Up @@ -499,22 +495,17 @@ def test_draw_multilayer(edgelist8):
)

# node_fc
assert np.all(
node_coll4.get_facecolor() == np.array([[1, 0, 0, 1]])
) # red
assert np.all(node_coll4.get_facecolor() == np.array([[1, 0, 0, 1]])) # red

# node_ec
assert np.all(
node_coll4.get_edgecolor() == np.array([[0, 0, 1, 1]])
) # blue
assert np.all(node_coll4.get_edgecolor() == np.array([[0, 0, 1, 1]])) # blue

# node_lw
assert np.all(node_coll4.get_linewidth() == np.array([1]))

# node_size
assert np.all(node_coll4.get_sizes() == np.array([10**2]))


plt.close()


Expand Down
251 changes: 105 additions & 146 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .. import convert
from ..algorithms import max_edge_order, unique_edge_sizes
from ..convert import to_bipartite_edgelist
from ..core import DiHypergraph, Hypergraph, SimplicialComplex
from ..exception import XGIError
from ..utils import subfaces
Expand Down Expand Up @@ -47,6 +48,7 @@
"draw_hypergraph_hull",
"draw_multilayer",
"draw_dihypergraph",
"draw_bipartite",
]


Expand Down Expand Up @@ -373,7 +375,7 @@ def draw_nodes(
* "max_node_size" (default: 30)
* "min_node_lw" (default: 0)
* "max_node_lw" (default: 5)
kwargs : optional keywords
See `draw_node_labels` for a description of optional keywords.
Expand Down Expand Up @@ -569,7 +571,7 @@ def draw_hyperedges(
Default parameters. Keys that may be useful to override default settings:
* "min_dyad_lw" (default: 1)
* "max_dyad_lw" (default: 10)
kwargs : optional keywords
See `draw_hyperedge_labels` for a description of optional keywords.
Expand Down Expand Up @@ -800,7 +802,7 @@ def draw_simplices(
Default parameters. Keys that may be useful to override default settings:
* "min_dyad_lw" (default: 1)
* "max_dyad_lw" (default: 10)
kwargs : optional keywords
See `draw_hyperedge_labels` for a description of optional keywords.
Expand Down Expand Up @@ -1592,8 +1594,8 @@ def draw_multilayer(
ydiff = np.max(ys) - np.min(ys)
ymin = np.min(ys) - ydiff * 0.1
ymax = np.max(ys) + ydiff * 0.1
xmin = np.min(xs) - xdiff * 0.1 #* (width / height)
xmax = np.max(xs) + xdiff * 0.1 #* (width / height)
xmin = np.min(xs) - xdiff * 0.1 # * (width / height)
xmax = np.max(xs) + xdiff * 0.1 # * (width / height)
xx, yy = np.meshgrid([xmin, xmax], [ymin, ymax])

# plot layers
Expand Down Expand Up @@ -2140,20 +2142,20 @@ def draw_bipartite(
node_pos=None,
edge_pos=None,
ax=None,
dyad_color="black",
dyad_lw=1,
node_fc="white",
node_ec="black",
node_marker="o",
node_shape="o",
node_lw=1,
node_size=10,
edge_marker_fc="lightblue",
edge_marker_fc="blue",
edge_marker_ec="black",
edge_marker="s",
edge_marker_shape="s",
edge_marker_lw=1,
edge_marker_size=10,
dyad_color="black",
dyad_lw=1,
max_order=None,
settings=None,
rescale_sizes=True,
**kwargs,
):
"""Draw a hypergraph as a bipartite network.
Expand Down Expand Up @@ -2226,14 +2228,18 @@ def draw_bipartite(
* max_node_size
* min_node_lw
* max_node_lw
* min_edge_marker_size
* max_edge_marker_size
* min_edge_marker_lw
* max_edge_marker_lw
* node_fc_cmap
* node_ec_cmap
* min_lines_lw
* max_lines_lw
* lines_fc_cmap
* edge_fc_cmap
* edge_marker_fc_cmap
* edge_marker_ec_cmap
* min_dyad_lw
* max_dyad_lw
* dyad_color_cmap
Returns
-------
Expand All @@ -2254,153 +2260,106 @@ def draw_bipartite(
if not isinstance(H, Hypergraph):
raise XGIError("The input must be a Hypergraph")

if settings is None:
settings = {
"min_node_size": 10.0,
"max_node_size": 30.0,
"min_edge_marker_size": 10.0,
"max_edge_marker_size": 30.0,
"min_node_lw": 10.0,
"max_node_lw": 30.0,
"min_edge_marker_lw": 10.0,
"max_edge_marker_lw": 30.0,
"min_dyad_lw": 1.0,
"max_dyad_lw": 5.0,
"node_fc_cmap": cm.Reds,
"node_ec_cmap": cm.RdBu,
"dyad_fc_cmap": cm.Blues,
"edge_marker_fc_cmap": cm.Blues,
"edge_marker_ec_cmap": cm.Greys,
}
settings = {
"min_node_lw": 10.0,
"max_node_lw": 30.0,
"min_node_size": 10.0,
"max_node_size": 30.0,
"min_edge_marker_lw": 10.0,
"max_edge_marker_lw": 30.0,
"min_edge_marker_size": 10.0,
"max_edge_marker_size": 30.0,
"min_dyad_lw": 1.0,
"max_dyad_lw": 5.0,
"node_fc_cmap": cm.Reds,
"node_ec_cmap": cm.RdBu,
"dyad_color_cmap": cm.Blues,
"edge_marker_fc_cmap": cm.Greys,
"edge_marker_ec_cmap": cm.Blues,
}

settings.update(kwargs)

if ax is None:
ax = plt.gca()

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.axis("off")
node_settings = {
"min_node_lw": settings["min_node_lw"],
"max_node_lw": settings["max_node_lw"],
"min_node_size": settings["min_node_size"],
"max_node_size": settings["max_node_size"],
"node_fc_cmap": settings["node_fc_cmap"],
"node_ec_cmap": settings["node_ec_cmap"],
}

if not max_order:
max_order = max_edge_order(H)
edge_marker_settings = {
"min_node_lw": settings["min_edge_marker_lw"],
"max_node_lw": settings["max_edge_marker_lw"],
"min_node_size": settings["min_edge_marker_size"],
"max_node_size": settings["max_edge_marker_size"],
"node_fc_cmap": settings["edge_marker_fc_cmap"],
"node_ec_cmap": settings["edge_marker_ec_cmap"],
}

if not node_pos or not edge_pos:
node_pos, edge_pos = bipartite_spring_layout(H)

dyad_lw = _scalar_arg_to_dict(
dyad_lw, H.edges, settings["min_dyad_lw"], settings["max_dyad_lw"]
)

if dyad_color is None:
dyad_color = H.edges.size

dyad_color = _color_arg_to_dict(dyad_color, H.edges, settings["dyad_fc_cmap"])

if edge_marker_fc is None:
edge_marker_fc = H.edges.size

edge_marker_fc = _color_arg_to_dict(
edge_marker_fc, H.edges, settings["edge_marker_fc_cmap"]
)

if edge_marker_ec is None:
edge_marker_ec = H.edges.size
if ax is None:
ax = plt.gca()

edge_marker_ec = _color_arg_to_dict(
edge_marker_ec, H.edges, settings["edge_marker_ec_cmap"]
ax, node_collection = draw_nodes(
H=H,
pos=node_pos,
ax=ax,
node_fc=node_fc,
node_ec=node_ec,
node_lw=node_lw,
node_size=node_size,
node_shape=node_shape,
zorder=2,
params=node_settings,
node_labels=None,
rescale_sizes=rescale_sizes,
**kwargs,
)

edge_marker_lw = _scalar_arg_to_dict(
edge_marker_lw,
H.edges,
settings["min_edge_marker_lw"],
settings["max_edge_marker_lw"],
ax, edge_marker_collection = draw_nodes(
H=H.dual(),
pos=edge_pos,
ax=ax,
node_fc=edge_marker_fc,
node_ec=edge_marker_ec,
node_lw=edge_marker_lw,
node_size=edge_marker_size,
node_shape=edge_marker_shape,
zorder=1,
params=edge_marker_settings,
node_labels=None,
rescale_sizes=rescale_sizes,
**kwargs,
)

edge_marker_size = _scalar_arg_to_dict(
edge_marker_size,
H.edges,
settings["min_edge_marker_size"],
settings["max_edge_marker_size"],
)
dyads = to_bipartite_edgelist(H)
dyad_lw = _draw_arg_to_arr(dyad_lw)
# dyad_color, dyad_c_mapped = _parse_color_arg(dyad_color, list(dyads)
# check validity of input values
if np.any(dyad_lw < 0):
raise ValueError("dyad_lw cannot contain negative values.")

node_size = _scalar_arg_to_dict(
node_size, H.nodes, settings["min_node_size"], settings["max_node_size"]
)
node_fc = _color_arg_to_dict(node_fc, H.nodes, settings["node_fc_cmap"])
node_ec = _color_arg_to_dict(node_ec, H.nodes, settings["node_ec_cmap"])
node_lw = _scalar_arg_to_dict(
node_lw,
H.nodes,
settings["min_node_lw"],
settings["max_node_lw"],
)
# interpolate if needed
if rescale_sizes and isinstance(dyad_lw, np.ndarray):
dyad_lw = _interp_draw_arg(
dyad_lw, settings["min_dyad_lw"], settings["max_dyad_lw"]
)

for id, e in H.edges.members(dtype=dict).items():
d = len(e) - 1
if d > 0:
x_edge, y_edge = edge_pos[id]
for n in e:
x_node, y_node = node_pos[n]
line = plt.Line2D(
[x_node, x_edge],
[y_node, y_edge],
color=dyad_color[id],
lw=dyad_lw[id],
zorder=max_order - d,
)
ax.add_line(line)

(xs, ys, s, c, ec, lw,) = zip(
*[
(
node_pos[i][0],
node_pos[i][1],
node_size[i],
node_fc[i],
node_ec[i],
node_lw[i],
)
for i in H.nodes
]
)
ax.scatter(
x=xs,
y=ys,
marker=node_marker,
s=s,
c=c,
edgecolors=ec,
linewidths=lw,
zorder=max_order,
)
dyad_pos = np.asarray([(node_pos[list(e)[0]], edge_pos[list(e)[1]]) for e in dyads])

(xs, ys, s, c, ec, lw,) = zip(
*[
(
edge_pos[i][0],
edge_pos[i][1],
edge_marker_size[i],
edge_marker_fc[i],
edge_marker_ec[i],
edge_marker_lw[i],
)
for i in H.edges
]
)
ax.scatter(
x=xs,
y=ys,
marker=edge_marker,
s=s,
c=c,
edgecolors=ec,
linewidths=lw,
zorder=max_order,
dyad_collection = LineCollection(
dyad_pos,
colors=dyad_color,
linewidths=dyad_lw,
antialiaseds=(1,),
cmap=settings["dyad_color_cmap"],
zorder=0,
)

# compute axis limits
_update_lims(node_pos, ax)
ax.add_collection(dyad_collection)

return ax
return ax, (node_collection, edge_marker_collection, dyad_collection)
Loading

0 comments on commit aa4aafb

Please sign in to comment.