Skip to content

Commit

Permalink
Fix foldconstant involving dropout (#7550)
Browse files Browse the repository at this point in the history
Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
masahi and masa authored Mar 1, 2021
1 parent 2673309 commit 51dc332
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,8 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input
.add_argument("data", "Tensor", "Input to which dropout will be applied.")
.set_support_level(1)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.add_type_rel("Dropout", DropoutRel);
.add_type_rel("Dropout", DropoutRel)
.set_attr<TOpIsStateful>("TOpIsStateful", true);

// batch_norm
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
Expand Down
26 changes: 24 additions & 2 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
Expand Down Expand Up @@ -276,12 +275,35 @@ def initializer(_, param):
assert tvm.ir.structural_equal(mod["main"], expect)


def test_fold_dropout():
def before():
# A constant graph to fire fold constant
data = relay.const(np.arange(10).astype(np.float32))
dropout = relay.nn.dropout(data)
add = dropout + relay.const(1.0)
return relay.Function(relay.analysis.free_vars(add), add)

passes = tvm.transform.Sequential(
[
relay.transform.InferType(),
relay.transform.FoldConstant(),
]
)

before_mod = tvm.IRModule.from_expr(before())

with tvm.transform.PassContext(opt_level=3):
after_mod = passes(before_mod)

assert tvm.ir.structural_equal(run_infer_type(before_mod["main"]), after_mod["main"])


if __name__ == "__main__":
test_fold_const()
test_fold_let()
test_fold_tuple()
test_fold_concat()
test_fold_shape_of()
test_fold_full()
test_fold_batch_norm()
test_fold_ndarray_size()
test_fold_dropout()

0 comments on commit 51dc332

Please sign in to comment.