From 51dc332646d90b77ae29e2e2dbe21f40008a0082 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 1 Mar 2021 17:12:25 +0900 Subject: [PATCH] Fix foldconstant involving dropout (#7550) Co-authored-by: masa --- src/relay/op/nn/nn.cc | 3 ++- tests/python/relay/test_pass_fold_constant.py | 26 +++++++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 97460ba4a98b..38c33b45936e 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -591,7 +591,8 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input .add_argument("data", "Tensor", "Input to which dropout will be applied.") .set_support_level(1) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .add_type_rel("Dropout", DropoutRel); + .add_type_rel("Dropout", DropoutRel) + .set_attr("TOpIsStateful", true); // batch_norm TVM_REGISTER_NODE_TYPE(BatchNormAttrs); diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 14ad419e80c6..7b4eb5231a2c 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -16,7 +16,6 @@ # under the License. import numpy as np import tvm -from tvm import te from tvm import relay from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name @@ -276,12 +275,35 @@ def initializer(_, param): assert tvm.ir.structural_equal(mod["main"], expect) +def test_fold_dropout(): + def before(): + # A constant graph to fire fold constant + data = relay.const(np.arange(10).astype(np.float32)) + dropout = relay.nn.dropout(data) + add = dropout + relay.const(1.0) + return relay.Function(relay.analysis.free_vars(add), add) + + passes = tvm.transform.Sequential( + [ + relay.transform.InferType(), + relay.transform.FoldConstant(), + ] + ) + + before_mod = tvm.IRModule.from_expr(before()) + + with tvm.transform.PassContext(opt_level=3): + after_mod = passes(before_mod) + + assert tvm.ir.structural_equal(run_infer_type(before_mod["main"]), after_mod["main"]) + + if __name__ == "__main__": test_fold_const() test_fold_let() test_fold_tuple() test_fold_concat() test_fold_shape_of() - test_fold_full() test_fold_batch_norm() test_fold_ndarray_size() + test_fold_dropout()