Skip to content

Commit

Permalink
[microNPU] Add support for scalar values
Browse files Browse the repository at this point in the history
PR apache#9515 enabled support for scalar constants, but didn't consider the
case of a scalar value where the underlying constant data does not have
a shape i.e. `constant.shape == []`. See the test case for a visual
differece when the scalar value is 1.

Change-Id: Id7a238cb5bf999dd5a8428c097202f9fb940a5f0
  • Loading branch information
lhutton1 committed Jan 4, 2022
1 parent 9cc1df6 commit 1abe534
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,8 @@ def callback(
ifm2_zero_point=int(params.ifm2.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=params.ifm.shape[-1],
ifm2_channels=params.ifm2.shape[-1],
ifm_channels=params.ifm.shape[-1] if params.ifm.shape else 1,
ifm2_channels=params.ifm2.shape[-1] if params.ifm2.shape else 1,
reversed_operands=params.reversed_operands,
ofm_dtype=params.ofm.dtype,
activation=activation,
Expand Down
11 changes: 5 additions & 6 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,11 @@ def __init__(self):

def visit_constant(self, const):
if isinstance(const.checked_type, relay.ty.TensorType):
if const.checked_type.concrete_shape != ():
self.constants.append(const.data.asnumpy())
name = "p" + str(len(self.constants))
var = relay.var(type_annotation=const.checked_type, name_hint=name)
self.const_vars.append(var)
return var
self.constants.append(const.data.asnumpy())
name = "p" + str(len(self.constants))
var = relay.var(type_annotation=const.checked_type, name_hint=name)
self.const_vars.append(var)
return var

return const

Expand Down
12 changes: 11 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,17 @@ def _visit(tensor, reader, lut):
if tensor not in planned:
planned.add(tensor)
if isinstance(tensor.op, tvm.te.PlaceholderOp) and tensor != lut:
index = list(cached_func.inputs).index(tensor)
# Find index of input using 'same_as' check to prevent equality
# ambiguity when encountering a scalar.
index = -1
for i, var in enumerate(cached_func.inputs):
if var.same_as(tensor):
index = i
break
assert (
index >= 0
), f"Tensor {tensor} was not found in inputs: {cached_func.inputs}"

if index in const_dict:
sch.cache_read(tensor, "global", [reader])

Expand Down
5 changes: 3 additions & 2 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,13 @@ def create_mod_from_relay():

@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("dtype", ["int8", "uint8"])
def test_elementwise_add_from_constant_scalar(accel_type, dtype):
@pytest.mark.parametrize("constant", [np.ones((1, 1, 1, 1)), np.array(1)])
def test_elementwise_add_from_constant_scalar(accel_type, dtype, constant):
ifm_shape = (1, 4, 4, 8)

def create_relay_graph():
inp = relay.var("input", shape=ifm_shape, dtype=dtype)
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
scalar = relay.const(constant, dtype=dtype)
add = relay.qnn.op.add(
inp,
scalar,
Expand Down

0 comments on commit 1abe534

Please sign in to comment.