Skip to content

Commit

Permalink
[microNPU] Allow constants to be given as input to an operator (apach…
Browse files Browse the repository at this point in the history
…e#9515)

* [microNPU] Allow constants to be given as input to an operator

Currently the expectation is that all constants need to be encoded,
however, this is not always the case for scalar inputs. This PR
ensures that constants that don't need encoding are not treated
like encoded constants by the EncodeConstants pass.

Change-Id: I79cf4aa10d01c4ae9ce9cdafb6f21ebb2d028126

* address comments

Change-Id: I67b61a2d2f67de25c47d2ace0e3a22c59ba8ea15
  • Loading branch information
lhutton1 authored and ylc committed Jan 7, 2022
1 parent 8e8791b commit ff814b7
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 3 deletions.
9 changes: 7 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,19 @@ def _visit_rewrite(stmt):
# For extern calls, we need to rewrite pairs of arguments corresponding to
# base address load and the length of the load.
new_args = [stmt.args[0]]
new_buffers = rewrite_buffer.values()
for i in range(1, len(stmt.args)):
# If the previous argument was a load, the current should be a length
if isinstance(stmt.args[i - 1], tvm.tir.Load):
load = stmt.args[i - 1]
pointer = load.buffer_var
if pointer in pointer_to_buffer:
new_args.append(np.prod(list(pointer_to_buffer[pointer].shape)))
continue
buffer = pointer_to_buffer[pointer]
# Only rewrite the arguments of buffers that have been encoded
if buffer in new_buffers:
new_arg = np.prod(list(pointer_to_buffer[pointer].shape))
new_args.append(new_arg)
continue
new_args.append(stmt.args[i])

return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span)
Expand Down
50 changes: 50 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,56 @@ def representative_dataset():
infra.verify_source(compiled_models, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_binary_add_from_constant_scalar(accel_type):
dtype = "uint8"
ifm_shape = (1, 4, 4, 8)

def create_relay_graph():
inp = relay.var("input", shape=ifm_shape, dtype=dtype)
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
add = relay.qnn.op.add(
inp,
scalar,
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
)
func = relay.Function(relay.analysis.free_vars(add), add)
return tvm.IRModule.from_expr(func)

mod = create_relay_graph()
partitioned_mod = partition_for_ethosu(mod)

# Generate reference data
input_data = {"input": np.random.randint(low=0, high=255, size=ifm_shape, dtype=dtype)}
output_data = generate_ref_data(mod, input_data)

compiled_models = infra.build_source(
partitioned_mod,
input_data,
output_data,
accel_type,
output_tolerance=0,
)

# Assumes only two runtime.Modules are created -- i.e. single offload module
imported_modules = compiled_models[0].executor_factory.lib.imported_modules
assert len(imported_modules) == 2
ethosu_module = imported_modules[0]

# Verify generated C source
get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)

infra.print_payload(cmms)
infra.verify_source(compiled_models, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape, ifm2_shape",
Expand Down
47 changes: 46 additions & 1 deletion tests/python/contrib/test_ethosu/test_encode_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pytest
import numpy as np

pytest.importorskip("ethosu.vela")
import tvm
Expand All @@ -23,8 +24,10 @@
from tvm.relay.testing import run_opt_pass
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator

from .infra import make_ethosu_conv2d
from .infra import make_ethosu_conv2d, make_ethosu_binary_elementwise


# fmt: off
Expand Down Expand Up @@ -270,5 +273,47 @@ def _get_func():
assert reference_const_sizes == test_const_sizes


def test_constant_as_input():
"""Test to check that constants specified as inputs aren't
interpreted as an encoded constant."""

def get_graph():
dtype = "uint8"
ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype=dtype)
conv1 = make_ethosu_conv2d(
ifm,
32,
16,
(1, 1),
(0, 0),
(1, 1),
(1, 1),
)
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
add1 = make_ethosu_binary_elementwise(
conv1, scalar, ifm_channels=32, ifm2_channels=1, operator_type="ADD", ofm_dtype=dtype
)
func = relay.Function(relay.analysis.free_vars(add1), add1)
func = run_opt_pass(func, relay.transform.InferType())
return func

tir_mod, params = lower_to_tir(get_graph(), copy_constants())

# Check tile address for the scalar constant input hasn't been
# overwritten.
extern_calls = tir_mod["main"].body.body.body.body.body
binary_elementwise = extern_calls[-1].value
args = binary_elementwise.args

reason = "Tile address overwritten"
assert args[26] == 0, reason
assert args[27] == 0, reason
assert args[28] == 0, reason

# More generally, check compiles successfully to make sure
# nothing else was overrwritten.
tir_to_cs_translator.translate(tir_mod, params)


if __name__ == "__main__":
pytest.main([__file__])
47 changes: 47 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,53 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


def test_binary_add_from_constant_scalar():
dtype = "uint8"
ifm_shape = (1, 4, 4, 8)

def create_graph():
inp = relay.var("input", shape=ifm_shape, dtype=dtype)
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
add = relay.qnn.op.add(
inp,
scalar,
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
)
func = relay.Function(relay.analysis.free_vars(add), add)
return tvm.IRModule.from_expr(func)

def verify(ext_func):
op = ext_func.body
assert list(op.args[0].checked_type.shape) == [1, 4, 4, 8]
assert list(op.args[1].checked_type.shape) == [1, 1, 1, 1]
assert op.args[0].checked_type.dtype == "uint8"
assert list(op.checked_type.shape) == [1, 4, 4, 8]
assert op.checked_type.dtype == "uint8"
assert op.attrs.operator_type == "ADD"

rewriter = legalize.AddRewriter()
pattern_table = [
(
ethosu.AddParams.composite_name,
ethosu.qnn_add_pattern(),
lambda pat: ethosu.AddParams(pat).is_valid(),
),
]

mod = create_graph()
mod = partition_ethosu_by_table(mod, pattern_table)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
rewriter, mod["tvmgen_default_ethos_u_main_0"]
)
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize(
"ifm_shape, ifm2_shape, reversed_operands",
[
Expand Down

0 comments on commit ff814b7

Please sign in to comment.