Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] Allow custom codegens to register their own constant updater #6697

Merged
merged 6 commits into from
Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/relay/backend/contrib/ethosn/codegen_ethosn.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ runtime::Module CompileEthosn(const ObjectRef& ref) {

TVM_REGISTER_GLOBAL("relay.ext.ethos-n").set_body_typed(CompileEthosn);

TVM_REGISTER_GLOBAL("relay.ext.ethos-n.constant_updater")
.set_body_typed([](Expr expr, std::string symbol) { return Map<String, runtime::NDArray>(); });

} // namespace ethosn
} // namespace contrib
} // namespace relay
Expand Down
9 changes: 1 addition & 8 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
ICHECK(ext_func.defined()) << "External function is not defined.";

// Step into the functions that are handled by external codegen to
// collect metadata.
const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
std::string symobl = std::string(name_node.value());
ConstantUpdater const_visit(symobl, &params_);
const_visit(func);

UpdateConstants(func, &params_);
return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
}

Expand Down
31 changes: 31 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ struct ConstantUpdater : public ExprVisitor {
std::unordered_map<std::string, runtime::NDArray>* params_;
};

/*!
* \brief A function to update the params with constants found in an external function.
* \param func The function from which to get the constant params.
* \param params The params to update with the constants.
*/
inline void UpdateConstants(Function func,
std::unordered_map<std::string, runtime::NDArray>* params) {
auto codegen = func->GetAttr<String>(attr::kCompiler);
ICHECK(codegen.defined()) << "No external codegen is set";
std::string codegen_name = codegen.value();
const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
std::string symbol = std::string(name_node.value());
std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater";
// Get the constant updater for the external codegen
auto pf = tvm::runtime::Registry::Get(const_update_name);
// If the backend hasn't registered a constant updater, use a default one
if (pf == nullptr) {
ConstantUpdater const_visit(symbol, params);
const_visit(func);
} else {
Map<String, tvm::runtime::NDArray> constants = (*pf)(func, symbol);
for (const auto& it : constants) {
std::string const_name(it.first);
// Constant names should begin this the compiler name (to avoid conflicts)
ICHECK(const_name.find(codegen_name) == 0)
<< "External constant names must start with compiler name";
(*params)[const_name] = it.second;
}
}
}

/*!
* \brief A simple wrapper around ExprFunctor for a single argument case.
* The result of visit is memoized.
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1124,8 +1124,8 @@ void VMCompiler::Codegen() {
if (target_str == "ext_dev") {
// Collect metadata in functions that are handled by external codegen.
ICHECK(mod->ContainGlobalVar(cfunc->func_name));
backend::ConstantUpdater const_visit(cfunc->func_name, &params_);
const_visit(Downcast<Function>(mod->Lookup(cfunc->func_name)));
Function func = Downcast<Function>(mod->Lookup(cfunc->func_name));
backend::UpdateConstants(func, &params_);
continue;
} else if (funcs.count(target_str) == 0) {
funcs.emplace(target_str, mod);
Expand Down
82 changes: 82 additions & 0 deletions tests/python/contrib/test_ethosn/test_constant_duplication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Test that constants aren't duplicated for Ethos-N"""

import numpy as np
import tvm
from tvm import relay
from tvm.relay.op.contrib.ethosn import ethosn_available
from . import infrastructure as tei


def _get_model():
"""Return a model and any parameters it may have"""
shape = (1, 4, 4, 4)
kernel_h = 3
kernel_w = 3
out_channels = 8

a = relay.var("a", shape=shape, dtype="uint8")
add_const_value = tvm.nd.array(np.random.randint(0, high=10, size=shape, dtype="uint8"))
add_const = relay.const(add_const_value, "uint8")
a = relay.add(a, add_const)
weight_shape = (kernel_h, kernel_w, shape[3], out_channels)
w = tvm.nd.array(np.random.randint(low=0, high=255, size=weight_shape, dtype="uint8"))
weights = relay.const(w, "uint8")
conv = relay.qnn.op.conv2d(
a,
weights,
input_zero_point=relay.const(0, "int32"),
kernel_zero_point=relay.const(0, "int32"),
input_scale=relay.const(0.3, "float32"),
kernel_scale=relay.const(0.4, "float32"),
kernel_size=(kernel_h, kernel_w),
data_layout="NHWC",
kernel_layout="HWIO",
dilation=(1, 1),
strides=(1, 1),
groups=1,
channels=out_channels,
padding=(0, 0, 0, 0),
out_dtype="int32",
)
b = tvm.nd.array(np.random.randint(0, high=10, size=(out_channels,), dtype="int32"))
biasc = relay.const(b, "int32")
bias = relay.nn.bias_add(conv, biasc, axis=3)
req = relay.qnn.op.requantize(
bias,
relay.const(0.3 * 0.4, "float32"), # input zero scale
relay.const(0, "int32"), # input zero point
relay.const(0.4, "float32"), # output zero scale
relay.const(0, "int32"), # output zero point
out_dtype="uint8",
)
params = {"w": w, "b": b}
return req, params


def test_constant_duplication():
if not ethosn_available():
return

model, params = _get_model()
mod = tei.make_module(model, params)
res = tei.build(mod, params, npu=True, expected_host_ops=1)
for key, value in res.params.items():
assert key == "p0"
assert value.asnumpy().size == 64
34 changes: 34 additions & 0 deletions tests/python/relay/test_external_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,39 @@ def test_extern_gcc():
check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))


def test_extern_gcc_consts():
@tvm._ffi.register_func("relay.ext.ccompiler.constant_updater")
def constant_updater(expr, symbol):
"""A dummy constant updater just to test that a custom one works."""
return {"ccompiler_0_p0": tvm.nd.array(y0_data)}

x = relay.var("x", shape=(8, 8))
y0_data = np.random.uniform(0, 1, (8, 8)).astype("float32")

x0 = relay.var("x0", shape=(8, 8))
y0_const = relay.const(y0_data, "float32")
z = x0 + y0_const
f = relay.Function([x0], z)
f = set_external_func_attr(f, "ccompiler", "ccompiler_0")
call = relay.Call(f, [x])
mod = tvm.IRModule.from_expr(call)

with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
compiler = relay.backend.vm.VMCompiler()
compiler.lower(mod, "llvm")
compiler.codegen()
params = compiler.get_params()
assert len(params) == 1
assert "ccompiler_0_p0" in params.keys()

with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
_, _, params = relay.build(mod, target="llvm")
assert len(params) == 1
assert "ccompiler_0_p0" in params.keys()

tvm._ffi.registry.remove_global_func("relay.ext.ccompiler.constant_updater")


def test_extern_dnnl():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
Expand Down Expand Up @@ -301,5 +334,6 @@ def test_extern_dnnl_const():
test_extern_gcc_single_op()
test_extern_gcc_single_op_int()
test_extern_gcc()
test_extern_gcc_consts()
test_extern_dnnl()
test_extern_dnnl_const()