@@ -383,8 +383,8 @@ def plot(self, filename=None, show=False,
383383
384384 :param str filename:
385385 Write diagram into a file.
386- The extension must be one of: ``.png .dot .jpg .jpeg .pdf .svg``
387- Prefer ``.pdf`` or ``.svg`` to see solution-values in tooltips .
386+ Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
387+ call :func:`network.supported_plot_formats()` for more .
388388 :param boolean show:
389389 If it evaluates to true, opens the diagram in a matplotlib window.
390390 :param inputs:
@@ -454,15 +454,10 @@ def get_data_node(name, graph):
454454 return None
455455
456456
457- def supported_plot_writers ():
458- return {
459- ".png" : lambda gplot : gplot .create_png (),
460- ".dot" : lambda gplot : gplot .to_string (),
461- ".jpg" : lambda gplot : gplot .create_jpeg (),
462- ".jpeg" : lambda gplot : gplot .create_jpeg (),
463- ".pdf" : lambda gplot : gplot .create_pdf (),
464- ".svg" : lambda gplot : gplot .create_svg (),
465- }
457+ def supported_plot_formats ():
458+ import pydot
459+
460+ return [".%s" % f for f in pydot .Dot ().formats ]
466461
467462
468463def plot_graph (graph , filename = None , show = False , steps = None ,
@@ -496,8 +491,8 @@ def plot_graph(graph, filename=None, show=False, steps=None,
496491 what to plot
497492 :param str filename:
498493 Write diagram into a file.
499- The extension must be one of: ``.png .dot .jpg .jpeg .pdf .svg``
500- Prefer ``.pdf`` or ``.svg`` to see solution-values in tooltips .
494+ Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
495+ call :func:`network.supported_plot_formats()` for more .
501496 :param boolean show:
502497 If it evaluates to true, opens the diagram in a matplotlib window.
503498 :param steps:
@@ -603,16 +598,15 @@ def get_node_name(a):
603598
604599 # save plot
605600 if filename :
601+ formats = supported_plot_formats ()
606602 _basename , ext = os .path .splitext (filename )
607- writers = supported_plot_writers ()
608- plot_writer = supported_plot_writers ().get (ext .lower ())
609- if not plot_writer :
603+ if not ext .lower () in formats :
610604 raise ValueError (
611605 "Unknown file format for saving graph: %s"
612606 " File extensions must be one of: %s"
613- % (ext , ' ' .join (writers )))
614- with open ( filename , "wb" ) as fh :
615- fh .write (plot_writer ( g ) )
607+ % (ext , " " .join (formats )))
608+
609+ g .write (filename , format = ext . lower ()[ 1 :] )
616610
617611 # display graph via matplotlib
618612 if show :
0 commit comments