Skip to content

Commit 16723ba

Browse files
committed
[REFACTOR] Phase out LegacyReprPrinter and improve CommonSubExprElim
This PR phases out LegacyReprPrinter. Previously common subexpr elim relies on sorting on legacy repr for determism, which is hacky. This PR introduces an ordered_map impl in support to ensure determinism and migrates the CSE pass to use that instead.
1 parent 2d63574 commit 16723ba

21 files changed

+246
-1121
lines changed

include/tvm/node/repr_printer.h

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,32 +52,6 @@ class ReprPrinter {
5252
TVM_DLL static FType& vtable();
5353
};
5454

55-
/*! \brief Legacy behavior of ReprPrinter. */
56-
class ReprLegacyPrinter {
57-
public:
58-
/*! \brief The indentation level. */
59-
int indent{0};
60-
61-
explicit ReprLegacyPrinter(std::ostream& stream) // NOLINT(*)
62-
: stream(stream) {}
63-
64-
/*! \brief The node to be printed. */
65-
TVM_DLL void Print(const ObjectRef& node);
66-
/*! \brief Print indent to the stream */
67-
TVM_DLL void PrintIndent();
68-
/*! \brief Could the LegacyPrinter dispatch the node */
69-
TVM_DLL static bool CanDispatch(const ObjectRef& node);
70-
/*! \brief Return the ostream it maintains */
71-
TVM_DLL std::ostream& Stream() const;
72-
// Allow registration to be printer.
73-
using FType = NodeFunctor<void(const ObjectRef&, ReprLegacyPrinter*)>;
74-
TVM_DLL static FType& vtable();
75-
76-
private:
77-
/*! \brief The output stream */
78-
std::ostream& stream;
79-
};
80-
8155
/*!
8256
* \brief Dump the node to stderr, used for debug purposes.
8357
* \param node The input node
@@ -113,12 +87,6 @@ inline std::ostream& operator<<(std::ostream& os, const Variant<V...>& n) { //
11387
return os;
11488
}
11589

116-
inline std::string AsLegacyRepr(const ObjectRef& n) {
117-
std::ostringstream os;
118-
ReprLegacyPrinter(os).Print(n);
119-
return os.str();
120-
}
12190
} // namespace ffi
122-
using ffi::AsLegacyRepr;
12391
} // namespace tvm
12492
#endif // TVM_NODE_REPR_PRINTER_H_

src/ir/analysis.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace ir {
3131
Map<GlobalVar, Array<GlobalVar>> CollectCallMap(const IRModule& mod) {
3232
struct CalleeCollectorImpl : CalleeCollector {
3333
void Mark(GlobalVar gvar) override { gvars.push_back(gvar); }
34-
support::OrderedSet<GlobalVar> gvars;
34+
support::OrderedSet<GlobalVar, ObjectPtrHash, ObjectPtrEqual> gvars;
3535
};
3636

3737
Map<GlobalVar, Array<GlobalVar>> call_map;

src/node/repr_printer.cc

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -97,38 +97,6 @@ ReprPrinter::FType& ReprPrinter::vtable() {
9797
return inst;
9898
}
9999

100-
void ReprLegacyPrinter::Print(const ObjectRef& node) {
101-
static const FType& f = vtable();
102-
if (!node.defined()) {
103-
stream << "(nullptr)";
104-
} else if (f.can_dispatch(node)) {
105-
f(node, this);
106-
} else {
107-
try {
108-
stream << node; // Use ReprPrinter
109-
} catch (const tvm::Error& e) {
110-
LOG(WARNING) << "ReprPrinter fails";
111-
stream << node->GetTypeKey() << '(' << node.get() << ')';
112-
}
113-
}
114-
}
115-
116-
bool ReprLegacyPrinter::CanDispatch(const ObjectRef& node) {
117-
static const FType& f = vtable();
118-
return !node.defined() || f.can_dispatch(node);
119-
}
120-
121-
void ReprLegacyPrinter::PrintIndent() {
122-
for (int i = 0; i < indent; ++i) {
123-
stream << ' ';
124-
}
125-
}
126-
127-
ReprLegacyPrinter::FType& ReprLegacyPrinter::vtable() {
128-
static FType inst;
129-
return inst;
130-
}
131-
132100
void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; }
133101

134102
void Dump(const runtime::Object* n) { Dump(runtime::GetRef<runtime::ObjectRef>(n)); }
@@ -138,7 +106,4 @@ TVM_FFI_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](ffi::Any obj) {
138106
os << obj;
139107
return os.str();
140108
});
141-
142-
TVM_FFI_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(ffi::AsLegacyRepr);
143-
144109
} // namespace tvm

src/node/script_printer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
3232

3333
std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional<PrinterConfig>& cfg) {
3434
if (!TVMScriptPrinter::vtable().can_dispatch(node)) {
35-
return AsLegacyRepr(node);
35+
TVM_FFI_THROW(RuntimeError) << "TVMScriptPrinter cannot dispatch node: " << node.GetTypeKey();
3636
}
3737
return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig()));
3838
}

src/relax/analysis/computable_at_compile_time.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class CompileTimeCollector : ExprVisitor {
8383
}
8484
}
8585

86-
support::OrderedSet<Var> known_relax_vars_;
86+
support::OrderedSet<Var, ObjectPtrHash, ObjectPtrEqual> known_relax_vars_;
8787
std::unordered_set<tir::Var> known_tir_vars_;
8888
};
8989
} // namespace

src/relax/analysis/udchain.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class UDChain : relax::ExprVisitor {
5656
private:
5757
Map<Var, Expr> bound_values;
5858
std::unordered_set<Var> forward_declarations;
59-
std::unordered_map<Var, support::OrderedSet<Var>> usage_map;
60-
support::OrderedSet<Var> outputs;
59+
std::unordered_map<Var, support::OrderedSet<Var, ObjectPtrHash, ObjectPtrEqual>> usage_map;
60+
support::OrderedSet<Var, ObjectPtrHash, ObjectPtrEqual> outputs;
6161

6262
Optional<Var> cur_user_;
6363

src/relax/ir/binding_rewrite.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,8 @@ Expr RemoveAllUnused(Expr expr) {
321321
auto var_usage = CollectVarUsage(expr);
322322

323323
// For the purpose of
324-
support::OrderedSet<Var> externally_exposed(var_usage.outputs.begin(), var_usage.outputs.end());
324+
support::OrderedSet<Var, ObjectPtrHash, ObjectPtrEqual> externally_exposed(
325+
var_usage.outputs.begin(), var_usage.outputs.end());
325326
for (const auto& [var, expr] : var_usage.bound_values) {
326327
if (ContainsImpureCall(expr)) {
327328
externally_exposed.insert(var);

src/relax/transform/inline_functions.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class FunctionInliner : public ExprMutator {
138138
}
139139

140140
const Map<Variant<String, GlobalVar>, Function>& replacements_;
141-
support::OrderedSet<GlobalVar> inline_stack_;
141+
std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> inline_stack_;
142142
};
143143
} // namespace
144144

src/relax/transform/run_codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class CodeGenRunner : ExprMutator {
4444
Array<String> entry_function_names) {
4545
IRModule mod = builder_->GetContextIRModule();
4646

47-
support::OrderedSet<GlobalVar> entry_functions;
47+
support::OrderedSet<GlobalVar, ObjectPtrHash, ObjectPtrEqual> entry_functions;
4848
// Any user-provided functions are treated as entry functions.
4949
for (const auto& name : entry_function_names) {
5050
entry_functions.insert(mod->GetGlobalVar(name));

0 commit comments

Comments
 (0)