Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SharedVariable.default_update graphs to debugprint #1412

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 90 additions & 13 deletions aesara/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def debugprint(
print_destroy_map: bool = False,
print_view_map: bool = False,
print_fgraph_inputs: bool = False,
print_default_updates: bool = False,
ids: Optional[IDTypesType] = None,
) -> Union[str, TextIO]:
r"""Print a graph as text.
Expand Down Expand Up @@ -177,6 +178,8 @@ def debugprint(
Whether to print the `view_map`\s of printed objects
print_fgraph_inputs
Print the inputs of `FunctionGraph`\s.
print_default_updates
Print the `SharedVariable.default_update` values.

Returns
-------
Expand Down Expand Up @@ -263,6 +266,7 @@ def debugprint(
raise TypeError(f"debugprint cannot print an object type {type(obj)}")

inner_graph_vars: List[Variable] = []
default_updates: List[Variable] = []

if any(p for p in profile_list if p is not None and p.fct_callcount > 0):
print(
Expand Down Expand Up @@ -297,14 +301,16 @@ def debugprint(
print_type=print_type,
file=_file,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
stop_on_name=stop_on_name,
used_ids=used_ids,
op_information=op_information,
parent_node=var.owner,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

for var, profile, storage_map, topo_order in zip(
Expand All @@ -325,7 +331,7 @@ def debugprint(
file=_file,
topo_order=topo_order,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
stop_on_name=stop_on_name,
profile=profile,
storage_map=storage_map,
Expand All @@ -335,6 +341,8 @@ def debugprint(
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

if len(inner_graph_vars) > 0:
Expand Down Expand Up @@ -384,7 +392,7 @@ def debugprint(
print_type=print_type,
file=_file,
id_type=id_type,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
stop_on_name=stop_on_name,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
Expand All @@ -393,6 +401,8 @@ def debugprint(
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

if print_fgraph_inputs:
Expand All @@ -406,7 +416,7 @@ def debugprint(
file=_file,
id_type=id_type,
stop_on_name=stop_on_name,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
Expand All @@ -415,6 +425,8 @@ def debugprint(
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=ig_var.owner,
print_default_updates=print_default_updates,
default_updates=default_updates,
)
inner_to_outer_inputs = None

Expand All @@ -436,7 +448,7 @@ def debugprint(
id_type=id_type,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
inner_graph_ops=inner_graph_vars,
inner_graph_vars=inner_graph_vars,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
Expand All @@ -445,8 +457,43 @@ def debugprint(
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=ig_var.owner,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

if len(default_updates) > 0:
print("", file=_file)
print("Default updates:", file=_file)

inner_to_outer_inputs = {}

for var in default_updates:

print("", file=_file)

update_var = var.default_update
inner_to_outer_inputs[update_var] = var

_debugprint(
update_var,
depth=depth,
done=done,
print_type=print_type,
file=_file,
id_type=id_type,
inner_graph_vars=inner_graph_vars,
stop_on_name=stop_on_name,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
op_information=op_information,
parent_node=None,
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
print_default_updates=print_default_updates,
default_updates=default_updates,
)

if file is _file:
return file
elif file == "str":
Expand All @@ -470,7 +517,7 @@ def _debugprint(
id_type: IDTypesType = "CHAR",
stop_on_name: bool = False,
prefix_child: Optional[str] = None,
inner_graph_ops: Optional[List[Variable]] = None,
inner_graph_vars: Optional[List[Variable]] = None,
profile: Optional[ProfileStats] = None,
inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None,
storage_map: Optional[StorageMapType] = None,
Expand All @@ -479,6 +526,8 @@ def _debugprint(
parent_node: Optional[Apply] = None,
print_op_info: bool = False,
inner_graph_node: Optional[Apply] = None,
print_default_updates: bool = False,
default_updates: Optional[List[Variable]] = None,
) -> TextIO:
r"""Print the graph represented by `var`.

