From 092bc1ce7c3cbaca601ab2bb0a205501fd4c1f11 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 28 Jan 2021 13:29:59 -0700 Subject: [PATCH 1/6] DynamicToStatic Refactor --- src/relay/transforms/dynamic_to_static.cc | 91 ++++++++++++++--------- 1 file changed, 57 insertions(+), 34 deletions(-) diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index c580f60c2a68..bff333a79142 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -32,13 +32,35 @@ namespace tvm { namespace relay { +Expr PrepareInput(const Expr& expr) { + auto mod = IRModule::FromExpr(expr); + mod = transform::FoldConstant()(mod); + mod = transform::InferType()(mod); + mod = transform::FoldConstant()(mod); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main"); + } else { + return mod->Lookup("main").as()->body; + } +} + +std::vector PrepareArgs(const CallNode* call_node) { + std::vector args; + for (auto arg : call_node->args) { + args.emplace_back(PrepareInput(arg)); + } + return args; +} + class DynamicToStaticMutator : public MixedModeMutator { public: DynamicToStaticMutator() { op_map_ = { {Op::Get("dyn.reshape"), [](const CallNode* call_node) { - if (const ConstantNode* shape = call_node->args[1].as()) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* shape = args[1].as()) { ICHECK_EQ(shape->data->ndim, 1); return MakeReshape(call_node->args[0], ToVector(shape->data)); } @@ -46,7 +68,8 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.tile"), [](const CallNode* call_node) { - if (const ConstantNode* reps = call_node->args[1].as()) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* reps = args[1].as()) { ICHECK_EQ(reps->data->ndim, 1); return MakeTile(call_node->args[0], ToVector(reps->data)); } @@ -54,7 +77,8 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.topk"), [](const CallNode* call_node) { - if (const ConstantNode* k = call_node->args[1].as()) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* k = args[1].as()) { const TopKAttrs* param = call_node->attrs.as(); ICHECK(param); return MakeTopK(call_node->args[0], static_cast(ToScalar(k->data, 0)), @@ -64,7 +88,8 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.broadcast_to"), [](const CallNode* call_node) { - if (const ConstantNode* shape = call_node->args[1].as()) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* shape = args[1].as()) { ICHECK_EQ(shape->data->ndim, 1); return MakeBroadCastTo(call_node->args[0], ToVector(shape->data)); } @@ -72,7 +97,8 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.zeros"), [](const CallNode* call_node) { - if (const ConstantNode* shape = call_node->args[0].as()) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* shape = args[0].as()) { const InitOpAttrs* param = call_node->attrs.as(); ICHECK(param); return MakeZeros(ToVector(shape->data), param->dtype); @@ -81,7 +107,8 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.ones"), [](const CallNode* call_node) { - if (const ConstantNode* shape = call_node->args[0].as()) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* shape = args[0].as()) { const InitOpAttrs* param = call_node->attrs.as(); ICHECK(param); return MakeOnes(ToVector(shape->data), param->dtype); @@ -90,7 +117,8 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.one_hot"), [](const CallNode* call_node) { - if (const ConstantNode* depth = call_node->args[3].as()) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* depth = args[3].as()) { const OneHotAttrs* param = call_node->attrs.as(); ICHECK(param); return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2], @@ -101,7 +129,8 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.image.resize"), [](const CallNode* call_node) { - if (const ConstantNode* size = call_node->args[1].as()) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* size = args[1].as()) { const ResizeAttrs* param = call_node->attrs.as(); ICHECK(param); auto size_int = ToVector(size->data); @@ -116,7 +145,8 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.full"), [](const CallNode* call_node) { - if (const ConstantNode* shape = call_node->args[1].as()) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* shape = args[1].as()) { ICHECK_EQ(shape->data->ndim, 1); const InitOpAttrs* param = call_node->attrs.as(); ICHECK(param); @@ -126,8 +156,9 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.nn.upsampling"), [](const CallNode* call_node) { - const ConstantNode* scale_h = call_node->args[1].as(); - const ConstantNode* scale_w = call_node->args[2].as(); + auto args = PrepareArgs(call_node); + const ConstantNode* scale_h = args[1].as(); + const ConstantNode* scale_w = args[2].as(); if (scale_h && scale_w) { ICHECK_EQ(scale_h->data->ndim, 0); ICHECK_EQ(scale_w->data->ndim, 0); @@ -141,9 +172,10 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.nn.upsampling3d"), [](const CallNode* call_node) { - const ConstantNode* scale_d = call_node->args[1].as(); - const ConstantNode* scale_h = call_node->args[2].as(); - const ConstantNode* scale_w = call_node->args[3].as(); + auto args = PrepareArgs(call_node); + const ConstantNode* scale_d = args[1].as(); + const ConstantNode* scale_h = args[2].as(); + const ConstantNode* scale_w = args[3].as(); if (scale_d && scale_h && scale_w) { ICHECK_EQ(scale_d->data->ndim, 0); ICHECK_EQ(scale_h->data->ndim, 0); @@ -160,8 +192,9 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.nn.pad"), [](const CallNode* call_node) { - const ConstantNode* pad_width = call_node->args[1].as(); - const ConstantNode* pad_fill = call_node->args[2].as(); + auto args = PrepareArgs(call_node); + const ConstantNode* pad_width = args[1].as(); + const ConstantNode* pad_fill = args[2].as(); if (pad_width && pad_fill) { ICHECK_EQ(pad_fill->data->ndim, 0); // pad_val is 1d ICHECK_EQ(pad_width->data->ndim, 2); // pad_width is 2d @@ -175,9 +208,10 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.strided_slice"), [](const CallNode* call_node) { - const ConstantNode* begin = call_node->args[1].as(); - const ConstantNode* end = call_node->args[2].as(); - const ConstantNode* stride = call_node->args[3].as(); + auto args = PrepareArgs(call_node); + const ConstantNode* begin = args[1].as(); + const ConstantNode* end = args[2].as(); + const ConstantNode* stride = args[3].as(); if (begin && end && stride) { ICHECK_EQ(begin->data->ndim, 1); ICHECK_EQ(end->data->ndim, 1); @@ -222,6 +256,7 @@ class DynamicToStaticMutator : public MixedModeMutator { } return post; } + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> op_map_; }; @@ -229,27 +264,15 @@ class DynamicToStaticMutator : public MixedModeMutator { Expr DynamicToStatic(Function f, IRModule m) { Expr pre = f; Expr expr = f; - auto fold_const = transform::FoldConstant(); - auto infer_type = transform::InferType(); DynamicToStaticMutator mutator; Map vars; for (auto kv : m->functions) { vars.Set(kv.second, kv.first); } const auto gv = vars[f]; - // Put a limit on the while loop - // Primarily used to prevent accidental infinite lops in development - const int loop_limit = 1000; - int i = 0; - do { - pre = expr; - // TODO(mbrookhart): Is it possible to run these passes JUST on the current function? - m = infer_type(m); - m = fold_const(m); - expr = mutator.Mutate(m->functions[gv]); - m->Update(gv, Downcast(expr)); - i += 1; - } while (!StructuralEqual()(pre, expr) && i < loop_limit); + pre = expr; + expr = mutator.Mutate(m->functions[gv]); + expr = PrepareInput(expr); return expr; } From b39623011da6e46e3238d3b6839091810688ac2a Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 28 Jan 2021 16:01:19 -0700 Subject: [PATCH 2/6] fix test --- tests/python/relay/test_pass_dynamic_to_static.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 141023d77019..b2ed160e24d5 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -232,11 +232,10 @@ def verify_ones_zeros(shape, dtype): func = run_infer_type(relay.Function([x], y)) func2 = run_opt_pass( - run_opt_pass(func, transform.DynamicToStatic()), transform.InferType() + run_opt_pass(func, transform.DynamicToStatic()), transform.InferType(), ) zz = func2.body - assert isinstance(zz, relay.Constant) assert zz.checked_type == relay.ty.TensorType(shape, dtype) x_data = np.random.uniform(low=1, high=1, size=shape) From 4d40a9f68765ee3c18456d88a28908961b50b02d Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 28 Jan 2021 16:17:37 -0700 Subject: [PATCH 3/6] add regression tests --- .../relay/test_pass_dynamic_to_static.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index b2ed160e24d5..c9e047a38540 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -232,7 +232,8 @@ def verify_ones_zeros(shape, dtype): func = run_infer_type(relay.Function([x], y)) func2 = run_opt_pass( - run_opt_pass(func, transform.DynamicToStatic()), transform.InferType(), + run_opt_pass(func, transform.DynamicToStatic()), + transform.InferType(), ) zz = func2.body @@ -517,5 +518,45 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified +@tvm.testing.uses_gpu +def test_dynamic_to_static_dynamic_rank(): + def verify_full(fill_value, fill_shape, dtype): + x = relay.var("x", relay.scalar_type(dtype)) + y = relay.var("y", relay.TensorType(fill_shape, "int64")) + shape = relay.shape_of(y) + shape = relay.strided_slice(shape, [0], relay.shape_of(shape)) + z = relay.full(x, shape, dtype) + + func = relay.Function([x, y], z) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("full") + + ref_res = np.full(fill_shape, fill_value).astype(dtype) + y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype("int64") + verify_func(func2, [fill_value, y_data], ref_res) + + verify_full(4, (1, 2, 3, 4), "int32") + verify_full(4.0, (1, 2, 8, 10), "float32") + + +@tvm.testing.uses_gpu +def test_dynamic_to_static_dynamic_if(): + x = relay.var("x", relay.TensorType((2, 2), "int64")) + cond = relay.const(1) + iff = relay.If(cond, relay.reshape(x, [1, 4]), relay.reshape(x, (4, 1))) + + func = relay.Function([x], iff) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("reshape") + x_data = np.random.uniform(low=-1, high=1, size=(2, 2)).astype("int64") + verify_func(func2, [x_data], x_data.reshape(1, 4)) + + if __name__ == "__main__": pytest.main([__file__]) From 2dc21ad91c524a0602ce038678005d1a3aa58b88 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 28 Jan 2021 16:20:41 -0700 Subject: [PATCH 4/6] cleanup --- src/relay/transforms/dynamic_to_static.cc | 28 +++++++++-------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index bff333a79142..97e895a6a797 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -33,7 +33,11 @@ namespace tvm { namespace relay { Expr PrepareInput(const Expr& expr) { + // TODO(mbrookhart): Rewrite this to use increment type inference + // when that feature is available auto mod = IRModule::FromExpr(expr); + // Perform FoldConstant->InferType twice due to nested control + // flow/dynamic rank issues in certain object models mod = transform::FoldConstant()(mod); mod = transform::InferType()(mod); mod = transform::FoldConstant()(mod); @@ -46,11 +50,11 @@ Expr PrepareInput(const Expr& expr) { } std::vector PrepareArgs(const CallNode* call_node) { - std::vector args; - for (auto arg : call_node->args) { - args.emplace_back(PrepareInput(arg)); - } - return args; + std::vector args; + for (auto arg : call_node->args) { + args.emplace_back(PrepareInput(arg)); + } + return args; } class DynamicToStaticMutator : public MixedModeMutator { @@ -262,18 +266,8 @@ class DynamicToStaticMutator : public MixedModeMutator { }; Expr DynamicToStatic(Function f, IRModule m) { - Expr pre = f; - Expr expr = f; - DynamicToStaticMutator mutator; - Map vars; - for (auto kv : m->functions) { - vars.Set(kv.second, kv.first); - } - const auto gv = vars[f]; - pre = expr; - expr = mutator.Mutate(m->functions[gv]); - expr = PrepareInput(expr); - return expr; + Expr expr = DynamicToStaticMutator().Mutate(f); + return PrepareInput(expr); } namespace transform { From 7a33f9de340b193fee7188a4f55fcbaf7cc6df6a Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 29 Jan 2021 09:27:44 -0700 Subject: [PATCH 5/6] skip PrepareInput if the arg is already a constant --- src/relay/transforms/dynamic_to_static.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 97e895a6a797..36156e0ee986 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -52,7 +52,11 @@ Expr PrepareInput(const Expr& expr) { std::vector PrepareArgs(const CallNode* call_node) { std::vector args; for (auto arg : call_node->args) { - args.emplace_back(PrepareInput(arg)); + if (arg.as()) { + args.emplace_back(arg); + } else { + args.emplace_back(PrepareInput(arg)); + } } return args; } From 8dc9e3fe95897a5803005dc9d7c88f16dc4e96a3 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Mon, 1 Feb 2021 11:41:39 -0700 Subject: [PATCH 6/6] fix an issue with type inference with global functions --- src/relay/transforms/dynamic_to_static.cc | 110 +++++++++++++--------- 1 file changed, 63 insertions(+), 47 deletions(-) diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 36156e0ee986..815e4d224cc5 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -32,41 +32,12 @@ namespace tvm { namespace relay { -Expr PrepareInput(const Expr& expr) { - // TODO(mbrookhart): Rewrite this to use increment type inference - // when that feature is available - auto mod = IRModule::FromExpr(expr); - // Perform FoldConstant->InferType twice due to nested control - // flow/dynamic rank issues in certain object models - mod = transform::FoldConstant()(mod); - mod = transform::InferType()(mod); - mod = transform::FoldConstant()(mod); - mod = transform::InferType()(mod); - if (expr.as()) { - return mod->Lookup("main"); - } else { - return mod->Lookup("main").as()->body; - } -} - -std::vector PrepareArgs(const CallNode* call_node) { - std::vector args; - for (auto arg : call_node->args) { - if (arg.as()) { - args.emplace_back(arg); - } else { - args.emplace_back(PrepareInput(arg)); - } - } - return args; -} - class DynamicToStaticMutator : public MixedModeMutator { public: - DynamicToStaticMutator() { + DynamicToStaticMutator(IRModule mod, Function func) : mod_(mod), func_(func) { op_map_ = { {Op::Get("dyn.reshape"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* shape = args[1].as()) { ICHECK_EQ(shape->data->ndim, 1); @@ -75,7 +46,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.tile"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* reps = args[1].as()) { ICHECK_EQ(reps->data->ndim, 1); @@ -84,7 +55,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.topk"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* k = args[1].as()) { const TopKAttrs* param = call_node->attrs.as(); @@ -95,7 +66,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.broadcast_to"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* shape = args[1].as()) { ICHECK_EQ(shape->data->ndim, 1); @@ -104,7 +75,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.zeros"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* shape = args[0].as()) { const InitOpAttrs* param = call_node->attrs.as(); @@ -114,7 +85,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.ones"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* shape = args[0].as()) { const InitOpAttrs* param = call_node->attrs.as(); @@ -124,7 +95,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.one_hot"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* depth = args[3].as()) { const OneHotAttrs* param = call_node->attrs.as(); @@ -136,7 +107,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.image.resize"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* size = args[1].as()) { const ResizeAttrs* param = call_node->attrs.as(); @@ -152,7 +123,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.full"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* shape = args[1].as()) { ICHECK_EQ(shape->data->ndim, 1); @@ -163,7 +134,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.nn.upsampling"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); const ConstantNode* scale_h = args[1].as(); const ConstantNode* scale_w = args[2].as(); @@ -179,7 +150,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.nn.upsampling3d"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); const ConstantNode* scale_d = args[1].as(); const ConstantNode* scale_h = args[2].as(); @@ -199,7 +170,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.nn.pad"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); const ConstantNode* pad_width = args[1].as(); const ConstantNode* pad_fill = args[2].as(); @@ -215,7 +186,7 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.strided_slice"), - [](const CallNode* call_node) { + [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); const ConstantNode* begin = args[1].as(); const ConstantNode* end = args[2].as(); @@ -232,8 +203,9 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, {Op::Get("dyn.sparse_to_dense"), - [](const CallNode* call_node) { - const ConstantNode* output_shape = call_node->args[3].as(); + [this](const CallNode* call_node) { + auto args = PrepareArgs(call_node); + const ConstantNode* output_shape = args[3].as(); if (output_shape) { ICHECK_EQ(output_shape->data->ndim, 1); return MakeSparseToDense(call_node->args[0], ToVector(output_shape->data), @@ -242,6 +214,45 @@ class DynamicToStaticMutator : public MixedModeMutator { return Expr(nullptr); }}, }; + Map vars; + for (auto kv : mod_->functions) { + vars.Set(kv.second, kv.first); + } + gv_ = vars[func_]; + } + + Expr PrepareInput(const Expr& expr) { + BaseFunc func; + if (auto* func_node = expr.as()) { + func = GetRef(func_node); + } else { + func = + relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {}); + } + mod_->Update(gv_, func); + mod_ = transform::FoldConstant()(mod_); + mod_ = transform::InferType()(mod_); + mod_ = transform::FoldConstant()(mod_); + mod_ = transform::InferType()(mod_); + Expr out; + if (expr.as()) { + out = mod_->Lookup(gv_); + } else { + out = mod_->Lookup(gv_).as()->body; + } + return out; + } + + std::vector PrepareArgs(const CallNode* call_node) { + std::vector args; + for (auto arg : call_node->args) { + if (arg.as()) { + args.emplace_back(arg); + } else { + args.emplace_back(PrepareInput(arg)); + } + } + return args; } private: @@ -267,11 +278,16 @@ class DynamicToStaticMutator : public MixedModeMutator { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> op_map_; + IRModule mod_; + Function func_; + GlobalVar gv_; }; Expr DynamicToStatic(Function f, IRModule m) { - Expr expr = DynamicToStaticMutator().Mutate(f); - return PrepareInput(expr); + DynamicToStaticMutator mutator(m, f); + Expr expr = mutator.Mutate(f); + Expr out = mutator.PrepareInput(expr); + return out; } namespace transform {