Skip to content

Commit

Permalink
BUG: Look through on_device annotations when looking for shape consta…
Browse files Browse the repository at this point in the history
…nts (apache#9345)

apache#8788 introduced a perf regression
since a `shape.as<ConstantNode>` in `alloc_tensor` was always failing
due to the extra `on_device` annotation on the constant. Fixed that,
and introduced some helpers to make this situation easier to deal with.

(This is CORE-102 in OctoML JIRA).

(Second try -- test_crp.py failure seems unrelated)
  • Loading branch information
mbs-octoml authored and ylc committed Jan 13, 2022
1 parent 02dfc4e commit e16b5e5
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 19 deletions.
3 changes: 1 addition & 2 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
* \return The corresponding token.
*/
StorageInfo GetStorage(const Expr& expr) {
auto props = GetOnDeviceProps(expr);
// See through "on_device" calls.
Expr true_expr = props.body.defined() ? props.body : expr;
Expr true_expr = IgnoreOnDevice(expr);
VisitExpr(true_expr);
auto it = storage_device_map_.find(true_expr);
ICHECK(it != storage_device_map_.end());
Expand Down
5 changes: 2 additions & 3 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,9 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor {
* \return The corresponding token.
*/
const std::vector<StorageToken*>& GetToken(const Expr& expr) {
this->VisitExpr(expr);
// See through on_device calls.
auto props = GetOnDeviceProps(expr);
Expr real_expr = props.body.defined() ? props.body : expr;
Expr real_expr = IgnoreOnDevice(expr);
this->VisitExpr(real_expr);
auto it = token_map_.find(real_expr.get());
ICHECK(it != token_map_.end()) << "Expression not found in storage map:" << std::endl
<< PrettyPrint(real_expr);
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
auto offset_register = last_register_;

// If the shape is constant then we will emit a static tensor allocation
// instruction.
auto const_shape = args[2].as<ConstantNode>();
// instruction. It may be wrapped by an on_device, but it will be on the host
// which is assumed by the alloc_tensor instruction anyway.
auto const_shape = AsIgnoringOnDevice<ConstantNode>(args[2]);

if (const_shape) {
NDArray shape = const_shape->data;
Expand All @@ -619,7 +620,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
this->VisitExpr(args[0]);
auto size_register = last_register_;

ICHECK(args[1].as<ConstantNode>());
ICHECK(args[1].as<ConstantNode>()); // Always a literal.
NDArray alignment_arr = args[1].as<ConstantNode>()->data;
ICHECK_EQ(alignment_arr->dtype.code, 0U)
<< "The dtype of constant shape must be int32 or int64, but got "
Expand Down
26 changes: 26 additions & 0 deletions src/relay/op/annotation/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node);
*/
OnDeviceProps GetOnDeviceProps(const Expr& expr);

/*!
* \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns
* \p expr directly.
*/
inline Expr IgnoreOnDevice(const Expr& expr) {
OnDeviceProps props = GetOnDeviceProps(expr);
return props.body.defined() ? props.body : expr;
}

/*!
* \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through
* any "on_device" annotations.
*/
template <typename NodeType>
const NodeType* AsIgnoringOnDevice(const Expr& expr) {
const auto* node = expr.as<NodeType>();
if (node != nullptr) {
return node;
}
OnDeviceProps props = GetOnDeviceProps(expr);
if (!props.body.defined()) {
return nullptr;
}
return props.body.as<NodeType>();
}

/*!
* \brief Returns \p function annotated with "param_device_types" and "result_device_type"
* attributes capturing parameter and result devices types respectively.
Expand Down
10 changes: 3 additions & 7 deletions src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,9 @@ Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype,
attrs->assert_shape = assert_shape;
} else {
// Look through any on_device for the shape argument expression.
Expr literal_shape = shape;
auto props = GetOnDeviceProps(literal_shape);
if (props.body.defined()) {
// See through on_device calls.
literal_shape = props.body;
}
attrs->const_shape = Downcast<Constant>(literal_shape);
const auto* constant_node = AsIgnoringOnDevice<ConstantNode>(shape);
ICHECK(constant_node);
attrs->const_shape = GetRef<Constant>(constant_node);
}
static const Op& op = Op::Get("memory.alloc_tensor");
return Call(op, {storage, offset, shape}, Attrs(attrs), {});
Expand Down
5 changes: 2 additions & 3 deletions src/relay/transforms/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr&
* is it atomic?
* if so, the compute cost of the expression is bounded so it can be copy without graph mode.
*/
inline bool IsAtomic(const Expr& e) {
auto props = GetOnDeviceProps(e);
Expr true_expr = props.body.defined() ? props.body : e;
inline bool IsAtomic(const Expr& expr) {
Expr true_expr = IgnoreOnDevice(expr);
return true_expr.as<VarNode>() || true_expr.as<OpNode>() || true_expr.as<ConstructorNode>() ||
true_expr.as<GlobalVarNode>() ||
true_expr.as<ConstantNode>(); // Constant is always by reference.
Expand Down
15 changes: 14 additions & 1 deletion tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,19 @@ def test_vm_reshape_tensor(target, dev):
check_result(target, dev, [x_np, y_np], x_np.reshape([8, 2, 8]), mod)


def test_vm_reshape_and_copy(target, dev):
"""Make sure the compiler notices the reshape result shape is a literal and can use
the immediate-mode alloc_tensor instruction instead of alloc_tensor_reg."""
x_np = np.random.uniform(size=(1, 1)).astype("float32")
x = relay.var("x", shape=(1, 1), dtype="float32")
mod = tvm.IRModule.from_expr(relay.Function([x], relay.copy(relay.reshape(x, [0, 1]))))
with tvm.transform.PassContext(opt_level=3):
exec = relay.vm.compile(mod, "llvm")
assert "alloc_tensor" in exec.bytecode
assert not "alloc_tensor_reg" in exec.bytecode
check_result(target, dev, [x_np], x_np.reshape([1, 1]), mod)


def test_vm_reshape_tuple(target, dev, x_shape=(1, 4, 2), y_shape=(1, 2, 10)):
tup = relay.var(
"tup",
Expand Down Expand Up @@ -963,4 +976,4 @@ def test_benchmark_end_to_end_rpc():
if __name__ == "__main__":
import sys

sys.exit(pytest.main(sys.argv))
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit e16b5e5

Please sign in to comment.