Skip to content

Commit

Permalink
changed pre-fuse heuristic a bit (#744)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee authored Aug 9, 2022
1 parent 7181a8d commit 7b039cc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
4 changes: 2 additions & 2 deletions benchmarks/torchbench_models_list.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ LearningToPaint,1024
alexnet,1024
dcgan,1024
densenet121,64
hf_Albert,16
hf_Albert,32
hf_Bart,16
hf_Bert,16
hf_GPT2,16
Expand All @@ -19,7 +19,7 @@ resnext50_32x4d,128
shufflenet_v2_x1_0,512
squeezenet1_1,512
timm_efficientnet,128
timm_regnet,64
timm_regnet,128
timm_resnest,256
timm_vision_transformer,256
timm_vovnet,128
Expand Down
38 changes: 29 additions & 9 deletions torchinductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def toposort(graph):
if node.op == "call_function" and node.meta["fusion_meta"].type == "compute"
]

def is_fusible(src, dst):
def will_create_cycle(src, dst):
# Finds whether there's a path from src to dst that isn't a direct edge
cur_nodes = collections.deque(
[user for user in src.users if user != dst]
Expand All @@ -715,11 +715,11 @@ def is_fusible(src, dst):
if cur in vis:
continue
if cur == dst or cur == src:
return False
return True
vis.add(cur)
for user in cur.users:
cur_nodes.append(user)
return True
return False

def fuse_nodes(a: fx.Node, b: fx.Node):
a.args = tuple((set(a.args) | set(b.args)) - set([a, b]))
Expand All @@ -730,17 +730,37 @@ def fuse_nodes(a: fx.Node, b: fx.Node):
graph.erase_node(b)
return a

def is_fusible(groupA, groupB):
return groupA == groupB

def fusion_weight(nodeA, nodeB):
shared_reads = len(set(nodeA.args) & set(nodeB.args))
saved_write = 0
saved_read = 0

if nodeA in nodeB.users:
nodeA, nodeB = nodeB, nodeA

if nodeB in nodeA.users:
saved_read += 1
if nodeA.users == 1:
saved_write += 1

return shared_reads + saved_write + saved_read

# Enumerates all fusion opportunities, and ranks them in descending order of shared reads
fusion_opportunities = []
for idx, nodeA in enumerate(fusible_nodes):
for nodeB in fusible_nodes[idx + 1 :]:
if nodeA.meta["fusion_meta"].group == nodeB.meta[
"fusion_meta"
].group and set(nodeA.args) & set(nodeB.args):
fusion_opportunities.append(
[len(set(nodeA.args) & set(nodeB.args)), nodeA, nodeB]
if (
is_fusible(
nodeA.meta["fusion_meta"].group, nodeB.meta["fusion_meta"].group
)
and fusion_weight(nodeA, nodeB) > 0
):
fusion_opportunities.append([fusion_weight(nodeA, nodeB), nodeA, nodeB])

# NB: Python sort is stable, so this is deterministic
fusion_opportunities = sorted(
fusion_opportunities, key=lambda x: x[0], reverse=True
)
Expand All @@ -754,7 +774,7 @@ def fuse_nodes(a: fx.Node, b: fx.Node):
b = nodes[mapping.find(b.order)]
if a == b:
continue
if is_fusible(a, b):
if not will_create_cycle(a, b):
fused_node = fuse_nodes(a, b)
nodes[mapping.find(a.order)] = fused_node
nodes[mapping.find(b.order)] = fused_node
Expand Down

0 comments on commit 7b039cc

Please sign in to comment.