Skip to content

Commit

Permalink
feat(hugr-py)!: pretty printing for ops and types (#1482)
Browse files Browse the repository at this point in the history
Closes #1468 
Closes #1470

Previous testing wasn't covering label formatting because it requires
actually rendering the dot in to an image. Have enabled that, but put it
behind an env var because it significantly slows down tests.

BREAKING CHANGE: rename `Custom.name` to `Custom.op_name` and
`Func(Defn/Decl).name` to `f_name` to allow for new `name` method
  • Loading branch information
ss2165 authored Aug 30, 2024
1 parent 037005f commit aca403a
Show file tree
Hide file tree
Showing 17 changed files with 495 additions and 122 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/ci-py.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,20 @@ jobs:
- name: Setup dependencies
run: uv sync

- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v2

- name: Run tests
if: github.event_name == 'merge_group' || !matrix.python-version.coverage
run: |
chmod +x $HUGR_BIN
uv run pytest
HUGR_RENDER_DOT=1 uv run pytest
- name: Run python tests with coverage instrumentation
if: github.event_name != 'merge_group' && matrix.python-version.coverage
run: |
chmod +x $HUGR_BIN
uv run pytest --cov=./ --cov-report=xml
HUGR_RENDER_DOT=1 uv run pytest --cov=./ --cov-report=xml
- name: Upload python coverage to codecov.io
if: github.event_name != 'merge_group' && matrix.python-version.coverage
Expand Down
1 change: 1 addition & 0 deletions devenv.nix
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ in
pkgs.llvmPackages_16.libllvm
# cargo-llvm-cov is currently marked broken on nixpkgs unstable
pkgs-stable.cargo-llvm-cov
pkgs.graphviz
] ++ lib.optionals
pkgs.stdenv.isDarwin
(with pkgs.darwin.apple_sdk; [
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def display_name(self) -> str:
def deserialize(self) -> ops.Custom:
return ops.Custom(
extension=self.extension,
name=self.name,
op_name=self.name,
signature=self.signature.deserialize(),
args=deser_it(self.args),
)
Expand Down
7 changes: 7 additions & 0 deletions hugr-py/src/hugr/_serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ def join(*bs: TypeBound) -> TypeBound:
res = b
return res

def __str__(self) -> str:
match self:
case TypeBound.Copyable:
return "Copyable"
case TypeBound.Any:
return "Any"


class Opaque(BaseType):
"""An opaque Type that can be downcasted by the extensions that define it."""
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/build/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ class Function(DfBase[ops.FuncDefn]):
Examples:
>>> f = Function("f", [tys.Bool])
>>> f.parent_op
FuncDefn(name='f', inputs=[Bool], params=[])
FuncDefn(f_name='f', inputs=[Bool], params=[])
"""

def __init__(
Expand Down
6 changes: 6 additions & 0 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ def _to_serial(self) -> ext_s.OpDef:
lower_funcs=[f._to_serial() for f in self.lower_funcs],
)

def qualified_name(self) -> str:
ext_name = self._extension.name if self._extension else ""
if ext_name:
return f"{ext_name}.{self.name}"
return self.name


@dataclass
class ExtensionValue(ExtensionObject):
Expand Down
16 changes: 8 additions & 8 deletions hugr-py/src/hugr/hugr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from hugr import ext
from hugr.val import Value

from .render import RenderConfig


@dataclass()
class NodeData:
Expand Down Expand Up @@ -677,32 +679,30 @@ def load_json(cls, json_str: str) -> Hugr:
serial = SerialHugr.load_json(json_dict)
return cls._from_serial(serial)

def render_dot(self, palette: str | None = None) -> gv.Digraph:
def render_dot(self, config: RenderConfig | None = None) -> gv.Digraph:
"""Render the HUGR to a graphviz Digraph.
Args:
palette: The palette to use for rendering. See :obj:`PALETTE` for the
included options.
config: Render configuration.
Returns:
The graphviz Digraph.
"""
from .render import DotRenderer

return DotRenderer(palette).render(self)
return DotRenderer(config).render(self)

def store_dot(
self, filename: str, format: str = "svg", palette: str | None = None
self, filename: str, format: str = "svg", config: RenderConfig | None = None
) -> None:
"""Render the HUGR to a graphviz dot file.
Args:
filename: The file to render to.
format: The format used for rendering ('pdf', 'png', etc.).
Defaults to SVG.
palette: The palette to use for rendering. See :obj:`PALETTE` for the
included options.
config: Render configuration.
"""
from .render import DotRenderer

DotRenderer(palette).store(self, filename=filename, format=format)
DotRenderer(config).store(self, filename=filename, format=format)
71 changes: 42 additions & 29 deletions hugr-py/src/hugr/hugr/render.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Visualise HUGR using graphviz."""

from collections.abc import Iterable
from dataclasses import dataclass
from dataclasses import dataclass, field

import graphviz as gv # type: ignore[import-untyped]
from graphviz import Digraph
from typing_extensions import assert_never

from hugr.hugr import Hugr
from hugr.ops import AsExtOp
from hugr.tys import CFKind, ConstKind, FunctionKind, Kind, OrderKind, ValueKind

from .node_port import InPort, Node, OutPort
Expand All @@ -26,6 +27,10 @@ class Palette:
node_border: str
port_border: str

@classmethod
def named(cls, name: str) -> "Palette":
return PALETTE[name]


PALETTE: dict[str, Palette] = {
"default": Palette(
Expand Down Expand Up @@ -61,22 +66,27 @@ class Palette:
}


@dataclass
class RenderConfig:
"""Configuration for rendering a HUGR to a graphviz dot file."""

#: The palette to use for rendering. See :obj:`PALETTE` for the included options.
palette: Palette = field(default_factory=lambda: PALETTE["default"])
#: If true prepend extension name to operation name.
qualify_op_name: bool = False


class DotRenderer:
"""Render a HUGR to a graphviz dot file.
Args:
palette: The palette to use for rendering. See :obj:`PALETTE` for the
included options.
config: Render config
"""

palette: Palette
config: RenderConfig

def __init__(self, palette: Palette | str | None = None) -> None:
if palette is None:
palette = "default"
if isinstance(palette, str):
palette = PALETTE[palette]
self.palette = palette
def __init__(self, config: RenderConfig | None = None) -> None:
self.config = config or RenderConfig()

def render(self, hugr: Hugr) -> Digraph:
"""Render a HUGR to a graphviz dot object."""
Expand All @@ -85,7 +95,7 @@ def render(self, hugr: Hugr) -> Digraph:
"ranksep": "0.1",
"nodesep": "0.15",
"margin": "0",
"bgcolor": self.palette.background,
"bgcolor": self.config.palette.background,
}
if not (name := hugr[hugr.root].metadata.get("name", None)):
name = ""
Expand Down Expand Up @@ -155,11 +165,11 @@ def store(self, hugr: Hugr, filename: str, format: str = "svg") -> None:

def _format_html_label(self, **kwargs: str) -> str:
_HTML_LABEL_DEFAULTS = {
"label_color": self.palette.dark,
"node_back_color": self.palette.node,
"label_color": self.config.palette.dark,
"node_back_color": self.config.palette.node,
"inputs_row": "",
"outputs_row": "",
"border_colour": self.palette.port_border,
"border_colour": self.config.palette.port_border,
"border_width": "1",
"fontface": self._FONTFACE,
"fontsize": 11.0,
Expand All @@ -174,10 +184,10 @@ def _html_ports(self, ports: Iterable[str], id_prefix: str) -> str:
# differentiate input and output node identifiers
# with a prefix
port_id=id_prefix + port,
back_colour=self.palette.background,
font_colour=self.palette.dark,
back_colour=self.config.palette.background,
font_colour=self.config.palette.dark,
border_width="1",
border_colour=self.palette.port_border,
border_colour=self.config.palette.port_border,
fontface=self._FONTFACE,
)
for port in ports
Expand Down Expand Up @@ -218,29 +228,32 @@ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None:
)

op = hugr[node].op

if isinstance(op, AsExtOp) and not self.config.qualify_op_name:
op_name = op.op_def().name
else:
op_name = op.name()
if hugr.children(node):
with graph.subgraph(name=f"cluster{node.idx}") as sub:
for child in hugr.children(node):
self._viz_node(child, hugr, sub)
html_label = self._format_html_label(
node_back_color=self.palette.edge,
node_label=str(op),
node_back_color=self.config.palette.edge,
node_label=op_name,
node_data=data,
border_colour=self.palette.port_border,
border_colour=self.config.palette.port_border,
inputs_row=inputs_row,
outputs_row=outputs_row,
)
sub.node(f"{node.idx}", shape="plain", label=f"<{html_label}>")
sub.attr(label="", margin="10", color=self.palette.edge)
sub.attr(label="", margin="10", color=self.config.palette.edge)
else:
html_label = self._format_html_label(
node_back_color=self.palette.node,
node_label=str(op),
node_back_color=self.config.palette.node,
node_label=op_name,
node_data=data,
inputs_row=inputs_row,
outputs_row=outputs_row,
border_colour=self.palette.background,
border_colour=self.config.palette.background,
)
graph.node(f"{node.idx}", label=f"<{html_label}>", shape="plain")

Expand All @@ -260,13 +273,13 @@ def _viz_link(
match kind:
case ValueKind(ty):
label = str(ty)
color = self.palette.edge
color = self.config.palette.edge
case OrderKind():
color = self.palette.dark
color = self.config.palette.dark
case ConstKind() | FunctionKind():
color = self.palette.const
color = self.config.palette.const
case CFKind():
color = self.palette.dark
color = self.config.palette.dark
case _:
assert_never(kind)

Expand Down
Loading

0 comments on commit aca403a

Please sign in to comment.