Skip to content

Commit 07f19af

Browse files
Merge branch 'pytorch:main' into pr_model_improve
2 parents 3fe89b1 + a073668 commit 07f19af

20 files changed

+78
-63
lines changed

backends/arm/_passes/cast_int64_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch._export.utils import is_buffer
1313

1414
logger = logging.getLogger(__name__)
15-
logger.setLevel(logging.WARNING)
1615

1716

1817
class CastInt64BuffersToInt32Pass(ExportPass):

backends/arm/arm_backend.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,13 @@
1111
# JIT compiler flows.
1212
#
1313

14-
import logging
15-
1614
from typing import List, Optional
1715

1816
from executorch.backends.arm.tosa_specification import TosaSpecification
1917

2018
from executorch.exir.backend.compile_spec_schema import CompileSpec
2119

2220

23-
logger = logging.getLogger(__name__)
24-
logger.setLevel(logging.WARNING)
25-
26-
2721
class ArmCompileSpecBuilder:
2822
def __init__(self):
2923
self.compile_spec: List[CompileSpec] = []

backends/arm/operator_support/right_shift_support.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

1919
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.WARNING)
2120

2221

2322
@register_tosa_support_check

backends/arm/operator_support/slice_copy_support.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717

1818
logger = logging.getLogger(__name__)
19-
logger.setLevel(logging.WARNING)
2019

2120

2221
@register_tosa_support_check

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def tosa_support_factory(
112112
# Negative checks: Remove nodes from partitioning
113113
negative_checks: list[OperatorSupportBase] = [
114114
CheckInt64Inputs(exported_program, reporter),
115+
CheckFloat64Inputs(exported_program, reporter),
115116
*[
116117
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
117118
for check in (additional_checks if additional_checks else [])
@@ -443,3 +444,26 @@ def is_node_supported(
443444
)
444445
return False
445446
return True
447+
448+
449+
class CheckFloat64Inputs(OperatorSupportBase):
450+
451+
def __init__(
452+
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
453+
):
454+
self.reporter = reporter
455+
super().__init__()
456+
457+
def is_node_supported(
458+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
459+
) -> bool:
460+
461+
for input_node in node.all_input_nodes:
462+
tensor = get_first_fake_tensor(input_node)
463+
if tensor.dtype == torch.float64:
464+
self.reporter.report_reject(
465+
node,
466+
f"Had float64 input {input_node.name} that couldn't be handled.",
467+
)
468+
return False
469+
return True

backends/arm/operators/op_maximum.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,27 @@ def define_node(
3636
inputs: List[TosaArg],
3737
output: TosaArg,
3838
) -> None:
39-
assert inputs[0].dtype == inputs[1].dtype
39+
if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype:
40+
raise TypeError(
41+
f"Data type of inputs and output must be the same. Got input 0 dtype: "
42+
f"{inputs[0].dtype}, input 1 dtype: {inputs[1].dtype} and output "
43+
f"dtype: {output.dtype}"
44+
)
4045

4146
scale_back = 1.0
4247
max_output = output
4348
if inputs[0].dtype == ts.DType.INT8:
4449
input_qparams = get_input_qparams(node)
45-
assert (
46-
len(input_qparams) == 2
47-
), f"Both inputs needs to have quantization information for {node}"
48-
# insert RESCALEs to int32
49-
assert (
50-
input_qparams[0] == input_qparams[1]
51-
), "Both inputs must have same quantization for MAX"
50+
if len(input_qparams) != 2:
51+
raise ValueError(
52+
f"Both inputs need to have quantization information for {node}"
53+
)
54+
if input_qparams[0] != input_qparams[1]:
55+
raise ValueError(
56+
"Both inputs must have the same quantization parameters for MAX"
57+
)
5258

59+
# insert RESCALEs to int32
5360
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5461
tosa_graph, inputs, node
5562
)

backends/arm/operators/op_reciprocal.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,16 @@ def define_node(
3434
inputs: List[TosaArg],
3535
output: TosaArg,
3636
) -> None:
37-
assert inputs[0].dtype == output.dtype == ts.DType.FP32
37+
if len(node.all_input_nodes) != 1:
38+
raise ValueError(
39+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40+
)
41+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42+
raise ValueError(
43+
f"Input and output for {self.target} need to be FP32, got "
44+
f"{inputs[0].dtype=} and {output.dtype=}"
45+
)
46+
3847
tosa_graph.addOperator(
3948
ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
4049
)

backends/arm/operators/op_sub.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,19 @@ def define_node(
4040
) -> None:
4141
# Specification (0.80) states that input and output types
4242
# should all be the same
43-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
43+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
44+
raise TypeError(
45+
f"All IO needs to have the same data type, got input 1: "
46+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
47+
f"{output.dtype}"
48+
)
49+
4450
# Handle int8 (quantized) and int32
45-
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
51+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
52+
if inputs[0].dtype not in supported_dtypes:
53+
raise TypeError(
54+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
55+
)
4656

4757
if inputs[0].dtype == ts.DType.INT8:
4858
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
@@ -97,15 +107,27 @@ def define_node(
97107
) -> None:
98108
# Specification (0.80) states that input and output types
99109
# should all be the same
100-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
110+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
111+
raise TypeError(
112+
f"All IO needs to have the same data type, got input 1: "
113+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
114+
f"{output.dtype}"
115+
)
101116

102117
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
103118
# Call the inherited define_node for handling integers
104119
super().define_node(node, tosa_graph, inputs, output)
105120
else:
106121
# FP32 Sub lowering
107-
assert inputs[0].dtype == ts.DType.FP32
108-
assert output.dtype == ts.DType.FP32
122+
if (
123+
inputs[0].dtype != ts.DType.FP32
124+
or inputs[1].dtype != ts.DType.FP32
125+
or output.dtype != ts.DType.FP32
126+
):
127+
raise TypeError(
128+
f"All IO needs to have data type fp32. Got: {inputs[0].dtype}, "
129+
f"input 2: {inputs[1].dtype} and output: {output.dtype}"
130+
)
109131

110132
# MI lowering
111133
tosa_graph.addOperator(

backends/arm/test/misc/test_debug_feats.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import logging
87
import os
98
import shutil
109
import tempfile
@@ -15,9 +14,6 @@
1514

1615
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1716

18-
logger = logging.getLogger(__name__)
19-
logger.setLevel(logging.INFO)
20-
2117

2218
class Linear(torch.nn.Module):
2319
def __init__(
@@ -205,7 +201,6 @@ def test_collate_tosa_BI_tests(self):
205201

206202

207203
def test_dump_tosa_ops(caplog):
208-
caplog.set_level(logging.INFO)
209204
model = Linear(20, 30)
210205
(
211206
ArmTester(
@@ -222,7 +217,6 @@ def test_dump_tosa_ops(caplog):
222217

223218

224219
def test_fail_dump_tosa_ops(caplog):
225-
caplog.set_level(logging.INFO)
226220

227221
class Add(torch.nn.Module):
228222
def forward(self, x):

backends/arm/test/models/test_conformer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import logging
76
import unittest
87

98
import torch
@@ -14,10 +13,6 @@
1413
from torchaudio.models import Conformer
1514

1615

17-
logger = logging.getLogger(__name__)
18-
logger.setLevel(logging.INFO)
19-
20-
2116
def get_test_inputs(dim, lengths, num_examples):
2217
return (torch.rand(num_examples, int(lengths.max()), dim), lengths)
2318

0 commit comments

Comments
 (0)