diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 75d75a5935b4..b070b11c0bf5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -392,14 +392,19 @@ def _visit_rewrite(stmt): # For extern calls, we need to rewrite pairs of arguments corresponding to # base address load and the length of the load. new_args = [stmt.args[0]] + new_buffers = rewrite_buffer.values() for i in range(1, len(stmt.args)): # If the previous argument was a load, the current should be a length if isinstance(stmt.args[i - 1], tvm.tir.Load): load = stmt.args[i - 1] pointer = load.buffer_var if pointer in pointer_to_buffer: - new_args.append(np.prod(list(pointer_to_buffer[pointer].shape))) - continue + buffer = pointer_to_buffer[pointer] + # Only rewrite the arguments of buffers that have been encoded + if buffer in new_buffers: + new_arg = np.prod(list(pointer_to_buffer[pointer].shape)) + new_args.append(new_arg) + continue new_args.append(stmt.args[i]) return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index d62f9d161ad3..93af66da8194 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -435,6 +435,56 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_binary_add_from_constant_scalar(accel_type): + dtype = "uint8" + ifm_shape = (1, 4, 4, 8) + + def create_relay_graph(): + inp = relay.var("input", shape=ifm_shape, dtype=dtype) + scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) + add = relay.qnn.op.add( + inp, + scalar, + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + ) + func = relay.Function(relay.analysis.free_vars(add), add) + return tvm.IRModule.from_expr(func) + + mod = create_relay_graph() + partitioned_mod = partition_for_ethosu(mod) + + # Generate reference data + input_data = {"input": np.random.randint(low=0, high=255, size=ifm_shape, dtype=dtype)} + output_data = generate_ref_data(mod, input_data) + + compiled_models = infra.build_source( + partitioned_mod, + input_data, + output_data, + accel_type, + output_tolerance=0, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize( "ifm_shape, ifm2_shape", diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 5b60102162be..cc3c68624242 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np pytest.importorskip("ethosu.vela") import tvm @@ -23,8 +24,10 @@ from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator -from .infra import make_ethosu_conv2d +from .infra import make_ethosu_conv2d, make_ethosu_binary_elementwise # fmt: off @@ -270,5 +273,47 @@ def _get_func(): assert reference_const_sizes == test_const_sizes +def test_constant_as_input(): + """Test to check that constants specified as inputs aren't + interpreted as an encoded constant.""" + + def get_graph(): + dtype = "uint8" + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype=dtype) + conv1 = make_ethosu_conv2d( + ifm, + 32, + 16, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) + add1 = make_ethosu_binary_elementwise( + conv1, scalar, ifm_channels=32, ifm2_channels=1, operator_type="ADD", ofm_dtype=dtype + ) + func = relay.Function(relay.analysis.free_vars(add1), add1) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + tir_mod, params = lower_to_tir(get_graph(), copy_constants()) + + # Check tile address for the scalar constant input hasn't been + # overwritten. + extern_calls = tir_mod["main"].body.body.body.body.body + binary_elementwise = extern_calls[-1].value + args = binary_elementwise.args + + reason = "Tile address overwritten" + assert args[26] == 0, reason + assert args[27] == 0, reason + assert args[28] == 0, reason + + # More generally, check compiles successfully to make sure + # nothing else was overrwritten. + tir_to_cs_translator.translate(tir_mod, params) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 2247234e6d68..8c3e4e31c1ca 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -693,6 +693,53 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +def test_binary_add_from_constant_scalar(): + dtype = "uint8" + ifm_shape = (1, 4, 4, 8) + + def create_graph(): + inp = relay.var("input", shape=ifm_shape, dtype=dtype) + scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) + add = relay.qnn.op.add( + inp, + scalar, + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + relay.const(1.0, dtype="float32"), + relay.const(0, dtype="int32"), + ) + func = relay.Function(relay.analysis.free_vars(add), add) + return tvm.IRModule.from_expr(func) + + def verify(ext_func): + op = ext_func.body + assert list(op.args[0].checked_type.shape) == [1, 4, 4, 8] + assert list(op.args[1].checked_type.shape) == [1, 1, 1, 1] + assert op.args[0].checked_type.dtype == "uint8" + assert list(op.checked_type.shape) == [1, 4, 4, 8] + assert op.checked_type.dtype == "uint8" + assert op.attrs.operator_type == "ADD" + + rewriter = legalize.AddRewriter() + pattern_table = [ + ( + ethosu.AddParams.composite_name, + ethosu.qnn_add_pattern(), + lambda pat: ethosu.AddParams(pat).is_valid(), + ), + ] + + mod = create_graph() + mod = partition_ethosu_by_table(mod, pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + @pytest.mark.parametrize( "ifm_shape, ifm2_shape, reversed_operands", [