Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to save graph as PNG #523

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .ci_support/environment-mpich.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- mpi4py =4.0.1
- pyzmq =26.2.0
- h5py =3.12.1
- matplotlib =3.10.0
- networkx =3.4.2
- pygraphviz =1.14
- ipython =8.30.0
Expand Down
1 change: 0 additions & 1 deletion .ci_support/environment-old.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- mpi4py =3.1.4
- pyzmq =25.0.0
- h5py =3.6.0
- matplotlib =3.5.3
- networkx =2.8.8
- ipython =7.33.0
- pygraphviz =1.10
1 change: 0 additions & 1 deletion .ci_support/environment-openmpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- mpi4py =4.0.1
- pyzmq =26.2.0
- h5py =3.12.1
- matplotlib =3.10.0
- networkx =3.4.2
- pygraphviz =1.14
- pysqa =0.2.2
Expand Down
1 change: 0 additions & 1 deletion .ci_support/environment-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- mpi4py =4.0.1
- pyzmq =26.2.0
- h5py =3.12.1
- matplotlib =3.10.0
- networkx =3.4.2
- pygraphviz =1.14
- ipython =8.30.0
Expand Down
5 changes: 5 additions & 0 deletions executorlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Executor:
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
plot_dependency_graph (bool): Plot the dependencies of multiple future objects without executing them. For
debugging purposes and to get an overview of the specified dependencies.
plot_dependency_graph_filename (str): Name of the file to store the plotted graph in.

