Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions autoparallel/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _add_alias(gm, version="v1"):
"""
graph = gm.graph

nodes = [n for n in graph.nodes if n.op == "call_function"]
nodes = list(graph.nodes)
node_map = {node: idx for idx, node in enumerate(nodes)}

def _insert_alias(node):
Expand All @@ -94,10 +94,9 @@ def delete_user_cb(n):

node.replace_all_uses_with(alias_node, delete_user_cb=delete_user_cb)

inputs = graph.find_nodes(op="placeholder")
if version == "v1":
# only on inputs
for node in inputs:
for node in graph.find_nodes(op="placeholder"):
if len(node.users) == 0:
# node is not used, don't add alias for it
continue
Expand All @@ -110,7 +109,7 @@ def delete_user_cb(n):
_insert_alias(node)
elif version == "v2":
# for every node that has more than one user
for node in inputs + nodes:
for node in nodes:
if len(node.users) < 2:
continue
# don't add alias for ops which return tuple for now
Expand All @@ -121,6 +120,7 @@ def delete_user_cb(n):
raise ValueError(f"Unknown version {version}")

"""
nodes = [n for n in graph.nodes if n.op == "call_function"]
for node in nodes:
# skip ops which return tuple
if not isinstance(node.meta["val"], torch.Tensor):
Expand Down