Skip to content

Commit d56381e

Browse files
Do not always remap storage in fgraph_to_python
1 parent d1af711 commit d56381e

File tree

4 files changed

+45
-19
lines changed

4 files changed

+45
-19
lines changed

aesara/link/numba/dispatch/basic.py

+3
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,9 @@ def perform(*inputs):
377377

378378
@numba_funcify.register(OpFromGraph)
379379
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
380+
381+
_ = kwargs.pop("storage_map", None)
382+
380383
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
381384

382385
if len(op.fgraph.outputs) == 1:

aesara/link/numba/dispatch/scalar.py

+3
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ def clip(_x, _min, _max):
221221
@numba_funcify.register(Composite)
222222
def numba_funcify_Composite(op, node, **kwargs):
223223
signature = create_numba_signature(node, force_scalar=True)
224+
225+
_ = kwargs.pop("storage_map", None)
226+
224227
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
225228
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
226229
)

aesara/link/utils.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,6 @@ def fgraph_to_python(
678678
*,
679679
type_conversion_fn: Callable = lambda x, **kwargs: x,
680680
order: Optional[List[Apply]] = None,
681-
input_storage: Optional["InputStorageType"] = None,
682-
output_storage: Optional["OutputStorageType"] = None,
683681
storage_map: Optional["StorageMapType"] = None,
684682
fgraph_name: str = "fgraph_to_python",
685683
global_env: Optional[Dict[Any, Any]] = None,
@@ -704,10 +702,6 @@ def fgraph_to_python(
704702
``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``.
705703
order
706704
The `order` argument to `map_storage`.
707-
input_storage
708-
The `input_storage` argument to `map_storage`.
709-
output_storage
710-
The `output_storage` argument to `map_storage`.
711705
storage_map
712706
The `storage_map` argument to `map_storage`.
713707
fgraph_name
@@ -730,9 +724,9 @@ def fgraph_to_python(
730724

731725
if order is None:
732726
order = fgraph.toposort()
733-
input_storage, output_storage, storage_map = map_storage(
734-
fgraph, order, input_storage, output_storage, storage_map
735-
)
727+
728+
if storage_map is None:
729+
storage_map = {}
736730

737731
unique_name = unique_name_generator([fgraph_name])
738732

@@ -752,31 +746,38 @@ def fgraph_to_python(
752746
node_input_names = []
753747
for i in node.inputs:
754748
local_input_name = unique_name(i)
755-
if storage_map[i][0] is not None or isinstance(i, Constant):
749+
input_storage = storage_map.setdefault(
750+
i, [None if not isinstance(i, Constant) else i.data]
751+
)
752+
if input_storage[0] is not None or isinstance(i, Constant):
756753
# Constants need to be assigned locally and referenced
757754
global_env[local_input_name] = type_conversion_fn(
758-
storage_map[i][0], variable=i, storage=storage_map[i], **kwargs
755+
input_storage[0], variable=i, storage=input_storage, **kwargs
759756
)
760757
# TODO: We could attempt to use the storage arrays directly
761758
# E.g. `local_input_name = f"{local_input_name}[0]"`
762759
node_input_names.append(local_input_name)
763760

764761
node_output_names = [unique_name(v) for v in node.outputs]
765762

766-
assign_comment_str = f"{indent(str(node), '# ')}"
767763
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
768-
body_assigns.append(f"{assign_comment_str}\n{assign_str}")
764+
assign_comment_str = f"{indent(str(node), '# ')}"
765+
assign_block_str = f"{assign_comment_str}\n{assign_str}"
766+
body_assigns.append(assign_block_str)
769767

770768
# Handle `Constant`-only outputs (these don't have associated `Apply`
771769
# nodes, so the above isn't applicable)
772770
for out in fgraph.outputs:
773771
if isinstance(out, Constant):
774-
local_input_name = unique_name(out)
775-
if local_input_name not in global_env:
776-
global_env[local_input_name] = type_conversion_fn(
777-
storage_map[out][0],
772+
local_output_name = unique_name(out)
773+
if local_output_name not in global_env:
774+
output_storage = storage_map.setdefault(
775+
out, [None if not isinstance(out, Constant) else out.data]
776+
)
777+
global_env[local_output_name] = type_conversion_fn(
778+
output_storage[0],
778779
variable=out,
779-
storage=storage_map[out],
780+
storage=output_storage,
780781
**kwargs,
781782
)
782783

@@ -794,7 +795,7 @@ def fgraph_to_python(
794795
fgraph_def_src = dedent(
795796
f"""
796797
def {fgraph_name}({", ".join(fgraph_input_names)}):
797-
{indent(joined_body_assigns, " " * 4)}
798+
{indent(joined_body_assigns, " " * 4)}
798799
return {fgraph_return_src}
799800
"""
800801
).strip()

tests/link/test_utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,25 @@ def test_fgraph_to_python_constant_outputs():
176176
assert out_py()[0] is y.data
177177

178178

179+
def test_fgraph_to_python_constant_inputs():
180+
x = constant([1.0])
181+
y = vector("y")
182+
183+
out = x + y
184+
out_fg = FunctionGraph(outputs=[out], clone=False)
185+
186+
out_py = fgraph_to_python(out_fg, to_python, storage_map=None)
187+
188+
res = out_py(2.0)
189+
assert res == (3.0,)
190+
191+
storage_map = {out: [None], x: [np.r_[2.0]], y: [None]}
192+
out_py = fgraph_to_python(out_fg, to_python, storage_map=storage_map)
193+
194+
res = out_py(2.0)
195+
assert res == (4.0,)
196+
197+
179198
def test_unique_name_generator():
180199

181200
unique_names = unique_name_generator(["blah"], suffix_sep="_")

0 commit comments

Comments
 (0)