Skip to content

Commit

Permalink
Make fgraph_to_python process constant FunctionGraph outputs correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 5, 2022
1 parent 3d96ee8 commit 6232637
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
25 changes: 20 additions & 5 deletions aesara/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from keyword import iskeyword
from operator import itemgetter
from tempfile import NamedTemporaryFile
from textwrap import indent
from textwrap import dedent, indent
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -767,6 +767,19 @@ def fgraph_to_python(
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
body_assigns.append(f"{assign_comment_str}\n{assign_str}")

# Handle `Constant`-only outputs (these don't have associated `Apply`
# nodes, so the above isn't applicable)
for out in fgraph.outputs:
if isinstance(out, Constant):
local_input_name = unique_name(out)
if local_input_name not in global_env:
global_env[local_input_name] = type_conversion_fn(
storage_map[out][0],
variable=out,
storage=storage_map[out],
**kwargs,
)

fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
fgraph_output_names = [unique_name(v) for v in fgraph.outputs]
joined_body_assigns = indent("\n".join(body_assigns), " ")
Expand All @@ -778,11 +791,13 @@ def fgraph_to_python(
else:
fgraph_return_src = ", ".join(fgraph_output_names)

fgraph_def_src = f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{joined_body_assigns}
return {fgraph_return_src}
fgraph_def_src = dedent(
f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{indent(joined_body_assigns, " " * 4)}
return {fgraph_return_src}
"""
).strip()

if local_env is None:
local_env = locals()
Expand Down
13 changes: 13 additions & 0 deletions tests/link/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
unique_name_generator,
)
from aesara.scalar.basic import Add, float64
from aesara.tensor import constant
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar, vector
from aesara.tensor.type_other import NoneConst
Expand Down Expand Up @@ -163,6 +164,18 @@ def func(*args, op=op):
)


def test_fgraph_to_python_constant_outputs():
"""Make sure that constant outputs are handled properly."""

y = constant(1)

out_fg = FunctionGraph([], [y], clone=False)

out_py = fgraph_to_python(out_fg, to_python)

assert out_py()[0] is y.data


def test_unique_name_generator():

unique_names = unique_name_generator(["blah"], suffix_sep="_")
Expand Down

0 comments on commit 6232637

Please sign in to comment.