Skip to content

Commit

Permalink
Do not always remap storage in fgraph_to_python
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 11, 2022
1 parent 643c973 commit 31b77a2
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 19 deletions.
3 changes: 3 additions & 0 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def perform(*inputs):

@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):

_ = kwargs.pop("storage_map", None)

fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))

if len(op.fgraph.outputs) == 1:
Expand Down
3 changes: 3 additions & 0 deletions aesara/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def clip(_x, _min, _max):
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)

_ = kwargs.pop("storage_map", None)

composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
Expand Down
39 changes: 20 additions & 19 deletions aesara/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,6 @@ def fgraph_to_python(
*,
type_conversion_fn: Callable = lambda x, **kwargs: x,
order: Optional[List[Apply]] = None,
input_storage: Optional["InputStorageType"] = None,
output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional["StorageMapType"] = None,
fgraph_name: str = "fgraph_to_python",
global_env: Optional[Dict[Any, Any]] = None,
Expand All @@ -704,10 +702,6 @@ def fgraph_to_python(
``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``.
order
The `order` argument to `map_storage`.
input_storage
The `input_storage` argument to `map_storage`.
output_storage
The `output_storage` argument to `map_storage`.
storage_map
The `storage_map` argument to `map_storage`.
fgraph_name
Expand All @@ -730,9 +724,9 @@ def fgraph_to_python(

if order is None:
order = fgraph.toposort()
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)

if storage_map is None:
storage_map = {}

unique_name = unique_name_generator([fgraph_name])

Expand All @@ -752,31 +746,38 @@ def fgraph_to_python(
node_input_names = []
for i in node.inputs:
local_input_name = unique_name(i)
if storage_map[i][0] is not None or isinstance(i, Constant):
input_storage = storage_map.setdefault(
i, [None if not isinstance(i, Constant) else i.data]
)
if input_storage[0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = type_conversion_fn(
storage_map[i][0], variable=i, storage=storage_map[i], **kwargs
input_storage[0], variable=i, storage=input_storage, **kwargs
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names.append(local_input_name)

node_output_names = [unique_name(v) for v in node.outputs]

assign_comment_str = f"{indent(str(node), '# ')}"
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
body_assigns.append(f"{assign_comment_str}\n{assign_str}")
assign_comment_str = f"{indent(str(node), '# ')}"
assign_block_str = f"{assign_comment_str}\n{assign_str}"
body_assigns.append(assign_block_str)

# Handle `Constant`-only outputs (these don't have associated `Apply`
# nodes, so the above isn't applicable)
for out in fgraph.outputs:
if isinstance(out, Constant):
local_input_name = unique_name(out)
if local_input_name not in global_env:
global_env[local_input_name] = type_conversion_fn(
storage_map[out][0],
local_output_name = unique_name(out)
if local_output_name not in global_env:
output_storage = storage_map.setdefault(
out, [None if not isinstance(out, Constant) else out.data]
)
global_env[local_output_name] = type_conversion_fn(
output_storage[0],
variable=out,
storage=storage_map[out],
storage=output_storage,
**kwargs,
)

Expand All @@ -794,7 +795,7 @@ def fgraph_to_python(
fgraph_def_src = dedent(
f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{indent(joined_body_assigns, " " * 4)}
{indent(joined_body_assigns, " " * 4)}
return {fgraph_return_src}
"""
).strip()
Expand Down
19 changes: 19 additions & 0 deletions tests/link/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,25 @@ def test_fgraph_to_python_constant_outputs():
assert out_py()[0] is y.data


def test_fgraph_to_python_constant_inputs():
x = constant([1.0])
y = vector("y")

out = x + y
out_fg = FunctionGraph(outputs=[out], clone=False)

out_py = fgraph_to_python(out_fg, to_python, storage_map=None)

res = out_py(2.0)
assert res == (3.0,)

storage_map = {out: [None], x: [np.r_[2.0]], y: [None]}
out_py = fgraph_to_python(out_fg, to_python, storage_map=storage_map)

res = out_py(2.0)
assert res == (4.0,)


def test_unique_name_generator():

unique_names = unique_name_generator(["blah"], suffix_sep="_")
Expand Down

0 comments on commit 31b77a2

Please sign in to comment.