Examples:
```
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
disable_dependencies: bool = False,
refresh_rate: float = 0.01,
plot_dependency_graph: bool = False,
plot_dependency_graph_filename: Optional[str] = None,
):
# Use __new__() instead of __init__(). This function is only implemented to enable auto-completion.
pass
Expand All @@ -122,6 +124,7 @@ def __new__(
disable_dependencies: bool = False,
refresh_rate: float = 0.01,
plot_dependency_graph: bool = False,
plot_dependency_graph_filename: Optional[str] = None,
):
"""
Instead of returning a executorlib.Executor object this function returns either a executorlib.mpi.PyMPIExecutor,
Expand Down Expand Up @@ -167,6 +170,7 @@ def __new__(
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
plot_dependency_graph (bool): Plot the dependencies of multiple future objects without executing them. For
debugging purposes and to get an overview of the specified dependencies.
plot_dependency_graph_filename (str): Name of the file to store the plotted graph in.

"""
default_resource_dict = {
Expand Down Expand Up @@ -216,6 +220,7 @@ def __new__(
init_function=init_function,
refresh_rate=refresh_rate,
plot_dependency_graph=plot_dependency_graph,
plot_dependency_graph_filename=plot_dependency_graph_filename,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation for the filename parameter.

The plot_dependency_graph_filename parameter should be validated before being passed to _ExecutorWithDependencies.

Add a validation check similar to other parameters:

def _check_plot_dependency_graph_filename(filename: Optional[str]) -> None:
    if filename is not None and not isinstance(filename, str):
        raise TypeError("plot_dependency_graph_filename must be a string or None")

Then use it before the _ExecutorWithDependencies instantiation:

+            _check_plot_dependency_graph_filename(plot_dependency_graph_filename)
             return _ExecutorWithDependencies(
                 max_workers=max_workers,
                 ...

)
else:
_check_pysqa_config_directory(pysqa_config_directory=pysqa_config_directory)
Expand Down
15 changes: 13 additions & 2 deletions executorlib/interactive/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ class ExecutorWithDependencies(ExecutorBase):
Args:
refresh_rate (float, optional): The refresh rate for updating the executor queue. Defaults to 0.01.
plot_dependency_graph (bool, optional): Whether to generate and plot the dependency graph. Defaults to False.
plot_dependency_graph_filename (str): Name of the file to store the plotted graph in.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.

Attributes:
_future_hash_dict (Dict[str, Future]): A dictionary mapping task hash to future object.
_task_hash_dict (Dict[str, Dict]): A dictionary mapping task hash to task dictionary.
_generate_dependency_graph (bool): Whether to generate the dependency graph.
_generate_dependency_graph (str): Name of the file to store the plotted graph in.

"""

Expand All @@ -57,6 +59,7 @@ def __init__(
*args: Any,
refresh_rate: float = 0.01,
plot_dependency_graph: bool = False,
plot_dependency_graph_filename: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(max_cores=kwargs.get("max_cores", None))
Expand All @@ -75,7 +78,11 @@ def __init__(
)
self._future_hash_dict = {}
self._task_hash_dict = {}
self._generate_dependency_graph = plot_dependency_graph
self._plot_dependency_graph_filename = plot_dependency_graph_filename
if plot_dependency_graph_filename is None:
self._generate_dependency_graph = plot_dependency_graph
else:
self._generate_dependency_graph = True

def submit(
self,
Expand Down Expand Up @@ -142,7 +149,11 @@ def __exit__(
v: k for k, v in self._future_hash_dict.items()
},
)
return draw(node_lst=node_lst, edge_lst=edge_lst)
return draw(
node_lst=node_lst,
edge_lst=edge_lst,
filename=self._plot_dependency_graph_filename,
)
Comment on lines +152 to +156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for file operations.

The draw function call should handle potential file operation errors when saving the graph.

-            return draw(
-                node_lst=node_lst,
-                edge_lst=edge_lst,
-                filename=self._plot_dependency_graph_filename,
-            )
+            try:
+                return draw(
+                    node_lst=node_lst,
+                    edge_lst=edge_lst,
+                    filename=self._plot_dependency_graph_filename,
+                )
+            except IOError as e:
+                raise IOError(f"Failed to save dependency graph: {e}") from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return draw(
node_lst=node_lst,
edge_lst=edge_lst,
filename=self._plot_dependency_graph_filename,
)
try:
return draw(
node_lst=node_lst,
edge_lst=edge_lst,
filename=self._plot_dependency_graph_filename,
)
except IOError as e:
raise IOError(f"Failed to save dependency graph: {e}") from e



def create_executor(
Expand Down
19 changes: 12 additions & 7 deletions executorlib/standalone/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os.path
from concurrent.futures import Future
from typing import Tuple
from typing import Optional, Tuple

import cloudpickle

Expand Down Expand Up @@ -106,23 +107,27 @@ def convert_arg(arg, future_hash_inverse_dict):
)


def draw(node_lst: list, edge_lst: list):
def draw(node_lst: list, edge_lst: list, filename: Optional[str] = None):
"""
Draw the graph visualization of nodes and edges.

Args:
node_lst (list): List of nodes.
edge_lst (list): List of edges.
filename (str): Name of the file to store the plotted graph in.
"""
from IPython.display import SVG, display # noqa
import matplotlib.pyplot as plt # noqa
import networkx as nx # noqa

graph = nx.DiGraph()
for node in node_lst:
graph.add_node(node["id"], label=node["name"], shape=node["shape"])
for edge in edge_lst:
graph.add_edge(edge["start"], edge["end"], label=edge["label"])
svg = nx.nx_agraph.to_agraph(graph).draw(prog="dot", format="svg")
display(SVG(svg))
plt.show()
if filename is not None:
file_format = os.path.splitext(filename)[-1][1:]
with open(filename, "wb") as f:
f.write(nx.nx_agraph.to_agraph(graph).draw(prog="dot", format=file_format))
Comment on lines +126 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for file operations

The file operations should include error handling for common issues like permission errors or invalid paths.

     if filename is not None:
         file_format = os.path.splitext(filename)[-1][1:]
+        if not file_format:
+            raise ValueError("Filename must have an extension (e.g., .png, .svg, .pdf)")
+        if file_format not in ['png', 'svg', 'pdf']:
+            raise ValueError(f"Unsupported file format: {file_format}")
+        try:
             with open(filename, "wb") as f:
                 f.write(nx.nx_agraph.to_agraph(graph).draw(prog="dot", format=file_format))
+        except (OSError, IOError) as e:
+            raise IOError(f"Failed to save graph to {filename}: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if filename is not None:
file_format = os.path.splitext(filename)[-1][1:]
with open(filename, "wb") as f:
f.write(nx.nx_agraph.to_agraph(graph).draw(prog="dot", format=file_format))
if filename is not None:
file_format = os.path.splitext(filename)[-1][1:]
if not file_format:
raise ValueError("Filename must have an extension (e.g., .png, .svg, .pdf)")
if file_format not in ['png', 'svg', 'pdf']:
raise ValueError(f"Unsupported file format: {file_format}")
try:
with open(filename, "wb") as f:
f.write(nx.nx_agraph.to_agraph(graph).draw(prog="dot", format=file_format))
except (OSError, IOError) as e:
raise IOError(f"Failed to save graph to {filename}: {str(e)}")

else:
from IPython.display import SVG, display # noqa

display(SVG(nx.nx_agraph.to_agraph(graph).draw(prog="dot", format="svg")))
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ Repository = "https://github.com/pyiron/executorlib"
cache = ["h5py==3.12.1"]
graph = [
"pygraphviz==1.14",
"matplotlib==3.10.0",
"networkx==3.4.2",
]
graphnotebook = [
"pygraphviz==1.14",
"networkx==3.4.2",
"ipython==8.30.0",
]
Expand All @@ -53,7 +56,6 @@ all = [
"pysqa==0.2.2",
"h5py==3.12.1",
"pygraphviz==1.14",
"matplotlib==3.10.0",
"networkx==3.4.2",
"ipython==8.30.0",
]
Expand Down
21 changes: 21 additions & 0 deletions tests/test_dependencies_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from concurrent.futures import Future
import os
import unittest
from time import sleep
from queue import Queue
Expand Down Expand Up @@ -73,6 +74,26 @@ def test_executor_dependency_plot(self):
self.assertEqual(len(nodes), 5)
self.assertEqual(len(edges), 4)

@unittest.skipIf(
skip_graphviz_test,
"graphviz is not installed, so the plot_dependency_graph tests are skipped.",
)
def test_executor_dependency_plot_filename(self):
graph_file = os.path.join(os.path.dirname(__file__), "test.png")
with Executor(
max_cores=1,
backend="local",
plot_dependency_graph=False,
plot_dependency_graph_filename=graph_file,
) as exe:
cloudpickle_register(ind=1)
future_1 = exe.submit(add_function, 1, parameter_2=2)
future_2 = exe.submit(add_function, 1, parameter_2=future_1)
self.assertTrue(future_1.done())
self.assertTrue(future_2.done())
self.assertTrue(os.path.exists(graph_file))
# os.remove(graph_file)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance test coverage for graph file output

The test case needs improvements:

  1. The commented cleanup code should be uncommented
  2. Add tests for other file formats
  3. Add negative test cases for invalid formats/paths
     def test_executor_dependency_plot_filename(self):
-        graph_file = os.path.join(os.path.dirname(__file__), "test.png")
-        with Executor(
-            max_cores=1,
-            backend="local",
-            plot_dependency_graph=False,
-            plot_dependency_graph_filename=graph_file,
-        ) as exe:
-            cloudpickle_register(ind=1)
-            future_1 = exe.submit(add_function, 1, parameter_2=2)
-            future_2 = exe.submit(add_function, 1, parameter_2=future_1)
-            self.assertTrue(future_1.done())
-            self.assertTrue(future_2.done())
-        self.assertTrue(os.path.exists(graph_file))
-        # os.remove(graph_file)
+        test_files = {
+            'png': os.path.join(os.path.dirname(__file__), "test.png"),
+            'svg': os.path.join(os.path.dirname(__file__), "test.svg"),
+            'pdf': os.path.join(os.path.dirname(__file__), "test.pdf")
+        }
+        
+        try:
+            # Test valid formats
+            for fmt, graph_file in test_files.items():
+                with Executor(
+                    max_cores=1,
+                    backend="local",
+                    plot_dependency_graph=False,
+                    plot_dependency_graph_filename=graph_file,
+                ) as exe:
+                    cloudpickle_register(ind=1)
+                    future_1 = exe.submit(add_function, 1, parameter_2=2)
+                    future_2 = exe.submit(add_function, 1, parameter_2=future_1)
+                    self.assertTrue(future_1.done())
+                    self.assertTrue(future_2.done())
+                self.assertTrue(os.path.exists(graph_file))
+
+            # Test invalid format
+            with self.assertRaises(ValueError):
+                with Executor(
+                    max_cores=1,
+                    backend="local",
+                    plot_dependency_graph_filename="test.invalid"
+                ) as exe:
+                    future_1 = exe.submit(add_function, 1, parameter_2=2)
+
+            # Test invalid path
+            with self.assertRaises(IOError):
+                with Executor(
+                    max_cores=1,
+                    backend="local",
+                    plot_dependency_graph_filename="/invalid/path/test.png"
+                ) as exe:
+                    future_1 = exe.submit(add_function, 1, parameter_2=2)
+
+        finally:
+            # Cleanup
+            for graph_file in test_files.values():
+                if os.path.exists(graph_file):
+                    os.remove(graph_file)

Committable suggestion skipped: line range outside the PR's diff.

def test_create_executor_error(self):
with self.assertRaises(ValueError):
create_executor(backend="toast", resource_dict={"cores": 1})
Expand Down
Loading