Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Allow constants to be given as input to an operator #9515

Merged
merged 2 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,19 @@ def _visit_rewrite(stmt):
# For extern calls, we need to rewrite pairs of arguments corresponding to
# base address load and the length of the load.
new_args = [stmt.args[0]]
new_buffers = rewrite_buffer.values()
for i in range(1, len(stmt.args)):
# If the previous argument was a load, the current should be a length
if isinstance(stmt.args[i - 1], tvm.tir.Load):
load = stmt.args[i - 1]
pointer = load.buffer_var
if pointer in pointer_to_buffer:
new_args.append(np.prod(list(pointer_to_buffer[pointer].shape)))
continue
buffer = pointer_to_buffer[pointer]
# Only rewrite the arguments of buffers that have been encoded
if buffer in new_buffers:
new_arg = np.prod(list(pointer_to_buffer[pointer].shape))
new_args.append(new_arg)
continue
new_args.append(stmt.args[i])

return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span)
Expand Down
50 changes: 50 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,56 @@ def representative_dataset():
infra.verify_source(compiled_models, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_binary_add_from_constant_scalar(accel_type):
dtype = "uint8"
ifm_shape = (1, 4, 4, 8)

def create_relay_graph():
inp = relay.var("input", shape=ifm_shape, dtype=dtype)
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
add = relay.qnn.op.add(
inp,
scalar,
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
)
func = relay.Function(relay.analysis.free_vars(add), add)
Comment on lines +443 to +456
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we start from Relay there instead of TFLite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if there was a way to generate similar relay from TFLite, although admittedly I didn't really check. I'll have a look into it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, Relay is also needed due to the uint8 restriction so I'll leave this for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok yes, that makes sense!

return tvm.IRModule.from_expr(func)

mod = create_relay_graph()
partitioned_mod = partition_for_ethosu(mod)

# Generate reference data
input_data = {"input": np.random.randint(low=0, high=255, size=ifm_shape, dtype=dtype)}
output_data = generate_ref_data(mod, input_data)

compiled_models = infra.build_source(
partitioned_mod,
input_data,
output_data,
accel_type,
output_tolerance=0,
)

# Assumes only two runtime.Modules are created -- i.e. single offload module
imported_modules = compiled_models[0].executor_factory.lib.imported_modules
assert len(imported_modules) == 2
ethosu_module = imported_modules[0]

# Verify generated C source
get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)

infra.print_payload(cmms)
infra.verify_source(compiled_models, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape, ifm2_shape",
Expand Down
47 changes: 46 additions & 1 deletion tests/python/contrib/test_ethosu/test_encode_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pytest
import numpy as np

pytest.importorskip("ethosu.vela")
import tvm
Expand All @@ -23,8 +24,10 @@
from tvm.relay.testing import run_opt_pass
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator

from .infra import make_ethosu_conv2d
from .infra import make_ethosu_conv2d, make_ethosu_binary_elementwise


# fmt: off
Expand Down Expand Up @@ -270,5 +273,47 @@ def _get_func():
assert reference_const_sizes == test_const_sizes


def test_constant_as_input():
"""Test to check that constants specified as inputs aren't
interpreted as an encoded constant."""

def get_graph():
dtype = "uint8"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the constant need to beuint8? (just asking for enlightenment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this needs to be uint8 due to https://github.com/apache/tvm/blob/main/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py#L254. Removing this opens up another can of worms that I'll address in another PR :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense, thanks for clarifying! :)

ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype=dtype)
conv1 = make_ethosu_conv2d(
ifm,
32,
16,
(1, 1),
(0, 0),
(1, 1),
(1, 1),
)
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
add1 = make_ethosu_binary_elementwise(
conv1, scalar, ifm_channels=32, ifm2_channels=1, operator_type="ADD", ofm_dtype=dtype
)
func = relay.Function(relay.analysis.free_vars(add1), add1)
func = run_opt_pass(func, relay.transform.InferType())
return func

tir_mod, params = lower_to_tir(get_graph(), copy_constants())

# Check tile address for the scalar constant input hasn't been
# overwritten.
extern_calls = tir_mod["main"].body.body.body.body.body
binary_elementwise = extern_calls[-1].value
args = binary_elementwise.args

reason = "Tile address overwritten"
assert args[26] == 0, reason
assert args[27] == 0, reason
assert args[28] == 0, reason

# More generally, check compiles successfully to make sure
# nothing else was overrwritten.
tir_to_cs_translator.translate(tir_mod, params)


if __name__ == "__main__":
pytest.main([__file__])
47 changes: 47 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,53 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


def test_binary_add_from_constant_scalar():
dtype = "uint8"
ifm_shape = (1, 4, 4, 8)

def create_graph():
inp = relay.var("input", shape=ifm_shape, dtype=dtype)
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
add = relay.qnn.op.add(
inp,
scalar,
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
)
func = relay.Function(relay.analysis.free_vars(add), add)
return tvm.IRModule.from_expr(func)

def verify(ext_func):
op = ext_func.body
assert list(op.args[0].checked_type.shape) == [1, 4, 4, 8]
assert list(op.args[1].checked_type.shape) == [1, 1, 1, 1]
assert op.args[0].checked_type.dtype == "uint8"
assert list(op.checked_type.shape) == [1, 4, 4, 8]
assert op.checked_type.dtype == "uint8"
assert op.attrs.operator_type == "ADD"

rewriter = legalize.AddRewriter()
pattern_table = [
(
ethosu.AddParams.composite_name,
ethosu.qnn_add_pattern(),
lambda pat: ethosu.AddParams(pat).is_valid(),
),
]

mod = create_graph()
mod = partition_ethosu_by_table(mod, pattern_table)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
rewriter, mod["tvmgen_default_ethos_u_main_0"]
)
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize(
"ifm_shape, ifm2_shape, reversed_operands",
[
Expand Down