diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 3d9fb5891c2a9..60f643e436b2d 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -100,7 +100,7 @@ Expr MakeCombinedConv2D(const Expr& data, const std::vector& co Expr CombineParallelConv2D(const Expr& expr) { // data -> array of conv2d with the same input auto children_map = SiblingConv2DFinder().Find(expr); - Map subst_map; + std::unordered_map subst_map; for (const auto& pair : children_map) { Expr data = pair.first; @@ -150,7 +150,7 @@ Expr CombineParallelConv2D(const Expr& expr) { CHECK_NE(channel_index, std::string::npos); auto take = MakeTake(new_conv2d, indices, channel_index); start += *channels; - subst_map.Set(GetRef(conv2d), take); + subst_map[GetRef(conv2d)] = take; } } } diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc index bac66bc0acf1c..586f748abef57 100644 --- a/src/relay/pass/expr_subst.cc +++ b/src/relay/pass/expr_subst.cc @@ -12,7 +12,8 @@ namespace relay { class ExprSubstituter : public ExprMutator { public: - explicit ExprSubstituter(tvm::Map subst_map) : subst_map_(subst_map) {} + explicit ExprSubstituter(std::unordered_map subst_map) + : subst_map_(subst_map) {} Expr VisitExpr(const Expr& expr) final { auto it = subst_map_.find(expr); @@ -26,7 +27,7 @@ class ExprSubstituter : public ExprMutator { tvm::Map subst_map_; }; -Expr ExprSubst(const Expr& expr, tvm::Map subst_map) { +Expr ExprSubst(const Expr& expr, std::unordered_map subst_map) { return ExprSubstituter(std::move(subst_map)).Mutate(expr); } diff --git a/src/relay/pass/expr_subst.h b/src/relay/pass/expr_subst.h index 02f4179dae66e..67892b3a0af7d 100644 --- a/src/relay/pass/expr_subst.h +++ b/src/relay/pass/expr_subst.h @@ -6,11 +6,12 @@ #ifndef TVM_RELAY_PASS_EXPR_SUBST_H_ #define TVM_RELAY_PASS_EXPR_SUBST_H_ #include +#include namespace tvm { namespace relay { -Expr ExprSubst(const Expr& expr, tvm::Map subst_map); +Expr ExprSubst(const Expr& expr, std::unordered_map subst_map); } // namespace relay } // namespace tvm