-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[microNPU] Allow constants to be given as input to an operator #9515
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
# specific language governing permissions and limitations | ||
# under the License. | ||
import pytest | ||
import numpy as np | ||
|
||
pytest.importorskip("ethosu.vela") | ||
import tvm | ||
|
@@ -23,8 +24,10 @@ | |
from tvm.relay.testing import run_opt_pass | ||
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir | ||
from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute | ||
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants | ||
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator | ||
|
||
from .infra import make_ethosu_conv2d | ||
from .infra import make_ethosu_conv2d, make_ethosu_binary_elementwise | ||
|
||
|
||
# fmt: off | ||
|
@@ -270,5 +273,47 @@ def _get_func(): | |
assert reference_const_sizes == test_const_sizes | ||
|
||
|
||
def test_constant_as_input(): | ||
"""Test to check that constants specified as inputs aren't | ||
interpreted as an encoded constant.""" | ||
|
||
def get_graph(): | ||
dtype = "uint8" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does the constant need to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now this needs to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, makes sense, thanks for clarifying! :) |
||
ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype=dtype) | ||
conv1 = make_ethosu_conv2d( | ||
ifm, | ||
32, | ||
16, | ||
(1, 1), | ||
(0, 0), | ||
(1, 1), | ||
(1, 1), | ||
) | ||
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) | ||
add1 = make_ethosu_binary_elementwise( | ||
conv1, scalar, ifm_channels=32, ifm2_channels=1, operator_type="ADD", ofm_dtype=dtype | ||
) | ||
func = relay.Function(relay.analysis.free_vars(add1), add1) | ||
func = run_opt_pass(func, relay.transform.InferType()) | ||
return func | ||
|
||
tir_mod, params = lower_to_tir(get_graph(), copy_constants()) | ||
|
||
# Check tile address for the scalar constant input hasn't been | ||
# overwritten. | ||
extern_calls = tir_mod["main"].body.body.body.body.body | ||
binary_elementwise = extern_calls[-1].value | ||
args = binary_elementwise.args | ||
|
||
reason = "Tile address overwritten" | ||
assert args[26] == 0, reason | ||
assert args[27] == 0, reason | ||
assert args[28] == 0, reason | ||
|
||
# More generally, check compiles successfully to make sure | ||
# nothing else was overrwritten. | ||
tir_to_cs_translator.translate(tir_mod, params) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we start from Relay there instead of TFLite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure if there was a way to generate similar relay from TFLite, although admittedly I didn't really check. I'll have a look into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, Relay is also needed due to the
uint8
restriction so I'll leave this for nowThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok yes, that makes sense!