Skip to content

fix bug with sequential backends #10708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,16 @@ def _insert_lowered_submodule(
owning_graph_module = call_submodule_node.graph.owning_module
# call delegate args should only use user_inputs
call_delegate_args = []
# Preserve input order as user_inputs
for inp_name in submodule_program.graph_signature.user_inputs:
for inp_node in call_submodule_node.all_input_nodes:
if inp_node.name == inp_name:
call_delegate_args.append(inp_node)
break
# names of input_specs to delete
input_specs_to_delete = toplevel_input_specs_to_delete
# Delete owned constants from the call_submodule_node args
for call_sm_input in call_submodule_node.args:
if (
isinstance(call_sm_input, torch.fx.Node)
and call_sm_input.name in input_specs_to_delete.keys()
):
continue
call_delegate_args.append(call_sm_input)

def generate_debug_handle(ep: ExportedProgram) -> int:
"""
Expand Down Expand Up @@ -324,6 +328,7 @@ def _partition_and_lower_one_graph_module(
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
)
owning_program._validate()

return tagged_graph_module

Expand Down Expand Up @@ -742,6 +747,7 @@ def to_backend(
for method_name in method_to_edge_program.keys():
if method_name in method_to_tagged_exported_program:
tagged_exported_program = method_to_tagged_exported_program[method_name]
tagged_exported_program._validate()
partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(
root=tagged_exported_program.graph_module,
graph=tagged_exported_program.graph_module.graph,
Expand Down
55 changes: 43 additions & 12 deletions exir/backend/test/backend_with_preprocess_all_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,30 @@
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase


def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
return (
is_param(exp_prog, node)
or is_buffer(exp_prog, node)
or is_lifted_tensor_constant(exp_prog, node)
)


def get_total_num_ops_in_ep(edge_programs, supported_ops):
total_number_of_ops = 0
for edge_program in edge_programs.values():
for partitioned_program in edge_program:
for node in partitioned_program.graph.nodes:
if node.op == "call_function":
if node.target in supported_ops:
total_number_of_ops += 1
return total_number_of_ops


def _preprocess_multimethod(
edge_programs: Dict[str, List[ExportedProgram]],
compile_specs: Dict[str, List[List[CompileSpec]]],
Expand All @@ -37,13 +57,7 @@ def _preprocess_multimethod(
in testing for a partitioner which tags different partitions for different backends
to be lowered to
"""
total_number_of_ops = 0
for edge_program in edge_programs.values():
for partitioned_program in edge_program:
for node in partitioned_program.graph.nodes:
if node.op == "call_function":
if node.target in supported_ops:
total_number_of_ops += 1
total_number_of_ops = get_total_num_ops_in_ep(edge_programs, supported_ops)
all_processed_results = {key: [] for key in edge_programs.keys()}

