|
5 | 5 |
|
6 | 6 | # Create flows for Arm Backends used to test operator and model suits |
7 | 7 |
|
| 8 | +from collections.abc import Callable |
| 9 | + |
8 | 10 | from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec |
9 | 11 | from executorch.backends.arm.quantizer import get_symmetric_quantization_config |
10 | 12 | from executorch.backends.arm.test import common |
11 | 13 | from executorch.backends.arm.test.tester.arm_tester import ArmTester |
12 | | -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec |
13 | 14 | from executorch.backends.arm.util._factory import create_quantizer |
14 | 15 | from executorch.backends.test.suite.flow import TestFlow |
15 | 16 | from executorch.backends.xnnpack.test.tester.tester import Quantize |
16 | 17 |
|
17 | 18 |
|
18 | 19 | def _create_arm_flow( |
19 | | - name, |
20 | | - compile_spec: ArmCompileSpec, |
| 20 | + name: str, |
| 21 | + compile_spec_factory: Callable[[], ArmCompileSpec], |
| 22 | + support_serialize: bool = True, |
| 23 | + quantize: bool = True, |
21 | 24 | symmetric_io_quantization: bool = False, |
22 | 25 | per_channel_quantization: bool = True, |
23 | 26 | use_portable_ops: bool = True, |
24 | 27 | timeout: int = 1200, |
25 | 28 | ) -> TestFlow: |
26 | 29 |
|
27 | 30 | def _create_arm_tester(*args, **kwargs) -> ArmTester: |
28 | | - kwargs["compile_spec"] = compile_spec |
| 31 | + spec = compile_spec_factory() |
| 32 | + kwargs["compile_spec"] = spec |
29 | 33 | return ArmTester( |
30 | 34 | *args, **kwargs, use_portable_ops=use_portable_ops, timeout=timeout |
31 | 35 | ) |
32 | 36 |
|
33 | | - support_serialize = not isinstance(compile_spec, TosaCompileSpec) |
34 | | - quantize = compile_spec.tosa_spec.support_integer() |
35 | | - |
36 | | - if quantize is True: |
| 37 | + if quantize: |
37 | 38 |
|
38 | 39 | def create_quantize_stage() -> Quantize: |
39 | | - quantizer = create_quantizer(compile_spec) |
| 40 | + spec = compile_spec_factory() |
| 41 | + quantizer = create_quantizer(spec) |
40 | 42 | quantization_config = get_symmetric_quantization_config( |
41 | 43 | is_per_channel=per_channel_quantization |
42 | 44 | ) |
43 | 45 | if symmetric_io_quantization: |
44 | 46 | quantizer.set_io(quantization_config) |
45 | | - return Quantize(quantizer, quantization_config) |
| 47 | + return Quantize(quantizer, quantization_config) # type: ignore |
46 | 48 |
|
47 | 49 | return TestFlow( |
48 | 50 | name, |
49 | 51 | backend="arm", |
50 | 52 | tester_factory=_create_arm_tester, |
51 | 53 | supports_serialize=support_serialize, |
52 | 54 | quantize=quantize, |
53 | | - quantize_stage_factory=(create_quantize_stage if quantize is True else False), |
| 55 | + quantize_stage_factory=(create_quantize_stage if quantize else False), # type: ignore |
54 | 56 | ) |
55 | 57 |
|
56 | 58 |
|
57 | 59 | ARM_TOSA_FP_FLOW = _create_arm_flow( |
58 | 60 | "arm_tosa_fp", |
59 | | - common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), |
| 61 | + lambda: common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), |
| 62 | + support_serialize=False, |
| 63 | + quantize=False, |
60 | 64 | ) |
61 | 65 | ARM_TOSA_INT_FLOW = _create_arm_flow( |
62 | 66 | "arm_tosa_int", |
63 | | - common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), |
| 67 | + lambda: common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), |
| 68 | + support_serialize=False, |
| 69 | + quantize=True, |
64 | 70 | ) |
65 | 71 | ARM_ETHOS_U55_FLOW = _create_arm_flow( |
66 | 72 | "arm_ethos_u55", |
67 | | - common.get_u55_compile_spec(), |
| 73 | + lambda: common.get_u55_compile_spec(), |
| 74 | + quantize=True, |
68 | 75 | ) |
69 | 76 | ARM_ETHOS_U85_FLOW = _create_arm_flow( |
70 | 77 | "arm_ethos_u85", |
71 | | - common.get_u85_compile_spec(), |
| 78 | + lambda: common.get_u85_compile_spec(), |
| 79 | + quantize=True, |
72 | 80 | ) |
0 commit comments