Skip to content

Commit

Permalink
Use unordered_map
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Nov 13, 2018
1 parent 853f571 commit f4b859f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Expr MakeCombinedConv2D(const Expr& data, const std::vector<const CallNode*>& co
Expr CombineParallelConv2D(const Expr& expr) {
// data -> array of conv2d with the same input
auto children_map = SiblingConv2DFinder().Find(expr);
Map<Expr, Expr> subst_map;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map;

for (const auto& pair : children_map) {
Expr data = pair.first;
Expand Down Expand Up @@ -150,7 +150,7 @@ Expr CombineParallelConv2D(const Expr& expr) {
CHECK_NE(channel_index, std::string::npos);
auto take = MakeTake(new_conv2d, indices, channel_index);
start += *channels;
subst_map.Set(GetRef<Call>(conv2d), take);
subst_map[GetRef<Call>(conv2d)] = take;
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/relay/pass/expr_subst.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace relay {

class ExprSubstituter : public ExprMutator {
public:
explicit ExprSubstituter(tvm::Map<Expr, Expr> subst_map) : subst_map_(subst_map) {}
explicit ExprSubstituter(std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map)
: subst_map_(subst_map) {}

Expr VisitExpr(const Expr& expr) final {
auto it = subst_map_.find(expr);
Expand All @@ -26,7 +27,7 @@ class ExprSubstituter : public ExprMutator {
tvm::Map<Expr, Expr> subst_map_;
};

Expr ExprSubst(const Expr& expr, tvm::Map<Expr, Expr> subst_map) {
Expr ExprSubst(const Expr& expr, std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map) {
return ExprSubstituter(std::move(subst_map)).Mutate(expr);
}

Expand Down
3 changes: 2 additions & 1 deletion src/relay/pass/expr_subst.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
#ifndef TVM_RELAY_PASS_EXPR_SUBST_H_
#define TVM_RELAY_PASS_EXPR_SUBST_H_
#include <tvm/relay/expr.h>
#include <unordered_map>

namespace tvm {
namespace relay {

Expr ExprSubst(const Expr& expr, tvm::Map<Expr, Expr> subst_map);
Expr ExprSubst(const Expr& expr, std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map);

} // namespace relay
} // namespace tvm
Expand Down

0 comments on commit f4b859f

Please sign in to comment.