for method_name, partitioned_programs in edge_programs.items():
Expand All @@ -67,6 +81,8 @@ def _preprocess_multimethod(
raise RuntimeError(
f"{node.op} {node.target.__name__} is not supported in backend {backend_name}"
)
if is_param_node(partitioned_program, node):
processed_bytes += f"CONST{node.name}:"

processed_bytes += "#"
for cs in compile_spec_for_partition:
Expand Down Expand Up @@ -171,14 +187,30 @@ def preprocess_multimethod(


class AddSinOperatorSupport(OperatorSupportBase):
def __init__(self, original_program):
self.original_program = original_program
super().__init__()

def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
supported_targets = [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sin.default,
]
if node.op == "call_function" and node.target in supported_targets:
return True

if node.op == "placeholder" and is_param_node(self.original_program, node):
for user in node.users.keys():
if user.target in supported_targets:
return True
return False


class SubCosOperatorSupport(OperatorSupportBase):
def __init__(self, original_program):
self.original_program = original_program
super().__init__()

def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.aten.sub.Tensor,
Expand All @@ -199,11 +231,8 @@ class BackendWithPreprocessAllPartitioner(Partitioner):
"""

def __init__(self) -> None:
self.add_sin_support = any_chain(AddSinOperatorSupport())
self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__

self.sub_cos_support = any_chain(SubCosOperatorSupport())
self.sub_cos_backend_id = SecondBackendWithPreprocessAll.__name__
self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__

def _partition_graph_module(
self,
Expand Down Expand Up @@ -260,6 +289,8 @@ def _partition_graph_module(
return partition_tags, start_idx_for_submodules

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
self.add_sin_support = any_chain(AddSinOperatorSupport(exported_program))
self.sub_cos_support = any_chain(SubCosOperatorSupport(exported_program))
partition_tags, _ = self._partition_graph_module(exported_program.graph_module)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
Expand Down
71 changes: 71 additions & 0 deletions exir/backend/test/test_to_backend_multi_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,77 @@ def forward(self, x):
}
self._test(test_set)

def test_multi_method_to_backend_sequential_delegates(self):
class SequentialBackendModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y, z):
# delegate one
x = x - x
y = y - y
z = z - z
# graph break
a = x * y * z
# delegate two uses outputs from delegate one and the
# output from the graph break
b = x + a
b = b + z + a
b = b + y + a
return b

module = SequentialBackendModule()
example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1))
seq_edgeir_m = to_edge(torch.export.export(module, example_inputs))

test_set = {
"seq_edgeir": (
seq_edgeir_m.exported_program(),
BackendWithPreprocessAllPartitioner(),
[
"SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';",
"FirstBackendWithPreprocessAll#5#aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';",
],
),
}
self._test(test_set)

def test_multi_method_to_backend_constants(self):
class SequentialBackendModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.zeros(1)

def forward(self, x, y, z):
# delegate one
x = x - x
y = y - y
z = z - z
# graph break
a = x * y * z * self.const
# delegate two uses outputs from delegate one and the
# output from the graph break
b = x + self.const + a
b = z + a + b
b = y + a + b
return b

module = SequentialBackendModule()
example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1))
seq_const_m = to_edge(torch.export.export(module, example_inputs))

test_set = {
"seq_const": (
seq_const_m.exported_program(),
BackendWithPreprocessAllPartitioner(),
[
"SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';",
"FirstBackendWithPreprocessAll#6#CONSTc_const_copy_0:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';",
],
),
}
self._test(test_set)

def test_multi_method_to_backend_not_found(self):
class SinModule(torch.nn.Module):
def __init__(self):
Expand Down
16 changes: 10 additions & 6 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def _fixup_output_node(gm: torch.fx.GraphModule) -> None:


def arrange_graph_placeholders(
gm: torch.fx.GraphModule, owning_program: ExportedProgram
gm: torch.fx.GraphModule, owning_program: ExportedProgram, tag
) -> torch.fx.GraphModule:
"""
Modifies the graph of the given graphmodule with one that contains the same nodes as the original,
Expand Down Expand Up @@ -411,9 +411,15 @@ def arrange_graph_placeholders(
if node.op != "placeholder":
continue

if node.name in graph_sign.inputs_to_parameters:
if (
node.name in graph_sign.inputs_to_parameters
and node.meta.get("delegation_tag", None) == tag
):
param_nodes.append(node)
elif node.name in graph_sign.inputs_to_buffers:
elif (
node.name in graph_sign.inputs_to_buffers
and node.meta.get("delegation_tag", None) == tag
):
buffer_nodes.append(node)
else:
input_nodes.append(node)
Expand Down Expand Up @@ -694,7 +700,7 @@ def create_exported_program_from_submodule(
removed from the toplevel ExportedProgram.
"""
# Arrange the submodule's placeholders in order
submodule = arrange_graph_placeholders(submodule, owning_program)
submodule = arrange_graph_placeholders(submodule, owning_program, tag)

# TODO: we probably need to arrange the outputs wrt buffer mutations.

Expand Down Expand Up @@ -958,5 +964,3 @@ def _unsafe_adjust_original_program( # noqa: C901
if user_idx > idx:
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
break

original_program._validate()
Loading