Skip to content

Commit

Permalink
Reorder dynamic to static and simplify inference, lower DynamicToStat…
Browse files Browse the repository at this point in the history
…ic Opt Level (#7213)

* reorder dynamic to static and simplify inference, add a dropout unit test

* lower dynamic to static opt level

* autoformat test

* raise DynamicToStatic to opt level 2 to match Constant Folding
  • Loading branch information
Matthew Brookhart authored Jan 13, 2021
1 parent 1410e68 commit 384714b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,11 @@ class RelayBuildModule : public runtime::ModuleNode {
pass_seqs.push_back(transform::Legalize());
}

pass_seqs.push_back(transform::SimplifyInference());

// Convert Dynamic ops to static versions
pass_seqs.push_back(transform::DynamicToStatic());

pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
*rv = false;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ Pass DynamicToStatic() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(DynamicToStatic(f, m));
};
return CreateFunctionPass(pass_func, 3, "DynamicToStatic", {});
return CreateFunctionPass(pass_func, 2, "DynamicToStatic", {});
}

TVM_REGISTER_GLOBAL("relay._transform.DynamicToStatic").set_body_typed([]() {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,16 @@ def test_dropout():
yy = run_infer_type(y)
assert yy.checked_type == input_ty

in_np = np.random.random([4, 5, 6]).astype("float32")
x = relay.const(in_np)
y = relay.nn.dropout(x, rate=0.5)
func = relay.Function([], y)
for target, ctx in tvm.testing.enabled_targets():
for backend in ["debug", "graph"]:
intrp = relay.create_executor("debug", ctx=ctx, target=target)
op_res = intrp.evaluate(func)()
tvm.testing.assert_allclose(op_res.asnumpy(), in_np, rtol=0.01)


def test_batch_norm():
for dtype in ["float16", "float32"]:
Expand Down

0 comments on commit 384714b

Please sign in to comment.