Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message in DAGCircuit.draw for invalid filenames #7447

Merged
merged 6 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions qiskit/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ def __init__(
def __str__(self) -> str:
"""Return the message."""
return repr(self.message)


class InvalidFileError(QiskitError):
"""Raised when the file provided is not valid for the specific task."""
67 changes: 66 additions & 1 deletion qiskit/visualization/dag_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tempfile

from qiskit.dagcircuit.dagnode import DAGOpNode, DAGInNode, DAGOutNode
from qiskit.exceptions import MissingOptionalLibraryError
from qiskit.exceptions import MissingOptionalLibraryError, InvalidFileError
from .exceptions import VisualizationError

try:
Expand All @@ -31,6 +31,63 @@
except ImportError:
HAS_PIL = False

FILENAME_EXTENSIONS = [
"bmp",
"canon",
"cgimage",
"cmap",
"cmapx",
"cmapx_np",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did you find this full list? Perhaps we could cut it down to some more common ones - I think things like cmapx_np might be a bit more noise than use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pydot library was throwing a exception with all those extensions...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see - it's the failing output of dot -Tunknown.

"dot",
"dot_json",
"eps exr",
jakelishman marked this conversation as resolved.
Show resolved Hide resolved
"fig",
"gd",
"gd2",
"gif",
"gv",
"icns",
"ico",
"imap",
"imap_np",
"ismap",
"jp2",
"jpe",
"jpeg",
"jpg",
"json",
"json0",
"mp",
"pct",
"pdf",
"pic",
"pict",
"plain",
"plain-ext",
"png",
"pov",
"ps",
"ps2",
"psd",
"sgi",
"svg",
"svgz",
"tga",
"tif",
"tiff",
"tk",
"vdx",
"vml",
"vmlz",
"vrml",
"wbmp",
"webp",
"xdot",
"xdot1.2",
"xdot1.4",
"xdot_json",
]


def dag_drawer(dag, scale=0.7, filename=None, style="color"):
"""Plot the directed acyclic graph (dag) to represent operation dependencies
Expand Down Expand Up @@ -59,6 +116,7 @@ def dag_drawer(dag, scale=0.7, filename=None, style="color"):
Raises:
VisualizationError: when style is not recognized.
MissingOptionalLibraryError: when pydot or pillow are not installed.
InvalidFileError: when filename provided is not valid

Example:
.. jupyter-execute::
Expand Down Expand Up @@ -166,7 +224,14 @@ def edge_attr_func(edge):
dot = pydot.graph_from_dot_data(dot_str)[0]

if filename:
if "." not in filename:
raise InvalidFileError("Parameter 'filename' must be in format 'name.extension'")
extension = filename.split(".")[-1]
if extension not in FILENAME_EXTENSIONS:
jakelishman marked this conversation as resolved.
Show resolved Hide resolved
raise InvalidFileError(
"Filename extension must be one of: "
+ " ".join([str(elem) for elem in FILENAME_EXTENSIONS])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the elements in FILENAME_EXTENSIONS are already strings, so this is just the same as " ".join(FILENAME_EXTENSIONS).

)
dot.write(filename, format=extension)
return None
elif ("ipykernel" in sys.modules) and ("spyder" not in sys.modules):
Expand Down
27 changes: 27 additions & 0 deletions test/python/visualization/test_dag_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from qiskit import QuantumRegister, QuantumCircuit
from qiskit.test import QiskitTestCase
from qiskit.tools.visualization import dag_drawer
from qiskit.exceptions import InvalidFileError
from qiskit.visualization.exceptions import VisualizationError
from qiskit.converters import circuit_to_dag

Expand All @@ -36,6 +37,32 @@ def test_dag_drawer_invalid_style(self):
"""Test dag draw with invalid style."""
self.assertRaises(VisualizationError, dag_drawer, self.dag, style="multicolor")

def test_dag_drawer_checks_filename_correct_format(self):
"""filename must contain name and extension"""
try:
dag_drawer(self.dag, filename="aaabc")
self.fail("Expected error not raised!")
except InvalidFileError as exception_instance:
self.assertEqual(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, and in the next test, a clearer way of writing these tests of exceptions is

with self.assertRaisesRegex(InvalidFileError, "Parameter 'filename' must be in format 'name.extension'"):
    dag_drawer(self.dag, filename="aaabc")

exception_instance.message,
"Parameter 'filename' must be in format 'name.extension'",
)

def test_dag_drawer_checks_filename_extension(self):
"""filename must have a valid extension"""
try:
dag_drawer(self.dag, filename="aa.abc")
self.fail("Expected error not raised!")
except InvalidFileError as exception_instance:
self.assertEqual(
exception_instance.message,
"Filename extension must be one of: bmp canon cgimage cmap cmapx cmapx_np "
"dot dot_json eps exr fig gd gd2 gif gv icns ico imap imap_np ismap jp2 "
"jpe jpeg jpg json json0 mp pct pdf pic pict plain plain-ext png pov "
"ps ps2 psd sgi svg svgz tga tif tiff tk vdx vml vmlz vrml wbmp webp "
"xdot xdot1.2 xdot1.4 xdot_json",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not want to test that exactly this set of extensions are supported - it'll be ok just to test that the message starts with "Filename extension must be one of:"

)


if __name__ == "__main__":
unittest.main(verbosity=2)