From 4fab7c2bd5f22dd38aaedaf79bdffd088fc87d24 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 6 May 2025 12:18:34 -0700 Subject: [PATCH] fix bug with sequential backends (#10708) Summary: https://github.com/pytorch/executorch/pull/10584/files#r2070213706 there's a bug described in this PR comment. I add some tests and a fix to cover it. Essentially when sequential partitions go through preprocess_all, the get_item nodes from the first partition in the sequence don't correctly get mapped to the arguments input into the second partition. This is because the name of these nodes change (the original node to a get_item node). Instead of checking for the names, we instead delete the nodes we know must be deleted from the inputspec Additionaly, there is an issue with validation. the _validate fails when there are call_module nodes still in the graph. Since preprocess_multimethod will lower the call_submodule nodes one-by-one calling _validate before all the call_submodule nodes are transformed to call_delegate nodes will fail. We remove the _validate call from unsafe_adjust_original_program and instead call _validate on the original program after all the submodule nodes have been converted to call_delegate Differential Revision: D74226258 --- exir/backend/backend_api.py | 18 +++-- .../test/backend_with_preprocess_all_demo.py | 55 ++++++++++---- .../test/test_to_backend_multi_method.py | 71 +++++++++++++++++++ exir/lowered_backend_module.py | 16 +++-- 4 files changed, 136 insertions(+), 24 deletions(-) diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 310e5ea9379..838156498c4 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -204,12 +204,16 @@ def _insert_lowered_submodule( owning_graph_module = call_submodule_node.graph.owning_module # call delegate args should only use user_inputs call_delegate_args = [] - # Preserve input order as user_inputs - for inp_name in submodule_program.graph_signature.user_inputs: - for inp_node in call_submodule_node.all_input_nodes: - if inp_node.name == inp_name: - call_delegate_args.append(inp_node) - break + # names of input_specs to delete + input_specs_to_delete = toplevel_input_specs_to_delete + # Delete owned constants from the call_submodule_node args + for call_sm_input in call_submodule_node.args: + if ( + isinstance(call_sm_input, torch.fx.Node) + and call_sm_input.name in input_specs_to_delete.keys() + ): + continue + call_delegate_args.append(call_sm_input) def generate_debug_handle(ep: ExportedProgram) -> int: """ @@ -324,6 +328,7 @@ def _partition_and_lower_one_graph_module( toplevel_input_specs_to_delete, toplevel_output_specs_to_delete, ) + owning_program._validate() return tagged_graph_module @@ -742,6 +747,7 @@ def to_backend( for method_name in method_to_edge_program.keys(): if method_name in method_to_tagged_exported_program: tagged_exported_program = method_to_tagged_exported_program[method_name] + tagged_exported_program._validate() partitioned_and_lowered_exported_programs[method_name] = ExportedProgram( root=tagged_exported_program.graph_module, graph=tagged_exported_program.graph_module.graph, diff --git a/exir/backend/test/backend_with_preprocess_all_demo.py b/exir/backend/test/backend_with_preprocess_all_demo.py index ae9a8174be5..11941b703a0 100644 --- a/exir/backend/test/backend_with_preprocess_all_demo.py +++ b/exir/backend/test/backend_with_preprocess_all_demo.py @@ -21,10 +21,30 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.graph_module import get_control_flow_submodules +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.export.exported_program import ExportedProgram from torch.fx.passes.operator_support import any_chain, OperatorSupportBase +def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: + return ( + is_param(exp_prog, node) + or is_buffer(exp_prog, node) + or is_lifted_tensor_constant(exp_prog, node) + ) + + +def get_total_num_ops_in_ep(edge_programs, supported_ops): + total_number_of_ops = 0 + for edge_program in edge_programs.values(): + for partitioned_program in edge_program: + for node in partitioned_program.graph.nodes: + if node.op == "call_function": + if node.target in supported_ops: + total_number_of_ops += 1 + return total_number_of_ops + + def _preprocess_multimethod( edge_programs: Dict[str, List[ExportedProgram]], compile_specs: Dict[str, List[List[CompileSpec]]], @@ -37,13 +57,7 @@ def _preprocess_multimethod( in testing for a partitioner which tags different partitions for different backends to be lowered to """ - total_number_of_ops = 0 - for edge_program in edge_programs.values(): - for partitioned_program in edge_program: - for node in partitioned_program.graph.nodes: - if node.op == "call_function": - if node.target in supported_ops: - total_number_of_ops += 1 + total_number_of_ops = get_total_num_ops_in_ep(edge_programs, supported_ops) all_processed_results = {key: [] for key in edge_programs.keys()} for method_name, partitioned_programs in edge_programs.items(): @@ -67,6 +81,8 @@ def _preprocess_multimethod( raise RuntimeError( f"{node.op} {node.target.__name__} is not supported in backend {backend_name}" ) + if is_param_node(partitioned_program, node): + processed_bytes += f"CONST{node.name}:" processed_bytes += "#" for cs in compile_spec_for_partition: @@ -171,14 +187,30 @@ def preprocess_multimethod( class AddSinOperatorSupport(OperatorSupportBase): + def __init__(self, original_program): + self.original_program = original_program + super().__init__() + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - return node.op == "call_function" and node.target in [ + supported_targets = [ exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.sin.default, ] + if node.op == "call_function" and node.target in supported_targets: + return True + + if node.op == "placeholder" and is_param_node(self.original_program, node): + for user in node.users.keys(): + if user.target in supported_targets: + return True + return False class SubCosOperatorSupport(OperatorSupportBase): + def __init__(self, original_program): + self.original_program = original_program + super().__init__() + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: return node.op == "call_function" and node.target in [ exir_ops.edge.aten.sub.Tensor, @@ -199,11 +231,8 @@ class BackendWithPreprocessAllPartitioner(Partitioner): """ def __init__(self) -> None: - self.add_sin_support = any_chain(AddSinOperatorSupport()) - self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__ - - self.sub_cos_support = any_chain(SubCosOperatorSupport()) self.sub_cos_backend_id = SecondBackendWithPreprocessAll.__name__ + self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__ def _partition_graph_module( self, @@ -260,6 +289,8 @@ def _partition_graph_module( return partition_tags, start_idx_for_submodules def partition(self, exported_program: ExportedProgram) -> PartitionResult: + self.add_sin_support = any_chain(AddSinOperatorSupport(exported_program)) + self.sub_cos_support = any_chain(SubCosOperatorSupport(exported_program)) partition_tags, _ = self._partition_graph_module(exported_program.graph_module) return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags diff --git a/exir/backend/test/test_to_backend_multi_method.py b/exir/backend/test/test_to_backend_multi_method.py index d4f8fccb8f2..045de253e0f 100644 --- a/exir/backend/test/test_to_backend_multi_method.py +++ b/exir/backend/test/test_to_backend_multi_method.py @@ -392,6 +392,77 @@ def forward(self, x): } self._test(test_set) + def test_multi_method_to_backend_sequential_delegates(self): + class SequentialBackendModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + # delegate one + x = x - x + y = y - y + z = z - z + # graph break + a = x * y * z + # delegate two uses outputs from delegate one and the + # output from the graph break + b = x + a + b = b + z + a + b = b + y + a + return b + + module = SequentialBackendModule() + example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1)) + seq_edgeir_m = to_edge(torch.export.export(module, example_inputs)) + + test_set = { + "seq_edgeir": ( + seq_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';", + "FirstBackendWithPreprocessAll#5#aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_constants(self): + class SequentialBackendModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.zeros(1) + + def forward(self, x, y, z): + # delegate one + x = x - x + y = y - y + z = z - z + # graph break + a = x * y * z * self.const + # delegate two uses outputs from delegate one and the + # output from the graph break + b = x + self.const + a + b = z + a + b + b = y + a + b + return b + + module = SequentialBackendModule() + example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1)) + seq_const_m = to_edge(torch.export.export(module, example_inputs)) + + test_set = { + "seq_const": ( + seq_const_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';", + "FirstBackendWithPreprocessAll#6#CONSTc_const_copy_0:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';", + ], + ), + } + self._test(test_set) + def test_multi_method_to_backend_not_found(self): class SinModule(torch.nn.Module): def __init__(self): diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 78b031a238e..6792626d4ac 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -381,7 +381,7 @@ def _fixup_output_node(gm: torch.fx.GraphModule) -> None: def arrange_graph_placeholders( - gm: torch.fx.GraphModule, owning_program: ExportedProgram + gm: torch.fx.GraphModule, owning_program: ExportedProgram, tag ) -> torch.fx.GraphModule: """ Modifies the graph of the given graphmodule with one that contains the same nodes as the original, @@ -411,9 +411,15 @@ def arrange_graph_placeholders( if node.op != "placeholder": continue - if node.name in graph_sign.inputs_to_parameters: + if ( + node.name in graph_sign.inputs_to_parameters + and node.meta.get("delegation_tag", None) == tag + ): param_nodes.append(node) - elif node.name in graph_sign.inputs_to_buffers: + elif ( + node.name in graph_sign.inputs_to_buffers + and node.meta.get("delegation_tag", None) == tag + ): buffer_nodes.append(node) else: input_nodes.append(node) @@ -694,7 +700,7 @@ def create_exported_program_from_submodule( removed from the toplevel ExportedProgram. """ # Arrange the submodule's placeholders in order - submodule = arrange_graph_placeholders(submodule, owning_program) + submodule = arrange_graph_placeholders(submodule, owning_program, tag) # TODO: we probably need to arrange the outputs wrt buffer mutations. @@ -958,5 +964,3 @@ def _unsafe_adjust_original_program( # noqa: C901 if user_idx > idx: user.args = (user.args[0], user_idx - (len(getitem_idxs) - i)) break - - original_program._validate()