diff --git a/autoparallel/graph_utils.py b/autoparallel/graph_utils.py index 72e0d71f..1e708aea 100644 --- a/autoparallel/graph_utils.py +++ b/autoparallel/graph_utils.py @@ -80,7 +80,7 @@ def _add_alias(gm, version="v1"): """ graph = gm.graph - nodes = [n for n in graph.nodes if n.op == "call_function"] + nodes = list(graph.nodes) node_map = {node: idx for idx, node in enumerate(nodes)} def _insert_alias(node): @@ -94,10 +94,9 @@ def delete_user_cb(n): node.replace_all_uses_with(alias_node, delete_user_cb=delete_user_cb) - inputs = graph.find_nodes(op="placeholder") if version == "v1": # only on inputs - for node in inputs: + for node in graph.find_nodes(op="placeholder"): if len(node.users) == 0: # node is not used, don't add alias for it continue @@ -110,7 +109,7 @@ def delete_user_cb(n): _insert_alias(node) elif version == "v2": # for every node that has more than one user - for node in inputs + nodes: + for node in nodes: if len(node.users) < 2: continue # don't add alias for ops which return tuple for now @@ -121,6 +120,7 @@ def delete_user_cb(n): raise ValueError(f"Unknown version {version}") """ + nodes = [n for n in graph.nodes if n.op == "call_function"] for node in nodes: # skip ops which return tuple if not isinstance(node.meta["val"], torch.Tensor):