Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor] Fix edge case in JIT vs. AOT fusion after finalizing Multi…
…TemplateBuffer (pytorch#126622) # Context Here's a peripheral scenario causing the JIT-pass and AOT-pass to pick different fusions. ```py # JIT -- buf3 is a MultiTemplateBuffer V.graph.buffers = [buf0, buf1, buf2, buf3, buf4] ^ ^ # JIT pass calls finalize_multi_template_buffers() V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*] # AOT, note proximity_score(buf2, buf4) is "better" for fusion than JIT V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*] ^ ^ ``` It happens like this: * JIT starts with the original set nodes using V.graph.buffers * In JIT, finalize_multi_template_buffers() is called which can change the order of the buffers. * This makes the order of buffers/scheduler nodes different. * Now, each node's min/max-order is different than before. * As a result, the proximity between two nodes is different. https://github.com/pytorch/pytorch/blob/ad67553c5c1672d65b810acd7a6a01e11695098b/torch/_inductor/scheduler.py#L2316-L2335 # Error ``` $ TORCH_LOGS="+fusion" python test/inductor/test_max_autotune.py -k test_jit_fusion_matches_aot_fusion ====================================================================== FAIL: test_jit_fusion_matches_aot_fusion (__main__.TestMaxAutotune) ---------------------------------------------------------------------- Traceback (most recent call last): ... File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1718, in compile_to_fn code, linemap = self.codegen_with_cpp_wrapper() File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1618, in codegen_with_cpp_wrapper return self.codegen() File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1636, in codegen self.scheduler.codegen() File "/data/users/colinpeppler/pytorch/torch/_dynamo/utils.py", line 210, in time_wrapper r = func(*args, **kwargs) File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2602, in codegen self.get_backend(device).codegen_node(node) # type: ignore[possibly-undefined] File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 66, in codegen_node return self._triton_scheduling.codegen_node(node) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3377, in codegen_node return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3602, in codegen_node_schedule final_kernel.call_kernel(final_kernel.kernel_name) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3055, in call_kernel grid = wrapper.generate_default_grid(name, grid) File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cpp_wrapper_cuda.py", line 174, in generate_default_grid params is not None AssertionError: cuda kernel parameters for triton_poi_fused_add_0 should already exist at this moment, only found dict_keys(['Placeholder.DESCRIPTIVE_NAME', 'triton_poi_fused_add_mul_0', 'triton_poi_fused_pow_1']) ``` Pull Request resolved: pytorch#126622 Approved by: https://github.com/chenyang78 ghstack dependencies: pytorch#125982
- Loading branch information