Skip to content

Commit

Permalink
Run pydot/graphviz tests in CI
Browse files Browse the repository at this point in the history
Closes #151
  • Loading branch information
michaelosthege authored and ferrine committed Dec 27, 2022
1 parent f92e109 commit f4de2fd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ dependencies:
- typing_extensions
# optional
- cython

- graphviz
- pydot
2 changes: 1 addition & 1 deletion tests/d3viz/test_d3viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytensor import compile
from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.d3viz.formatting import pydot_imported, pydot_imported_msg
from pytensor.printing import pydot_imported, pydot_imported_msg
from tests.d3viz import models


Expand Down
23 changes: 13 additions & 10 deletions tests/d3viz/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pytest

from pytensor import config, function
from pytensor.d3viz.formatting import PyDotFormatter, pydot_imported, pydot_imported_msg
from pytensor.d3viz.formatting import PyDotFormatter
from pytensor.printing import pydot_imported, pydot_imported_msg


if not pydot_imported:
Expand All @@ -21,21 +22,23 @@ def node_counts(self, graph):
nc = dict(zip(a, b))
return nc

def test_mlp(self):
@pytest.mark.parametrize("mode", ["FAST_RUN", "FAST_COMPILE"])
def test_mlp(self, mode):
m = models.Mlp()
f = function(m.inputs, m.outputs)
f = function(m.inputs, m.outputs, mode=mode)
pdf = PyDotFormatter()
graph = pdf(f)
expected = 11
if config.mode == "FAST_COMPILE":
expected = 12
if mode == "FAST_RUN":
expected = 13
elif mode == "FAST_COMPILE":
expected = 14
assert len(graph.get_nodes()) == expected
nc = self.node_counts(graph)

if config.mode == "FAST_COMPILE":
assert nc["apply"] == 6
else:
assert nc["apply"] == 5
if mode == "FAST_RUN":
assert nc["apply"] == 7
elif mode == "FAST_COMPILE":
assert nc["apply"] == 8
assert nc["output"] == 1

def test_ofg(self):
Expand Down

0 comments on commit f4de2fd

Please sign in to comment.