Skip to content

Commit

Permalink
[Relay] [Error] Fix error in partial evaluator (apache#3693)
Browse files Browse the repository at this point in the history
* fix

* lint
  • Loading branch information
MarisaKirisame authored and wweic committed Aug 9, 2019
1 parent b96ddee commit 4890cae
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
39 changes: 27 additions & 12 deletions src/relay/pass/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ Expr PostProcess(const Expr&);
/*! \brief The base container type of Relay values. */
class StaticNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
static constexpr const char* _type_key = "relay.Static";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
};

Expand Down Expand Up @@ -161,6 +161,7 @@ struct PStaticNode : Node {
PStaticNode(const Static& pstatic, const Expr& dynamic) :
pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
static constexpr const char* _type_key = "relay.PStatic";
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
};

Expand All @@ -169,6 +170,7 @@ RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef);
struct STupleNode : StaticNode {
std::vector<PStatic> fields;
explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) { }
static constexpr const char* _type_key = "relay.STuple";
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
};

Expand All @@ -181,7 +183,8 @@ Static MkSTuple(const std::vector<PStatic>& fields) {
struct STensorNode : StaticNode {
runtime::NDArray data;
explicit STensorNode(const NDArray& data) : data(data) { }
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
static constexpr const char* _type_key = "relay.STensor";
TVM_DECLARE_NODE_TYPE_INFO(STensorNode, StaticNode);
};

RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value);
Expand All @@ -195,6 +198,7 @@ struct SConstructorNode : StaticNode {
std::vector<PStatic> fields;
SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields) :
constructor(constructor), fields(fields) { }
static constexpr const char* _type_key = "relay.SConstructor";
TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode);
};

Expand All @@ -205,6 +209,7 @@ Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>
}

struct SRefNode : StaticNode {
static constexpr const char* _type_key = "relay.SRef";
// we will use the address as the guid for hashing
TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode);
};
Expand All @@ -223,6 +228,7 @@ using Func = std::function<PStatic(const std::vector<PStatic>&,
struct SFuncNode : StaticNode {
Func func;
explicit SFuncNode(const Func& func) : func(func) { }
static constexpr const char* _type_key = "relay.SFunc";
TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode);
};

Expand Down Expand Up @@ -711,8 +717,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return VisitFunc(GetRef<Function>(op), ll);
}

struct ReflectError : dmlc::Error {
ReflectError() : dmlc::Error("static value not found") { }
};

Expr Reflect(const PStatic& st) {
if (const STensorNode* op = st->pstatic.as<STensorNode>()) {
if (!st->pstatic.defined()) {
throw ReflectError();
} else if (const STensorNode* op = st->pstatic.as<STensorNode>()) {
return ConstantNode::make(op->data);
} else if (const STupleNode* op = st->pstatic.as<STupleNode>()) {
tvm::Array<Expr> fields;
Expand All @@ -721,7 +733,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
return TupleNode::make(fields);
} else {
LOG(FATAL) << "Unknown case";
LOG(FATAL) << "Unknown case: " << st->dynamic;
throw;
}
}
Expand Down Expand Up @@ -767,19 +779,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (const PStatic& ps : pv) {
ns_args.push_back(ps->dynamic);
}
PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
auto ns = [&]() {
return NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
};
if (StatefulOp(expr)) {
return ns;
return ns();
}
tvm::Array<Expr> args;
for (const PStatic& ps : pv) {
if (ps->pstatic.defined()) {
try {
tvm::Array<Expr> args;
for (const PStatic& ps : pv) {
args.push_back(Reflect(ps));
} else {
return ns;
}
return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll);
}
catch (const ReflectError&) {
return ns();
}
return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll);
};
}

Expand Down
11 changes: 10 additions & 1 deletion tests/python/relay/test_pass_partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from tvm.relay.prelude import Prelude
from tvm.relay import op, create_executor, transform
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
Expand Down Expand Up @@ -306,6 +306,14 @@ def test_double():
assert alpha_equal(res.body, make_nat_expr(p, 6))


def test_concat():
t = relay.TensorType([10], "float32")
x = Var("x", t)
y = Var("x", t)
orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
assert_alpha_equal(orig, dcpe(orig))


if __name__ == '__main__':
test_ref()
test_tuple()
Expand All @@ -323,3 +331,4 @@ def test_double():
test_nat_id()
test_global_match_nat_id()
test_match_nat_id()
test_concat()

0 comments on commit 4890cae

Please sign in to comment.