From 6a3be905ad112d862be82fc9b25a49a6b4054402 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 26 Aug 2024 10:28:23 +0700 Subject: [PATCH] fix file path when render model --- numpyro/infer/inspect.py | 9 +++++++-- test/test_model_rendering.py | 17 ++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/inspect.py b/numpyro/infer/inspect.py index abfb9bd71..6b8d5d058 100644 --- a/numpyro/infer/inspect.py +++ b/numpyro/infer/inspect.py @@ -629,9 +629,14 @@ def render_model( if filename is not None: filename = Path(filename) + # remove leading period from suffix + filename_without_suffix = filename.with_suffix("") graph.render( - filename.stem, view=False, cleanup=True, format=filename.suffix[1:] - ) # remove leading period from suffix + filename_without_suffix, + view=False, + cleanup=True, + format=filename.suffix[1:], + ) return graph diff --git a/test/test_model_rendering.py b/test/test_model_rendering.py index 62b7cf65f..a543add90 100644 --- a/test/test_model_rendering.py +++ b/test/test_model_rendering.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import os + import numpy as np import pytest @@ -8,7 +10,11 @@ import numpyro import numpyro.distributions as dist -from numpyro.infer.inspect import generate_graph_specification, get_model_relations +from numpyro.infer.inspect import ( + generate_graph_specification, + get_model_relations, + render_model, +) def simple(data): @@ -129,3 +135,12 @@ def test_model_transformation(test_model, model_kwargs, expected_graph_spec): graph_spec = generate_graph_specification(relations) assert graph_spec == expected_graph_spec + + +def test_render_model_filename(): + def model(): + numpyro.sample("x", dist.Normal(0, 1)) + + render_model(model, filename="graph.png") + assert os.path.exists("graph.png") + os.remove("graph.png")