Expand Down Expand Up @@ -506,8 +555,8 @@ def _debugprint(
See `debugprint`.
stop_on_name
Whether to print `Op` ``view_map``\s.
inner_graph_ops
A list of `Op`\s with inner graphs.
inner_graph_vars
A list of `Variables`\s with inner graphs.
inner_to_outer_inputs
A dictionary mapping an `Op`'s inner-inputs to its outer-inputs.
storage_map
Expand All @@ -522,6 +571,10 @@ def _debugprint(
See `debugprint`.
inner_graph_node
The inner-graph node in which `var` is contained.
print_default_updates
Print the `SharedVariable.default_update` values.
default_updates
A list of `Variables`\s with default updates.
"""
if depth == 0:
return file
Expand All @@ -534,8 +587,11 @@ def _debugprint(
else:
_done = done

if inner_graph_ops is None:
inner_graph_ops = []
if inner_graph_vars is None:
inner_graph_vars = []

if default_updates is None:
default_updates = []

if print_type:
type_str = f" <{var.type}>"
Expand Down Expand Up @@ -664,9 +720,9 @@ def get_id_str(
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if (
isinstance(in_var.owner.op, HasInnerGraph)
and in_var not in inner_graph_ops
and in_var not in inner_graph_vars
):
inner_graph_ops.append(in_var)
inner_graph_vars.append(in_var)

_debugprint(
in_var,
Expand All @@ -679,7 +735,7 @@ def get_id_str(
id_type=id_type,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
inner_graph_ops=inner_graph_ops,
inner_graph_vars=inner_graph_vars,
profile=profile,
inner_to_outer_inputs=inner_to_outer_inputs,
storage_map=storage_map,
Expand All @@ -690,6 +746,8 @@ def get_id_str(
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=inner_graph_node,
print_default_updates=print_default_updates,
default_updates=default_updates,
)
else:

Expand All @@ -705,6 +763,25 @@ def get_id_str(

var_output = f"{prefix}{var}{id_str}{type_str}{data}"

# `SharedVariable`s with default updates are considered "inner-graph" variables
if (
print_default_updates
and isinstance(var, SharedVariable)
and var.default_update is not None
):
update_obj = (
var.default_update
if var.default_update.owner is None
else var.default_update.owner
)
update_obj_id = get_id_str(update_obj)
var_output = f"{var_output} <- {update_obj_id}"

# We still want to print the graph later
if var not in default_updates:
default_updates.append(var)
del _done[update_obj]

if print_op_info and var.owner and var.owner not in op_information:
op_information.update(op_debug_information(var.owner.op, var.owner))

Expand Down
4 changes: 2 additions & 2 deletions tests/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def is_variable(x):


class MyType(Type):
def filter(self, data):
def filter(self, data, **kwargs):
return data

def __eq__(self, other):
Expand All @@ -27,7 +27,7 @@ def __repr__(self):


class MyType2(Type):
def filter(self, data):
def filter(self, data, **kwargs):
return data

def __eq__(self, other):
Expand Down
93 changes: 92 additions & 1 deletion tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import aesara
from aesara.compile.mode import get_mode
from aesara.compile.ops import deep_copy_op
from aesara.compile.sharedvalue import SharedVariable
from aesara.printing import (
PatternPrinter,
PPrinter,
Expand All @@ -25,7 +26,7 @@
)
from aesara.tensor import as_tensor_variable
from aesara.tensor.type import dmatrix, dvector, matrix
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable
from tests.graph.utils import MyInnerGraphOp, MyOp, MyType, MyVariable


@pytest.mark.skipif(not pydot_imported, reason="pydot not available")
Expand Down Expand Up @@ -450,3 +451,93 @@ def test_Print(capsys):

stdout, stderr = capsys.readouterr()
assert "hello" in stdout


def test_debugprint_default_updates():

op1 = MyOp("op1")
op2 = MyOp("op2")

r1 = MyVariable("1")
s1 = SharedVariable(MyType(), None, None, name="s1")
s2 = SharedVariable(MyType(), None, None, name="s2")

s1.default_update = op1(r1, s2)
s2.default_update = op1(r1, s1)

out = op2(r1, s1)
out.name = "o1"

s = StringIO()
debugprint(out, file=s, print_default_updates=True)
s = s.getvalue()

reference = dedent(
r"""
op2 [id A] 'o1'
|1 [id B]
|s1 [id C] <- [id D]

Default updates:

op1 [id D]
|1 [id B]
|s2 [id E] <- [id F]

op1 [id F]
|1 [id B]
|s1 [id C] <- [id D]
"""
).lstrip()

assert s == reference


def test_debugprint_inner_graph_default_updates():
"""Test for updates on shared variables in an `OpFromGraph`."""

r1 = MyVariable("1")
r2 = MyVariable("2")
o1 = MyOp("op1")(r1, r2)
o1.name = "o1"

# Inner graph
igo_in_1 = MyVariable("4")
igo_in_s = SharedVariable(MyType(), None, None, name="s")
igo_in_s.default_update = o1
igo_out_1 = MyOp("op2")(igo_in_1, igo_in_s)
igo_out_1.name = "igo1"

from aesara.compile.builders import OpFromGraph

igo = OpFromGraph([igo_in_1], [igo_out_1])

r3 = MyVariable("3")
out = igo(r3)

s = StringIO()
debugprint(out, file=s, print_default_updates=True)
s = s.getvalue()

reference = dedent(
r"""
OpFromGraph{inline=False} [id A]
|3 [id B]
|s [id C] <- [id D]

Inner graphs:

OpFromGraph{inline=False} [id A]
>op2 [id E] 'igo1'
> |*0-<MyType()> [id F]
> |*1-<MyType()> [id G]

Default updates:

op1 [id D] 'o1'
|1 [id H]
|2 [id I]
"""
).lstrip()

assert s == reference