From 7b039ccbfdef2f4a50b9155ed5e8e1fd4a882664 Mon Sep 17 00:00:00 2001 From: Horace He Date: Mon, 8 Aug 2022 19:02:36 -0700 Subject: [PATCH] changed pre-fuse heuristic a bit (#744) --- benchmarks/torchbench_models_list.txt | 4 +-- torchinductor/scheduler.py | 38 ++++++++++++++++++++------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/benchmarks/torchbench_models_list.txt b/benchmarks/torchbench_models_list.txt index 034755adcb..36d3f6adc1 100644 --- a/benchmarks/torchbench_models_list.txt +++ b/benchmarks/torchbench_models_list.txt @@ -3,7 +3,7 @@ LearningToPaint,1024 alexnet,1024 dcgan,1024 densenet121,64 -hf_Albert,16 +hf_Albert,32 hf_Bart,16 hf_Bert,16 hf_GPT2,16 @@ -19,7 +19,7 @@ resnext50_32x4d,128 shufflenet_v2_x1_0,512 squeezenet1_1,512 timm_efficientnet,128 -timm_regnet,64 +timm_regnet,128 timm_resnest,256 timm_vision_transformer,256 timm_vovnet,128 diff --git a/torchinductor/scheduler.py b/torchinductor/scheduler.py index 5528f1bf7c..d3c5fe37d2 100644 --- a/torchinductor/scheduler.py +++ b/torchinductor/scheduler.py @@ -703,7 +703,7 @@ def toposort(graph): if node.op == "call_function" and node.meta["fusion_meta"].type == "compute" ] - def is_fusible(src, dst): + def will_create_cycle(src, dst): # Finds whether there's a path from src to dst that isn't a direct edge cur_nodes = collections.deque( [user for user in src.users if user != dst] @@ -715,11 +715,11 @@ def is_fusible(src, dst): if cur in vis: continue if cur == dst or cur == src: - return False + return True vis.add(cur) for user in cur.users: cur_nodes.append(user) - return True + return False def fuse_nodes(a: fx.Node, b: fx.Node): a.args = tuple((set(a.args) | set(b.args)) - set([a, b])) @@ -730,17 +730,37 @@ def fuse_nodes(a: fx.Node, b: fx.Node): graph.erase_node(b) return a + def is_fusible(groupA, groupB): + return groupA == groupB + + def fusion_weight(nodeA, nodeB): + shared_reads = len(set(nodeA.args) & set(nodeB.args)) + saved_write = 0 + saved_read = 0 + + if nodeA in nodeB.users: + nodeA, nodeB = nodeB, nodeA + + if nodeB in nodeA.users: + saved_read += 1 + if nodeA.users == 1: + saved_write += 1 + + return shared_reads + saved_write + saved_read + # Enumerates all fusion opportunities, and ranks them in descending order of shared reads fusion_opportunities = [] for idx, nodeA in enumerate(fusible_nodes): for nodeB in fusible_nodes[idx + 1 :]: - if nodeA.meta["fusion_meta"].group == nodeB.meta[ - "fusion_meta" - ].group and set(nodeA.args) & set(nodeB.args): - fusion_opportunities.append( - [len(set(nodeA.args) & set(nodeB.args)), nodeA, nodeB] + if ( + is_fusible( + nodeA.meta["fusion_meta"].group, nodeB.meta["fusion_meta"].group ) + and fusion_weight(nodeA, nodeB) > 0 + ): + fusion_opportunities.append([fusion_weight(nodeA, nodeB), nodeA, nodeB]) + # NB: Python sort is stable, so this is deterministic fusion_opportunities = sorted( fusion_opportunities, key=lambda x: x[0], reverse=True ) @@ -754,7 +774,7 @@ def fuse_nodes(a: fx.Node, b: fx.Node): b = nodes[mapping.find(b.order)] if a == b: continue - if is_fusible(a, b): + if not will_create_cycle(a, b): fused_node = fuse_nodes(a, b) nodes[mapping.find(a.order)] = fused_node nodes[mapping.find(b.order)] = fused_node