Skip to content

Commit

Permalink
refector show graph, allow for explicit calling of formats, add inter…
Browse files Browse the repository at this point in the history
…active widget for jupyter

Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com>
  • Loading branch information
timkpaine committed Jul 25, 2024
1 parent fac1334 commit e0c990b
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 51 deletions.
1 change: 1 addition & 0 deletions conda/dev-environment-unix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- python-graphviz
- gtest
- httpx>=0.20,<1
- ipydagred3
- isort>=5,<6
- libarrow=16
- libboost>=1.80.0
Expand Down
2 changes: 1 addition & 1 deletion csp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from csp.impl.wiring.context import clear_global_context, new_global_context
from csp.math import *
from csp.showgraph import show_graph
from csp.showgraph import *

from . import stats

Expand Down
12 changes: 5 additions & 7 deletions csp/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import csp.baselib
from csp.impl.wiring.edge import Edge
from csp.showgraph import show_graph

# Lazy declaration below to avoid perspective import
RealtimePerspectiveWidget = None
Expand Down Expand Up @@ -143,12 +144,7 @@ def _eval(self, starttime: datetime, endtime: datetime = None, realtime: bool =
return csp.run(self._eval_graph, starttime=starttime, endtime=endtime, realtime=realtime)

def show_graph(self):
from PIL import Image

import csp.showgraph

buffer = csp.showgraph.generate_graph(self._eval_graph)
return Image.open(buffer)
show_graph(self._eval_graph, graph_filename=None)

def to_pandas(self, starttime: datetime, endtime: datetime):
import pandas
Expand Down Expand Up @@ -222,7 +218,9 @@ def join(self):
self._runner.join()

except ImportError:
raise ImportError("eval_perspective requires perspective-python installed")
raise ModuleNotFoundError(
"eval_perspective requires perspective-python installed. See https://perspective.finos.org for installation instructions."
)

if not realtime:
df = self.to_pandas(starttime, endtime)
Expand Down
15 changes: 3 additions & 12 deletions csp/impl/pandas_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from csp.impl.pandas_ext_type import TsDtype, is_csp_type
from csp.impl.struct import define_nested_struct
from csp.impl.wiring.edge import Edge
from csp.showgraph import show_graph

T = TypeVar("T")

Expand Down Expand Up @@ -375,12 +376,7 @@ def show_graph(self):
"""Show the graph corresponding to the evaluation of all the edges.
For large series, this may be very large, so it may be helpful to call .head() first.
"""
from PIL import Image

import csp.showgraph

buffer = csp.showgraph.generate_graph(self._eval_graph, "png")
return Image.open(buffer)
return show_graph(self._eval_graph, graph_filename=None)


@register_series_accessor("to_csp")
Expand Down Expand Up @@ -626,12 +622,7 @@ def show_graph(self):
"""Show the graph corresponding to the evaluation of all the edges.
For large series, this may be very large, so it may be helpful to call .head() first.
"""
from PIL import Image

import csp.showgraph

buffer = csp.showgraph.generate_graph(self._eval_graph, "png")
return Image.open(buffer)
show_graph(self._eval_graph, graph_filename=None)


@register_dataframe_accessor("to_csp")
Expand Down
4 changes: 2 additions & 2 deletions csp/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def initialize(self, adapter: GenericPushAdapter, display_graphs: bool):
try:
import matplotlib # noqa: F401
except ImportError:
raise Exception("You must have matplotlib installed to display profiling data graphs.")
raise ModuleNotFoundError("You must have matplotlib installed to display profiling data graphs.")

def get(self):
try:
Expand Down Expand Up @@ -478,7 +478,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
def init_profiler(self):
if self.http_port is not None:
if not HAS_TORNADO:
raise Exception("You must have tornado installed to use the HTTP profiling extension.")
raise ModuleNotFoundError("You must have tornado installed to use the HTTP profiling extension.")

adapter = GenericPushAdapter(Future)
application = tornado.web.Application(
Expand Down
180 changes: 151 additions & 29 deletions csp/showgraph.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,54 @@
from collections import deque, namedtuple
from io import BytesIO
from typing import Dict, Literal

from csp.impl.wiring.runtime import build_graph

NODE = namedtuple("NODE", ["name", "label", "color", "shape"])
EDGE = namedtuple("EDGE", ["start", "end"])
_KIND = Literal["output", "input", ""]
_NODE = namedtuple("NODE", ["name", "label", "kind"])
_EDGE = namedtuple("EDGE", ["start", "end"])

_GRAPHVIZ_COLORMAP: Dict[_KIND, str] = {"output": "red", "input": "cadetblue1", "": "white"}

def _build_graphviz_graph(graph_func, *args, **kwargs):
from graphviz import Digraph
_GRAPHVIZ_SHAPEMAP: Dict[_KIND, str] = {"output": "rarrow", "input": "rarrow", "": "box"}

_DAGRED3_COLORMAP: Dict[_KIND, str] = {
"output": "red",
"input": "#98f5ff",
"": "lightgrey",
}
_DAGRED3_SHAPEMAP: Dict[_KIND, str] = {"output": "diamond", "input": "diamond", "": "rect"}

_NOTEBOOK_KIND = Literal["", "terminal", "notebook"]

__all__ = (
"generate_graph",
"show_graph_pil",
"show_graph_graphviz",
"show_graph_widget",
"show_graph",
)


def _notebook_kind() -> _NOTEBOOK_KIND:
try:
from IPython import get_ipython

shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return "notebook"
elif shell == "TerminalInteractiveShell":
return "terminal"
else:
return ""
except ImportError:
return ""
except NameError:
return ""


def _build_graph_for_viz(graph_func, *args, **kwargs):
graph = build_graph(graph_func, *args, **kwargs)
digraph = Digraph(strict=True)
digraph.attr(rankdir="LR", size="150,150")

rootnames = set()
q = deque()
Expand All @@ -29,56 +65,142 @@ def _build_graphviz_graph(graph_func, *args, **kwargs):
name = str(id(nodedef))
visited.add(nodedef)
if name in rootnames: # output node
color = "red"
shape = "rarrow"
kind = "output"
elif not sum(1 for _ in nodedef.ts_inputs()): # input node
color = "cadetblue1"
shape = "rarrow"
kind = "input"
else:
color = "white"
shape = "box"
kind = ""

label = nodedef.__name__ if hasattr(nodedef, "__name__") else type(nodedef).__name__
nodes.append(NODE(name=name, label=label, color=color, shape=shape))
nodes.append(_NODE(name=name, label=label, kind=kind))

for input in nodedef.ts_inputs():
if input[1].nodedef not in visited:
q.append(input[1].nodedef)
edges.append(EDGE(start=str(id(input[1].nodedef)), end=name))
edges.append(_EDGE(start=str(id(input[1].nodedef)), end=name))
return nodes, edges


def _build_graphviz_graph(graph_func, *args, **kwargs):
from graphviz import Digraph

nodes, edges = _build_graph_for_viz(graph_func=graph_func, *args, **kwargs)

digraph = Digraph(strict=True)
digraph.attr(rankdir="LR", size="150,150")

for node in nodes:
digraph.node(
node.name,
node.label,
style="filled",
fillcolor=node.color,
shape=node.shape,
fillcolor=_GRAPHVIZ_COLORMAP[node.kind],
shape=_GRAPHVIZ_SHAPEMAP[node.kind],
)
for edge in edges:
digraph.edge(edge.start, edge.end)

return digraph


def _graphviz_to_buffer(digraph, image_format="png") -> BytesIO:
from graphviz import ExecutableNotFound

digraph.format = image_format
buffer = BytesIO()

try:
buffer.write(digraph.pipe())
buffer.seek(0)
return buffer
except ExecutableNotFound as exc:
raise ModuleNotFoundError(
"Must install graphviz and have `dot` available on your PATH. See https://graphviz.org for installation instructions"
) from exc


def generate_graph(graph_func, *args, image_format="png", **kwargs):
"""Generate a BytesIO image representation of the given graph"""
digraph = _build_graphviz_graph(graph_func, *args, **kwargs)
digraph.format = image_format
buffer = BytesIO()
buffer.write(digraph.pipe())
buffer.seek(0)
return buffer
return _graphviz_to_buffer(digraph=digraph, image_format=image_format)


def show_graph(graph_func, *args, graph_filename=None, **kwargs):
def show_graph_pil(graph_func, *args, **kwargs):
buffer = generate_graph(graph_func, *args, image_format="png", **kwargs)
try:
from PIL import Image
except ImportError:
raise ModuleNotFoundError(
"csp requires `pillow` to display images. Install `pillow` with your python package manager, or pass `graph_filename` to generate a file output."
)
image = Image.open(buffer)
image.show()


def show_graph_graphviz(graph_func, *args, graph_filename=None, interactive=False, **kwargs):
# extract the format of the image
image_format = graph_filename.split(".")[-1] if graph_filename else "png"
buffer = generate_graph(graph_func, *args, image_format=image_format, **kwargs)

if graph_filename:
with open(graph_filename, "wb") as f:
f.write(buffer.read())
# Generate graph with graphviz
digraph = _build_graphviz_graph(graph_func, *args, **kwargs)

# if we're in a notebook, return it directly for rendering
if interactive:
return digraph

# otherwise output to file
buffer = _graphviz_to_buffer(digraph=digraph, image_format=image_format)
with open(graph_filename, "wb") as f:
f.write(buffer.read())
return digraph


def show_graph_widget(graph_func, *args, **kwargs):
try:
import ipydagred3
except ImportError:
raise ModuleNotFoundError(
"csp requires `ipydagred3` to display graph widget. Install `ipydagred3` with your python package manager, or pass `graph_filename` to generate a file output."
)

nodes, edges = _build_graph_for_viz(graph_func=graph_func, *args, **kwargs)

graph = ipydagred3.Graph(directed=True, attrs=dict(rankdir="LR"))

for node in nodes:
graph.addNode(
ipydagred3.Node(
name=node.name,
label=node.label,
shape=_DAGRED3_SHAPEMAP[node.kind],
style=f"fill: {_DAGRED3_COLORMAP[node.kind]}",
)
)
for edge in edges:
graph.addEdge(edge.start, edge.end)
return ipydagred3.DagreD3Widget(graph=graph)


def show_graph(graph_func, *args, graph_filename=None, **kwargs):
# check if we're in jupyter
if _notebook_kind() == "notebook":
_HAVE_INTERACTIVE = True
else:
from PIL import Image
_HAVE_INTERACTIVE = False

# display graph via pillow or ipydagred3
if graph_filename in (None, "widget"):
if graph_filename == "widget" and not _HAVE_INTERACTIVE:
raise RuntimeError("Interactive graph viewer only works in Jupyter.")

# render with ipydagred3
if graph_filename == "widget":
return show_graph_widget(graph_func, *args, **kwargs)

# render with pillow
return show_graph_pil(graph_func, *args, **kwargs)

image = Image.open(buffer)
image.show()
# TODO we can show graphviz in jupyter without a filename, but preserving existing behavior for now
return show_graph_graphviz(
graph_func, *args, graph_filename=graph_filename, interactive=_HAVE_INTERACTIVE, **kwargs
)

0 comments on commit e0c990b

Please sign in to comment.