Skip to content

Commit c66078c

Browse files
Arm backend: Fix arg-type MyPy errors (pytorch#15367)
Introduce function `ensure_type()` which throws an exception if the expected dtype is incorrect. This solves a number of mypy errors spread around the codebase. Additionally, fix arg-type mypy errors. Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 4b42bea commit c66078c

File tree

11 files changed

+81
-37
lines changed

11 files changed

+81
-37
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _match_partition_to_node(
5151
raise RuntimeError(f"Cannot find an input node which matches, {node}.")
5252

5353
def call(self, graph_module: GraphModule) -> PassResult:
54-
matmul_partitions = get_source_partitions(
54+
matmul_partitions_map = get_source_partitions(
5555
graph_module.graph,
5656
[
5757
torch.matmul,
@@ -60,7 +60,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6060
None,
6161
)
6262
matmul_partitions = list(
63-
itertools.chain.from_iterable(matmul_partitions.values())
63+
itertools.chain.from_iterable(matmul_partitions_map.values())
6464
)
6565
matmul_targets = {
6666
exir_ops.edge.aten.bmm.default,
@@ -88,7 +88,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
8888
# Create new dq-node before matmul
8989
dq_node = create_node(
9090
graph=graph_module.graph,
91-
op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type]
91+
op_target=cast(EdgeOpOverload, input_node.target),
9292
)
9393
dq_node.args = (node, *input_node.args[1:])
9494
matmul_node.replace_input_with(node, dq_node)
@@ -109,7 +109,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
109109
# Create q-node after matmul
110110
q_node = create_node(
111111
graph=graph_module.graph,
112-
op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type]
112+
op_target=cast(EdgeOpOverload, partition_output.target),
113113
)
114114
matmul_node.replace_all_uses_with(q_node)
115115
q_node.args = (matmul_node, *partition_output.args[1:])

backends/arm/_passes/arm_pass_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import torch
1414
import torch.fx
1515
from executorch.backends.arm.common.debug import get_node_debug_info
16+
from executorch.backends.arm.common.type import ensure_type
1617
from executorch.exir import ExportedProgram
1718
from executorch.exir.dialects._ops import ops as exir_ops
19+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1820

1921
from torch._export.utils import (
2022
get_buffer,
@@ -81,17 +83,18 @@ def get_param_tensor(
8183
elif is_lifted_tensor_constant(exp_prog, node):
8284
return get_lifted_tensor_constant(exp_prog, node)
8385
elif is_get_attr_node(node):
86+
target_node = ensure_type(str, node.target)
8487
# This is a hack to support both lifted and unlifted graph
8588
try:
86-
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
89+
return getattr(node.graph.owning_module, target_node)
8790
except AttributeError:
88-
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
91+
return getattr(exp_prog.graph_module, target_node)
8992
raise RuntimeError(f"unsupported param type, {node.op}.")
9093

9194

9295
def create_node(
9396
graph: torch.fx.Graph,
94-
op_target: OpOverload,
97+
op_target: OpOverload | EdgeOpOverload,
9598
args: tuple = (),
9699
kwargs: Optional[dict] = None,
97100
quantize: bool = False,

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
4949
shape = get_first_fake_tensor(arg).shape
5050
biggest_rank = max(biggest_rank, len(shape))
5151

52-
new_args = []
52+
new_args: list[Node | int] = []
5353
for arg in n.args:
5454
if isinstance(arg, Node):
5555
new_args.append(arg)
5656
continue
5757
if isinstance(arg, int) and not torch.is_floating_point(
5858
get_first_fake_tensor(n)
5959
):
60-
new_args.append(arg) # type: ignore[arg-type]
60+
new_args.append(arg)
6161
continue
6262

6363
prefix = "_tensor_constant_"

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,19 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
259259

260260
# Transpose outputs if they are in (N)NCHW format
261261
outputs = output_node.args[0]
262+
if not isinstance(outputs, (list, tuple)):
263+
raise TypeError(
264+
f"Expected output node args to be a list or tuple, got {type(outputs)}"
265+
)
262266
output_dim_orders = output_node.meta.get("original_dim_orders")
263267
if output_dim_orders is None:
264268
raise RuntimeError(
265269
f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}."
266270
)
267271

268-
for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type]
272+
for output_node_input, output_dim_order in zip(
273+
outputs, output_dim_orders, strict=True
274+
):
269275
if output_dim_order in (
270276
NCHW_ORDER,
271277
NNCHW_ORDER,

backends/arm/common/type.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Type checking utilities."""
6+
7+
from typing import TypeVar
8+
9+
T = TypeVar("T")
10+
11+
12+
def ensure_type(expected_type: type[T], arg: object) -> T:
13+
"""Ensure that the argument is of the expected type.
14+
15+
Args:
16+
expected_type (type[T]): The expected type.
17+
arg (object): The argument to check.
18+
19+
Returns:
20+
T: The argument, if it is of the expected type.
21+
22+
"""
23+
if isinstance(arg, expected_type):
24+
return arg
25+
26+
expected_name = getattr(expected_type, "__name__", str(expected_type))
27+
actual_name = type(arg).__name__
28+
raise TypeError(f"Expected value of type {expected_name}, got {actual_name!r}")

backends/arm/operator_support/index_tensor_support.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.fx as fx
1616
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
17+
from executorch.backends.arm.common.type import ensure_type
1718
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1819
register_tosa_support_check,
1920
SupportedTOSAOperatorCheck,
@@ -137,7 +138,8 @@ def is_node_tosa_supported(
137138
return False
138139

139140
# Usage 1 guard
140-
fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type]
141+
index = ensure_type(torch.fx.Node, index)
142+
fake_tensor = get_first_fake_tensor(index)
141143
if len(fake_tensor.size()) > 3:
142144
self.reporter.report_reject(
143145
node,
@@ -146,7 +148,8 @@ def is_node_tosa_supported(
146148
return False
147149

148150
# Usage 3 guard
149-
total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type]
151+
input_node = ensure_type(torch.fx.Node, node.args[0])
152+
total_vals = math.prod(get_first_fake_tensor(input_node).shape)
150153
if total_vals > torch.iinfo(torch.int32).max:
151154
self.reporter.report_reject(
152155
node,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _is_matmul_node_supported(
219219
"""
220220
for graph_module in submodules.values():
221221
graph_module = typing.cast(fx.GraphModule, graph_module)
222-
matmul_partitions = get_source_partitions(
222+
matmul_partitions_map = get_source_partitions(
223223
graph_module.graph,
224224
[
225225
torch.matmul,
@@ -228,7 +228,7 @@ def _is_matmul_node_supported(
228228
None,
229229
)
230230
matmul_partitions = list(
231-
itertools.chain.from_iterable(matmul_partitions.values())
231+
itertools.chain.from_iterable(matmul_partitions_map.values())
232232
)
233233
matched_partition = None
234234
for partition in matmul_partitions:
@@ -406,9 +406,7 @@ def is_node_supported(
406406
if input_node.target in ComputeConstantOpsAOT.targeted_ops:
407407
# This is not perfect since the input_node can still be rejected by other checks but
408408
# this should cover the majority of cases.
409-
if self.is_node_supported(
410-
None, input_node # type: ignore[arg-type] #(we don't use 'submodules')
411-
):
409+
if self.is_node_supported({}, input_node):
412410
continue
413411
self.reporter.report_reject(
414412
node, f"Non-constant int64 input {input_node.name}"

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
374374
# TODO: Fix the need to lazily import this.
375375
from executorch.backends.arm._passes import ArmPassManager
376376

377-
return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type]
377+
return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline(
378378
graph_module=model
379379
)
380380

backends/arm/quantizer/quantization_annotator.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.fx
1313
import torch.nn.functional as F
1414
from executorch.backends.arm.common.debug import get_node_debug_info
15+
from executorch.backends.arm.common.type import ensure_type
1516
from executorch.backends.arm.quantizer import QuantizationConfig
1617
from torch._subclasses import FakeTensor
1718

@@ -510,7 +511,8 @@ def any_or_hardtanh_min_zero(n: Node):
510511
torch.ops.aten.minimum.default,
511512
torch.ops.aten.maximum.default,
512513
):
513-
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
514+
lhs_node = ensure_type(Node, node.args[0])
515+
shared_qspec = SharedQuantizationSpec((lhs_node, node))
514516
quant_properties.quant_inputs = [
515517
_QuantProperty(0, input_act_qspec),
516518
_QuantProperty(
@@ -520,22 +522,24 @@ def any_or_hardtanh_min_zero(n: Node):
520522
]
521523
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
522524
elif node.target in (torch.ops.aten.where.self,):
523-
shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type]
525+
true_node = ensure_type(Node, node.args[1])
526+
shared_qspec = SharedQuantizationSpec(true_node)
524527
quant_properties.quant_inputs = [
525528
_QuantProperty(1, shared_qspec),
526529
_QuantProperty(2, shared_qspec),
527530
]
528531
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
529532
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
533+
input_node = ensure_type(Node, node.args[0])
530534
input_qspec = (
531-
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
532-
if is_output_annotated(node.args[0]) # type: ignore[arg-type]
535+
SharedQuantizationSpec(input_node)
536+
if is_output_annotated(input_node)
533537
else input_act_qspec
534538
)
535539
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)]
536540
quant_properties.quant_output = _QuantProperty(
537541
0,
538-
SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type]
542+
SharedQuantizationSpec((input_node, node)),
539543
)
540544
elif node.target in (
541545
torch.ops.aten.cat.default,
@@ -550,26 +554,24 @@ def any_or_hardtanh_min_zero(n: Node):
550554
)
551555
if len(node.args[0]) == 0:
552556
raise ValueError("Expected non-empty list for node.args[0]")
553-
554-
shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) # type: ignore[arg-type]
557+
inputs = [ensure_type(Node, element) for element in node.args[0]]
558+
shared_qspec = SharedQuantizationSpec((inputs[0], node))
555559
quant_properties.quant_inputs = [
556560
_QuantProperty(
557561
0,
558-
[
559-
input_act_qspec if n == node.args[0][0] else shared_qspec # type: ignore[misc]
560-
for n in node.args[0]
561-
],
562+
[input_act_qspec if n == inputs[0] else shared_qspec for n in inputs],
562563
)
563564
]
564565
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
565566
elif node.target in _one_to_one:
566567
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
567568
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
568569
elif node.target in _one_to_one_shared_input_qspec:
570+
input_node = ensure_type(Node, node.args[0])
569571
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
570572
quant_properties.quant_output = _QuantProperty(
571573
0,
572-
SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type]
574+
SharedQuantizationSpec((input_node, node)),
573575
)
574576
elif node.target in [
575577
torch.ops.aten.eq.Tensor,
@@ -578,7 +580,8 @@ def any_or_hardtanh_min_zero(n: Node):
578580
torch.ops.aten.le.Tensor,
579581
torch.ops.aten.lt.Tensor,
580582
]:
581-
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
583+
input_node = ensure_type(Node, node.args[0])
584+
shared_qspec = SharedQuantizationSpec((input_node, node))
582585
quant_properties.quant_inputs = [
583586
_QuantProperty(0, input_act_qspec),
584587
_QuantProperty(
@@ -596,9 +599,10 @@ def any_or_hardtanh_min_zero(n: Node):
596599
quant_properties.quant_inputs = []
597600
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
598601
elif node.target in [operator.getitem]:
599-
if not is_output_annotated(node.args[0]): # type: ignore[arg-type]
602+
input_node = ensure_type(Node, node.args[0])
603+
if not is_output_annotated(input_node):
600604
return None
601-
shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
605+
shared_qspec = SharedQuantizationSpec(input_node)
602606
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)]
603607
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
604608
else:

backends/arm/test/tester/arm_tester.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,9 @@ def run_transform_for_annotation_pipeline(
604604
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
605605
artifact = self.get_artifact(stage)
606606
if self.cur == StageType.EXPORT:
607-
new_gm = ArmPassManager(self.compile_spec.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type]
608-
graph_module=artifact.graph_module
609-
)
607+
new_gm = ArmPassManager(
608+
self.compile_spec.tosa_spec
609+
).transform_for_annotation_pipeline(graph_module=artifact.graph_module)
610610
else:
611611
raise RuntimeError("Can only run passes on Export stage.")
612612
_copy_module(artifact.graph_module, new_gm)

0 commit comments

Comments
 (0)