Skip to content

Commit de56c81

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Merge passes that replace scalars (#15298)
Besides having obsolete names, ReplaceScalarWithTensorPassTOSAMI and ReplaceScalarWithTensorPassTOSABI was causing difficulties for defining chronological pass dependencies with the `_passes_required_after` attribute. This is because, with the current design, there is no way to distinguish which profile is referred to when defining `_passes_required_after`; a pass simply declares its chronological dependencies globally. Solve this by merging the two pass classes together into one called `ReplaceScalarWithTensorByProfilePass`. This means that a pass should include this new pass in `_passes_required_after` no matter which TOSA profile its working towards. ### Test plan This is a refactoring patch. No behavior is modified. Thus the tests already in place are enough to test this change. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent 69f79b9 commit de56c81

12 files changed

+67
-37
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@
8888
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
8989
from .remove_noop_pass import RemoveNoopPass # noqa
9090
from .replace_scalar_with_tensor_pass import ( # noqa
91-
ReplaceScalarWithTensorArgPassTOSABI,
92-
ReplaceScalarWithTensorArgPassTOSAMI,
91+
ReplaceScalarWithTensorByProfilePass,
9392
)
9493
from .rewrite_conv2d_pass import RewriteConv2dPass # noqa
9594
from .rewrite_matmul import RewriteMatmulPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@
8787
QuantizeOperatorArguments,
8888
RemoveNoopPass,
8989
ReplaceInfValues,
90-
ReplaceScalarWithTensorArgPassTOSABI,
91-
ReplaceScalarWithTensorArgPassTOSAMI,
90+
ReplaceScalarWithTensorByProfilePass,
9291
RetraceFoldedDtypesPass,
9392
RewriteConv2dPass,
9493
RewriteMatmulPass,
@@ -172,7 +171,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
172171
self.add_pass(CastToInt32Pass())
173172

174173
self.add_pass(CastBoolToInt8Pass())
175-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
174+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
176175
self.add_pass(AnnotateDecomposedMatmulPass())
177176
self.add_pass(QuantizeOperatorArguments())
178177
self.add_pass(ConvertELUParamsPass())
@@ -242,7 +241,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
242241
self.add_pass(DecomposeSinhPass())
243242
self.add_pass(DecomposeSignPass())
244243
self.add_pass(DecomposeDivTensorModePass())
245-
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
244+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
246245
self.add_pass(DecomposeEmbeddingPass())
247246
self.add_pass(FuseQuantizedActivationPass())
248247
self.add_pass(RemoveGetItemPass())
@@ -335,7 +334,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
335334
self.add_pass(DecomposeAddmmPass())
336335
self.add_pass(DecomposeDivTensorModePass())
337336
self.add_pass(DecomposeAddSubAlphaPass())
338-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
337+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
339338
self.add_pass(ScalarsToAttributePass())
340339
self.add_pass(DecomposeGroupNormPass())
341340
self.add_pass(DecomposeLayerNormPass())

backends/arm/_passes/decompose_acosh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -32,7 +32,7 @@ class DecomposeAcoshPass(ArmPass):
3232
DecomposeSqrtPass,
3333
InsertTableOpsPass,
3434
MatchArgRanksPass,
35-
ReplaceScalarWithTensorArgPassTOSAMI,
35+
ReplaceScalarWithTensorByProfilePass,
3636
MatchArgDtypePass,
3737
}
3838

backends/arm/_passes/decompose_asin_and_acos_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
2020
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
2121
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
22-
ReplaceScalarWithTensorArgPassTOSAMI,
22+
ReplaceScalarWithTensorByProfilePass,
2323
)
2424
from executorch.exir.dialects._ops import ops as exir_ops
2525
from executorch.exir.pass_base import ExportPass
@@ -71,7 +71,7 @@ class DecomposeAsinAndAcosPass(ArmPass):
7171
ConvertFullLikeToFullPass,
7272
MatchArgRanksPass,
7373
MatchArgDtypePass,
74-
ReplaceScalarWithTensorArgPassTOSAMI,
74+
ReplaceScalarWithTensorByProfilePass,
7575
}
7676

7777
def _build_polynomial(

backends/arm/_passes/decompose_asinh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -32,7 +32,7 @@ class DecomposeAsinhPass(ArmPass):
3232
DecomposeSqrtPass,
3333
InsertTableOpsPass,
3434
MatchArgRanksPass,
35-
ReplaceScalarWithTensorArgPassTOSAMI,
35+
ReplaceScalarWithTensorByProfilePass,
3636
MatchArgDtypePass,
3737
}
3838

backends/arm/_passes/decompose_atan_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -47,7 +47,7 @@ class DecomposeAtanPass(ArmPass):
4747
InsertTableOpsPass,
4848
MatchArgRanksPass,
4949
MatchArgDtypePass,
50-
ReplaceScalarWithTensorArgPassTOSAMI,
50+
ReplaceScalarWithTensorByProfilePass,
5151
}
5252

5353
def _rational_approximation(self, z, ops, meta):

backends/arm/_passes/decompose_atanh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1111
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1212
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
13-
ReplaceScalarWithTensorArgPassTOSAMI,
13+
ReplaceScalarWithTensorByProfilePass,
1414
)
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass
@@ -43,7 +43,7 @@ class DecomposeAtanhPass(ArmPass):
4343
InsertTableOpsPass,
4444
MatchArgRanksPass,
4545
MatchArgDtypePass,
46-
ReplaceScalarWithTensorArgPassTOSAMI,
46+
ReplaceScalarWithTensorByProfilePass,
4747
}
4848

4949
def call_operator(self, op, args, kwargs, meta):

backends/arm/_passes/decompose_cosh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1111
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1212
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
13-
ReplaceScalarWithTensorArgPassTOSAMI,
13+
ReplaceScalarWithTensorByProfilePass,
1414
)
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass
@@ -31,7 +31,7 @@ class DecomposeCoshPass(ArmPass):
3131
_passes_required_after: Set[Type[ExportPass]] = {
3232
InsertTableOpsPass,
3333
MatchArgRanksPass,
34-
ReplaceScalarWithTensorArgPassTOSAMI,
34+
ReplaceScalarWithTensorByProfilePass,
3535
MatchArgDtypePass,
3636
}
3737

backends/arm/_passes/decompose_expm1_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -83,7 +83,7 @@ class DecomposeExpm1Pass(ArmPass):
8383
ConvertIntPowToMuls,
8484
InsertTableOpsPass,
8585
DecomposeDivPass,
86-
ReplaceScalarWithTensorArgPassTOSAMI,
86+
ReplaceScalarWithTensorByProfilePass,
8787
MatchArgDtypePass,
8888
MatchArgRanksPass,
8989
}

backends/arm/_passes/decompose_logit_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -73,7 +73,7 @@ class DecomposeLogitPass(ArmPass):
7373
InsertTableOpsPass,
7474
MatchArgRanksPass,
7575
MatchArgDtypePass,
76-
ReplaceScalarWithTensorArgPassTOSAMI,
76+
ReplaceScalarWithTensorByProfilePass,
7777
}
7878

7979
def call_operator(self, op, args, kwargs, meta):

0 commit comments

Comments
 (0)