diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 30bfe8e95193..e3baf397f25f 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -52,32 +52,6 @@ class ReprPrinter { TVM_DLL static FType& vtable(); }; -/*! \brief Legacy behavior of ReprPrinter. */ -class ReprLegacyPrinter { - public: - /*! \brief The indentation level. */ - int indent{0}; - - explicit ReprLegacyPrinter(std::ostream& stream) // NOLINT(*) - : stream(stream) {} - - /*! \brief The node to be printed. */ - TVM_DLL void Print(const ObjectRef& node); - /*! \brief Print indent to the stream */ - TVM_DLL void PrintIndent(); - /*! \brief Could the LegacyPrinter dispatch the node */ - TVM_DLL static bool CanDispatch(const ObjectRef& node); - /*! \brief Return the ostream it maintains */ - TVM_DLL std::ostream& Stream() const; - // Allow registration to be printer. - using FType = NodeFunctor; - TVM_DLL static FType& vtable(); - - private: - /*! \brief The output stream */ - std::ostream& stream; -}; - /*! * \brief Dump the node to stderr, used for debug purposes. * \param node The input node @@ -113,12 +87,6 @@ inline std::ostream& operator<<(std::ostream& os, const Variant& n) { // return os; } -inline std::string AsLegacyRepr(const ObjectRef& n) { - std::ostringstream os; - ReprLegacyPrinter(os).Print(n); - return os.str(); -} } // namespace ffi -using ffi::AsLegacyRepr; } // namespace tvm #endif // TVM_NODE_REPR_PRINTER_H_ diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc index 3a54085c2290..26a348bceee1 100644 --- a/src/ir/analysis.cc +++ b/src/ir/analysis.cc @@ -31,7 +31,7 @@ namespace ir { Map> CollectCallMap(const IRModule& mod) { struct CalleeCollectorImpl : CalleeCollector { void Mark(GlobalVar gvar) override { gvars.push_back(gvar); } - support::OrderedSet gvars; + support::OrderedSet gvars; }; Map> call_map; diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index aa999655c03d..69cb05c12106 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -97,38 +97,6 @@ ReprPrinter::FType& ReprPrinter::vtable() { return inst; } -void ReprLegacyPrinter::Print(const ObjectRef& node) { - static const FType& f = vtable(); - if (!node.defined()) { - stream << "(nullptr)"; - } else if (f.can_dispatch(node)) { - f(node, this); - } else { - try { - stream << node; // Use ReprPrinter - } catch (const tvm::Error& e) { - LOG(WARNING) << "ReprPrinter fails"; - stream << node->GetTypeKey() << '(' << node.get() << ')'; - } - } -} - -bool ReprLegacyPrinter::CanDispatch(const ObjectRef& node) { - static const FType& f = vtable(); - return !node.defined() || f.can_dispatch(node); -} - -void ReprLegacyPrinter::PrintIndent() { - for (int i = 0; i < indent; ++i) { - stream << ' '; - } -} - -ReprLegacyPrinter::FType& ReprLegacyPrinter::vtable() { - static FType inst; - return inst; -} - void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } @@ -138,7 +106,4 @@ TVM_FFI_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](ffi::Any obj) { os << obj; return os.str(); }); - -TVM_FFI_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(ffi::AsLegacyRepr); - } // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index ee7880f4485a..c81543579655 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -32,7 +32,10 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional& cfg) { if (!TVMScriptPrinter::vtable().can_dispatch(node)) { - return AsLegacyRepr(node); + std::ostringstream os; + ReprPrinter printer(os); + printer.Print(node); + return os.str(); } return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig())); } diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index ba163b51d6c9..5825895db7d6 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -83,7 +83,7 @@ class CompileTimeCollector : ExprVisitor { } } - support::OrderedSet known_relax_vars_; + support::OrderedSet known_relax_vars_; std::unordered_set known_tir_vars_; }; } // namespace diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index f62254b6959d..2f04d8659405 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -56,8 +56,8 @@ class UDChain : relax::ExprVisitor { private: Map bound_values; std::unordered_set forward_declarations; - std::unordered_map> usage_map; - support::OrderedSet outputs; + std::unordered_map> usage_map; + support::OrderedSet outputs; Optional cur_user_; diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index f35b443b5b39..11a0fd29a92f 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -321,7 +321,8 @@ Expr RemoveAllUnused(Expr expr) { auto var_usage = CollectVarUsage(expr); // For the purpose of - support::OrderedSet externally_exposed(var_usage.outputs.begin(), var_usage.outputs.end()); + support::OrderedSet externally_exposed( + var_usage.outputs.begin(), var_usage.outputs.end()); for (const auto& [var, expr] : var_usage.bound_values) { if (ContainsImpureCall(expr)) { externally_exposed.insert(var); diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index 26b106373ff0..e295226e9e72 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -138,7 +138,7 @@ class FunctionInliner : public ExprMutator { } const Map, Function>& replacements_; - support::OrderedSet inline_stack_; + std::unordered_set inline_stack_; }; } // namespace diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index d29bdaacb9b0..33d3f485a5e0 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -44,7 +44,7 @@ class CodeGenRunner : ExprMutator { Array entry_function_names) { IRModule mod = builder_->GetContextIRModule(); - support::OrderedSet entry_functions; + support::OrderedSet entry_functions; // Any user-provided functions are treated as entry functions. for (const auto& name : entry_function_names) { entry_functions.insert(mod->GetGlobalVar(name)); diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc deleted file mode 100644 index 57dd691b8897..000000000000 --- a/src/script/printer/legacy_repr.cc +++ /dev/null @@ -1,894 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include - -#include "../../support/str_escape.h" - -namespace tvm { - -#define TVM_LEGACY_REPR_PRINTER_DEF_OP(Type) \ - ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, Type value) { \ - p.Stream() << value; \ - return p; \ - } - -TVM_LEGACY_REPR_PRINTER_DEF_OP(int); -TVM_LEGACY_REPR_PRINTER_DEF_OP(int64_t); -TVM_LEGACY_REPR_PRINTER_DEF_OP(float); -TVM_LEGACY_REPR_PRINTER_DEF_OP(double); -TVM_LEGACY_REPR_PRINTER_DEF_OP(char); -TVM_LEGACY_REPR_PRINTER_DEF_OP(const char*); -TVM_LEGACY_REPR_PRINTER_DEF_OP(const std::string&); -TVM_LEGACY_REPR_PRINTER_DEF_OP(runtime::DataType); -TVM_LEGACY_REPR_PRINTER_DEF_OP(const void*); -TVM_LEGACY_REPR_PRINTER_DEF_OP(const String&); - -std::ostream& ReprLegacyPrinter::Stream() const { return stream; } - -ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, const ObjectRef& value) { - p.Stream() << AsLegacyRepr(value); - return p; -} - -ReprLegacyPrinter& operator<<(ReprLegacyPrinter& out, tir::ForKind type) { // NOLINT(*) - using tvm::tir::ForKind; - switch (type) { - case ForKind::kSerial: - out << "for"; - break; - case ForKind::kParallel: - out << "parallel"; - break; - case ForKind::kUnrolled: - out << "unrolled"; - break; - case ForKind::kVectorized: - out << "vectorized"; - break; - case ForKind::kThreadBinding: - out << "launch_thread"; - break; - } - return out; -} - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '['; - for (size_t i = 0; i < op->size(); ++i) { - if (i != 0) { - (*p) << ", "; - } - p->Print(op->at(i).cast()); - } - (*p) << ']'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '{'; - for (auto it = op->begin(); it != op->end(); ++it) { - if (it != op->begin()) { - (*p) << ", "; - } - if (it->first.as()) { - (*p) << '\"' << Downcast(it->first) << "\": "; - } else { - p->Print(it->first.cast()); - (*p) << ": "; - } - p->Print(it->second.cast()); - } - (*p) << '}'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '['; - for (size_t i = 0; i < op->size; ++i) { - if (i != 0) { - (*p) << ", "; - } - (*p) << op->data[i]; - } - (*p) << ']'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - if (op->dtype == DataType::Int(32)) { - (*p) << op->value; - } else { - (*p) << "(" << op->dtype << ")" << op->value; - } - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - switch (op->dtype.bits()) { - case 64: - (*p) << op->value; - break; - case 32: - (*p) << op->value << 'f'; - break; - case 16: - (*p) << op->value << 'h'; - break; - default: - LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); - } - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "range(min=" << op->min << ", ext=" << op->extent << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << node->dtype; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - if (!node->storage_scope.empty()) { - (*p) << node->storage_scope << " "; - } - p->Print(node->element_type); - (*p) << '*'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "TupleTypeNode(" << node->fields << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->dict; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "GlobalVar(" << node->name_hint << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "IRModule(" << node->functions << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "FuncType(" << node->arg_types << ", " << node->ret_type << ")"; - }); - -} // namespace tvm - -namespace tvm { -namespace tir { - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "buffer(" << op->name << ", " << op << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - // omit the type - // stream << op->name << "." << op->type; - (*p) << op->name_hint; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "iter_var("; - if (op->var->name_hint.length() != 0) { - (*p) << op->var->name_hint << ", "; - } - if (op->dom.defined()) { - (*p) << op->dom; - } - if (op->thread_tag.length() != 0) { - (*p) << ", " << op->thread_tag; - } - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '\"' << support::StrEscape(op->value) << '\"'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->dtype << '('; - p->Print(op->value); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " + "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " - "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << "*"; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << "/"; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " % "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "floordiv(" << op->a << ", " << op->b << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "floormod(" << op->a << ", " << op->b << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "min("; - p->Print(op->a); - (*p) << ", "; - p->Print(op->b); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "max("; - p->Print(op->a); - (*p) << ", "; - p->Print(op->b); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " == "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " != "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " < "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " <= "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " > "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " >= "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " && "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " || "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '!'; - p->Print(op->a); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "select("; - p->Print(op->condition); - (*p) << ", "; - p->Print(op->true_value); - (*p) << ", "; - p->Print(op->false_value); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "ramp("; - p->Print(op->base); - (*p) << ", "; - p->Print(op->stride); - (*p) << ", " << op->lanes << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "x" << op->lanes << "("; - p->Print(op->value); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "(let " << op->var << " = "; - p->Print(op->value); - (*p) << " in "; - p->Print(op->body); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - if (auto* ptr_op = op->op.as()) { - (*p) << ptr_op->name << "("; - } else { - auto* ptr_gvar = op->op.as(); - ICHECK(ptr_gvar != nullptr); - (*p) << "@" << ptr_gvar->name_hint << "("; - } - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) { - (*p) << ", "; - } - } - (*p) << ")"; - }); - -template -void PrintList(const Array& exprs, ReprLegacyPrinter* p) { - for (size_t i = 0; i < exprs.size(); ++i) { - p->Print(exprs[i]); - if (i < exprs.size() - 1) { - (*p) << ", "; - } - } -} - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "shuffle("; - PrintList(op->vectors, p); - (*p) << ", "; - PrintList(op->indices, p); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs << ", rhs=" << op->rhs - << ", identity_element=" << op->identity_element << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "reduce(combiner=" << op->combiner; - (*p) << ", source=" << op->source; - (*p) << ", init=" << op->init; - (*p) << ", axis=" << op->axis; - (*p) << ", where=" << op->condition; - (*p) << ", value_index=" << op->value_index; - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) { - (*p) << ", "; - } - } - (*p) << "]"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->producer->GetNameHint() << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) { - (*p) << ", "; - } - } - (*p) << "]"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "PrimFunc(" << node->params << ") "; - if (node->attrs.defined()) { - (*p) << "attrs=" << node->attrs; - } - (*p) << " {\n"; - p->indent += 2; - p->Print(node->body); - p->indent -= 2; - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "let " << op->var << " = "; - p->Print(op->value); - (*p) << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "// attr ["; - p->Print(op->node); - (*p) << "] " << op->attr_key << " = "; - p->Print(op->value); - (*p) << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "assert("; - p->Print(op->condition); - (*p) << ", "; - p->Print(op->message); - (*p) << ")\n"; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << op->kind << " (" << op->loop_var << ", "; - p->Print(op->min); - (*p) << ", "; - p->Print(op->extent); - (*p) << ") {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "while(" << op->condition << ") {\n"; - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - const auto* ptr_type = op->buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - p->PrintIndent(); - (*p) << "allocate " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - (*p) << " * "; - p->Print(op->extents[i]); - } - (*p) << "], storage_scope = " << ptr_type->storage_scope; - if (!is_one(op->condition)) { - (*p) << " if "; - p->Print(op->condition); - } - (*p) << "\n"; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "constant " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - (*p) << " * "; - p->Print(op->extents[i]); - } - (*p) << "]"; - (*p) << "\n"; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "decl_buffer " << op->buffer << "\n"; - (*p) << op->body; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - for (Stmt stmt : op->seq) { - p->Print(stmt); - } - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - while (true) { - (*p) << "if (" << op->condition << ") {\n"; - p->indent += 2; - p->Print(op->then_case); - p->indent -= 2; - - if (!op->else_case) { - break; - } - - if (const IfThenElseNode* nested_if = op->else_case.as()) { - p->PrintIndent(); - (*p) << "} else "; - op = nested_if; - } else { - p->PrintIndent(); - (*p) << "} else {\n"; - p->indent += 2; - p->Print(op->else_case); - p->indent -= 2; - break; - } - } - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->Print(op->value); - (*p) << "\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) (*p) << ", "; - } - (*p) << "]"; - (*p) << " = "; - p->Print(op->value); - (*p) << '\n'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "buffer_realize " << op->buffer->name << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - (*p) << "["; - p->Print(op->bounds[i]->min); - (*p) << ", "; - p->Print(op->bounds[i]->extent); - (*p) << "]"; - if (i < op->bounds.size() - 1) (*p) << ", "; - } - (*p) << ")"; - if (!is_one(op->condition)) { - (*p) << " if "; - p->Print(op->condition); - } - (*p) << " {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->buffer->name; - (*p) << "["; - for (size_t i = 0; i < op->region.size(); ++i) { - const auto& range = op->region[i]; - p->Print(range->min); - if (!is_one(range->extent)) { - (*p) << ":"; - p->Print(range->min + range->extent); - } - if (i != op->region.size() - 1) (*p) << ", "; - } - (*p) << "]"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << op->buffer->name << " = match_buffer("; - p->Print(op->source); - (*p) << ")\n"; - }); - -void PrintBlockTitle(const BlockNode* op, ReprLegacyPrinter* p) { - (*p) << "block " << op->name_hint << "("; - for (size_t i = 0; i < op->iter_vars.size(); i++) { - p->Print(op->iter_vars[i]); - if (i < op->iter_vars.size() - 1) (*p) << ", "; - } - (*p) << ")"; -} - -void PrintBlockSignature(const BlockNode* op, ReprLegacyPrinter* p) { - // print read/write regions - p->PrintIndent(); - (*p) << "reads("; - p->Print(op->reads); - (*p) << ")\n"; - p->PrintIndent(); - (*p) << "writes("; - p->Print(op->writes); - (*p) << ")\n"; - // Print alloc_buffers - for (const auto& alloc_buf : op->alloc_buffers) { - p->PrintIndent(); - (*p) << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "["; - for (size_t i = 0; i < alloc_buf->shape.size(); ++i) { - if (i > 0) (*p) << ", "; - p->Print(alloc_buf->shape[i]); - } - (*p) << "])\n"; - } - // Print match_buffer_regions - for (const auto& match_buf : op->match_buffers) { - p->Print(match_buf); - } - if (!op->annotations.empty()) { - p->PrintIndent(); - (*p) << "annotations(" << op->annotations << ")\n"; - } -} - -void PrintBlockBody(const BlockNode* op, ReprLegacyPrinter* p) { - // Print init - if (op->init.defined()) { - p->PrintIndent(); - (*p) << "with init() {\n"; - p->indent += 2; - p->Print(op->init.value()); - p->indent -= 2; - p->PrintIndent(); - (*p) << "}\n"; - } - // Print body - p->Print(op->body); -} - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - PrintBlockTitle(op, p); - (*p) << " {\n"; - p->indent += 2; - - // Print block elements (e.g. reads/writes, etc) - PrintBlockSignature(op, p); - // Print block init and body - PrintBlockBody(op, p); - - p->indent -= 2; - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - auto* block_op = op->block.get(); - p->PrintIndent(); - PrintBlockTitle(block_op, p); - (*p) << " {\n"; - p->indent += 2; - - // Print binding iter_values - for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { - p->PrintIndent(); - (*p) << "bind("; - p->Print(block_op->iter_vars[i]->var); - (*p) << ", "; - p->Print(op->iter_values[i]); - (*p) << ")\n"; - } - // Print predicate - if (!is_one(op->predicate)) { - p->PrintIndent(); - (*p) << "where("; - p->Print(op->predicate); - (*p) << ")\n"; - } - // Print block elements (e.g. reads/writes, etc) - PrintBlockSignature(block_op, p); - // Print block init and body - PrintBlockBody(block_op, p); - - p->indent -= 2; - p->PrintIndent(); - (*p) << "}\n"; - }); - -} // namespace tir -} // namespace tvm diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 03341c4cd90f..95d24c91c41e 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -42,18 +42,8 @@ inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) { try { p->stream << TVMScriptPrinter::Script(obj, std::nullopt); } catch (const tvm::Error& e) { - if (ReprLegacyPrinter::CanDispatch(obj)) { - LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" - << e.what(); - try { - p->stream << AsLegacyRepr(obj); - } catch (const tvm::Error& e) { - LOG(WARNING) << "AsLegacyRepr fails. Falling back to the basic address printer"; - } - } else { - LOG(WARNING) << "TVMScript printer falls back to the basic address printer with the error:\n" - << e.what(); - } + LOG(WARNING) << "TVMScript printer falls back to the basic address printer with the error:\n" + << e.what(); p->stream << obj->GetTypeKey() << '(' << obj.get() << ')'; } } diff --git a/src/support/ordered_map.h b/src/support/ordered_map.h new file mode 100644 index 000000000000..81b0fd38a7a4 --- /dev/null +++ b/src/support/ordered_map.h @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file support/ordered_map.h + * \brief An STL-like map that preserves insertion order. + */ +#ifndef TVM_SUPPORT_ORDERED_MAP_H_ +#define TVM_SUPPORT_ORDERED_MAP_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace support { + +/** + * \brief An STL-like map that preserves insertion order. + * + * \tparam K The key type. + * \tparam V The value type. + * \tparam Hash The hash function. + * \tparam KeyEqual The key equality function. + * \note we don't support erase since it is less needed and vector backing is more efficient. + */ +template , + typename KeyEqual = std::equal_to> +class OrderedMap { + public: + OrderedMap() = default; + + /* \brief Explicit copy constructor + * + * The default copy constructor would copy both `elements_` and + * `elem_to_iter_`. While this is the correct behavior for + * `elements_`, the copy of `elem_to_iter_` would contain references + * to the original's `element_`, rather than to its own + */ + OrderedMap(const OrderedMap& other) : elements_(other.elements_) { + InitElementToIter(); + } + + /* \brief Explicit copy assignment + * + * Implemented in terms of the copy constructor, and the default + * move assignment. + */ + OrderedMap& operator=(const OrderedMap& other) { + return *this = OrderedMap(other); + } + + OrderedMap(OrderedMap&&) = default; + OrderedMap& operator=(OrderedMap&&) = default; + + template + OrderedMap(Iter begin, Iter end) : elements_(begin, end) { + InitElementToIter(); + } + + auto find(const K& k) { + auto it = elem_to_index_.find(k); + if (it != elem_to_index_.end()) { + return elements_.begin() + it->second; + } + return elements_.end(); + } + + auto find(const K& k) const { + auto it = elem_to_index_.find(k); + if (it != elem_to_index_.end()) { + return elements_.begin() + it->second; + } + return elements_.end(); + } + + V& operator[](const K& k) { + auto it = elem_to_index_.find(k); + if (it != elem_to_index_.end()) { + return elements_[it->second].second; + } + elements_.emplace_back(k, V()); + elem_to_index_[k] = elements_.size() - 1; + return elements_.back().second; + } + + void insert(const K& k, V v) { + auto it = elem_to_index_.find(k); + if (it != elem_to_index_.end()) { + elements_[it->second].second = std::move(v); + } else { + elements_.emplace_back(k, v); + elem_to_index_[k] = elements_.size() - 1; + } + } + + void clear() { + elements_.clear(); + elem_to_index_.clear(); + } + + size_t count(const K& k) const { return elem_to_index_.count(k); } + + auto begin() const { return elements_.begin(); } + auto end() const { return elements_.end(); } + auto begin() { return elements_.begin(); } + auto end() { return elements_.end(); } + + size_t size() const { return elements_.size(); } + bool empty() const { return elements_.empty(); } + + void reserve(size_t n) { elem_to_index_.reserve(n); } + + private: + void InitElementToIter() { + for (size_t i = 0; i < elements_.size(); i++) { + elem_to_index_[elements_[i].first] = i; + } + } + + std::vector> elements_; + std::unordered_map elem_to_index_; +}; + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_ORDERED_MAP_H_ diff --git a/src/support/ordered_set.h b/src/support/ordered_set.h index 11acb8c3fef5..169f738e700d 100644 --- a/src/support/ordered_set.h +++ b/src/support/ordered_set.h @@ -26,30 +26,14 @@ #include -#include +#include #include +#include namespace tvm { namespace support { -namespace detail { -/* \brief Utility to allow use for standard and ObjectRef types - * - * \tparam T The type held by the OrderedSet - */ -template -struct OrderedSetLookupType { - using MapType = std::unordered_map::iterator>; -}; - -template -struct OrderedSetLookupType>> { - using MapType = std::unordered_map::iterator, runtime::ObjectPtrHash, - runtime::ObjectPtrEqual>; -}; -} // namespace detail - -template +template , typename KeyEqual = std::equal_to> class OrderedSet { public: OrderedSet() = default; @@ -61,17 +45,21 @@ class OrderedSet { * `elements_`, the copy of `elem_to_iter_` would contain references * to the original's `element_`, rather than to its own */ - OrderedSet(const OrderedSet& other) : elements_(other.elements_) { InitElementToIter(); } + OrderedSet(const OrderedSet& other) : elements_(other.elements_) { + InitElementToIter(); + } /* \brief Explicit copy assignment * * Implemented in terms of the copy constructor, and the default * move assignment. */ - OrderedSet& operator=(const OrderedSet& other) { return *this = OrderedSet(other); } + OrderedSet& operator=(const OrderedSet& other) { + return *this = OrderedSet(other); + } - OrderedSet(OrderedSet&&) = default; - OrderedSet& operator=(OrderedSet&&) = default; + OrderedSet(OrderedSet&&) = default; + OrderedSet& operator=(OrderedSet&&) = default; template OrderedSet(Iter begin, Iter end) : elements_(begin, end) { @@ -79,27 +67,20 @@ class OrderedSet { } void push_back(const T& t) { - if (!elem_to_iter_.count(t)) { + if (!elem_to_index_.count(t)) { elements_.push_back(t); - elem_to_iter_[t] = std::prev(elements_.end()); + elem_to_index_[t] = elements_.size() - 1; } } void insert(const T& t) { push_back(t); } - void erase(const T& t) { - if (auto it = elem_to_iter_.find(t); it != elem_to_iter_.end()) { - elements_.erase(it->second); - elem_to_iter_.erase(it); - } - } - void clear() { elements_.clear(); - elem_to_iter_.clear(); + elem_to_index_.clear(); } - size_t count(const T& t) const { return elem_to_iter_.count(t); } + size_t count(const T& t) const { return elem_to_index_.count(t); } auto begin() const { return elements_.begin(); } auto end() const { return elements_.end(); } @@ -108,13 +89,13 @@ class OrderedSet { private: void InitElementToIter() { - for (auto it = elements_.begin(); it != elements_.end(); it++) { - elem_to_iter_[*it] = it; + for (size_t i = 0; i < elements_.size(); ++i) { + elem_to_index_[elements_[i]] = i; } } - std::list elements_; - typename detail::OrderedSetLookupType::MapType elem_to_iter_; + std::vector elements_; + std::unordered_map elem_to_index_; }; } // namespace support diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index 42409efb0bd1..3fd78a523301 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -43,8 +43,7 @@ #include // For the algorithm std::find #include #include -#include // For the hashtable datatype -#include // For std::pair and std::move +#include #include #include "../analysis/check_contains.h" // For the visitor CheckContains @@ -131,41 +130,24 @@ bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExp * they appeared in the hashtable was based on some runtime addresses, so it can potentially * change with every execution. */ -bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair a, - std::pair b) { +bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(const std::pair& a, + const std::pair& b) { size_t a_size = CalculateExprComplexity(a.first); size_t b_size = CalculateExprComplexity(b.first); - - // Criteria 1 - Size of the expression comes first - // `a` comes before `b` if the size of `a` is bigger - if (a_size > b_size) { - return true; - } - // `a` does NOT come before `b` if the size of `b` is bigger - if (b_size > a_size) { - return false; - } - - // Criteria 2 - If they had the same size, use the lexicographic order as a last resort - // as we need a deterministic order - std::stringstream a_stream; - std::stringstream b_stream; - a_stream << AsLegacyRepr(a.first); - b_stream << AsLegacyRepr(b.first); - return (a_stream.str().compare(b_stream.str()) < 0); + return a_size > b_size; } /*! - * \brief Generates a new fresh variable, whose name will be cse_var_i. + * \brief Generates a new fresh variable, whose name will be cse_vi. * \param type_annotation The type of the new variable to generate - * \return A new variable of type `type_annotation` called cse_var_i where i is the first available + * \return A new variable of type `type_annotation` called cse_vi where i is the first available integer. */ Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) { // Increase `num_last_try_` for this new attempt num_last_try_++; - // Builds the variable name, which is sce_var_i where i will go up from 1 - std::string prefix = "cse_var_"; + // Builds the variable name, which is cse_vi where i will go up from 1 + std::string prefix = "cse_v"; std::string name = prefix.append(std::to_string(num_last_try_)); // Builds a String using the std::string String string_name(name); @@ -241,8 +223,8 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_); // Sort the vector of semantic entities by decreasing size - std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(), - OrderOnExprAndFrequency); + std::stable_sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(), + OrderOnExprAndFrequency); // For each computation done (considering them from biggest to smallest) for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) { @@ -421,8 +403,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_); // Sort the vector of semantic entities by decreasing size - std::sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(), - OrderOnExprAndFrequency); + std::stable_sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(), + OrderOnExprAndFrequency); // For each computation done (considering them from biggest to smallest) for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) { diff --git a/src/tir/transforms/common_subexpr_elim.h b/src/tir/transforms/common_subexpr_elim.h index 5c14caf1a6e3..12a71458e13f 100644 --- a/src/tir/transforms/common_subexpr_elim.h +++ b/src/tir/transforms/common_subexpr_elim.h @@ -83,7 +83,8 @@ class CommonSubexpressionEliminator : public StmtExprMutator { static bool ForbiddenComputation(const PrimExpr& expr); static bool IsEligibleComputation(const PrimExpr& expr); static bool CanContainEligibleComputations(const PrimExpr& expr); - static bool OrderOnExprAndFrequency(std::pair a, std::pair b); + static bool OrderOnExprAndFrequency(const std::pair& a, + const std::pair& b); Var GenerateNewVar(DataType type_annotation); }; diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index ce8aef4587dd..f71d2cf42a02 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -797,7 +797,7 @@ std::vector> SyntacticToSemanticComputations( // normalized. This normalized table will keep the count for each set of equivalent terms // (i.e. each equivalence class), together with a term that did appear in this equivalence class // (in practice, the first term of the equivalence class that was encoutered). - std::unordered_map, StructuralHash, ExprDeepEqual> + support::OrderedMap, StructuralHash, ExprDeepEqual> norm_table; // In order to avoid frequent rehashing if the norm_table becomes big, we immediately ask for @@ -806,23 +806,7 @@ std::vector> SyntacticToSemanticComputations( // equivalence classes as there are elements) norm_table.reserve(table.size()); - // Transform the input hashtable to a vector and sort it according to some order, as we will be - // iterating through its items soon, and the order of appearance will be used to determine the - // individual representant for each class of equivalence, which we want to be deterministic - // (otherwise {x+y, y+x} could be both replaced by x+y, and on another run by y+x). - std::vector> sorted_items_of_table(table.begin(), table.end()); - - // We do the ordering by comparing the string repr of each expr to get a determinstic ordering - sort(sorted_items_of_table.begin(), sorted_items_of_table.end(), - [](std::pair a, std::pair b) { - std::stringstream a_stream; - std::stringstream b_stream; - a_stream << AsLegacyRepr(a.first); - b_stream << AsLegacyRepr(b.first); - return a_stream.str().compare(b_stream.str()) < 0; - }); - - for (const auto& elem : sorted_items_of_table) { + for (const auto& elem : table) { PrimExpr norm_elem = NormalizeTerm(elem.first, identify_equiv_terms); // If the normalized term is not already a key in the normalized table auto it_found = norm_table.find(norm_elem); @@ -831,7 +815,7 @@ std::vector> SyntacticToSemanticComputations( // (i.e. `norm_elem` has been seen `elem`.second many times so far, and the chosen element // to represent the equivalence class will be `elem`.first as it's the first element of the // class that we see) - norm_table[norm_elem] = elem; + norm_table.insert(norm_elem, elem); } else { // Otherwise, it's not the first time we see a term in this equivalence class, so we just // increase the count of this equivalence class as we now have `elem`.second additional items @@ -850,10 +834,8 @@ std::vector> SyntacticToSemanticComputations( // Careful : the pairs will never change (the canonical represantants chosen will always be the // same), but the order in which the pairs are produced can vary as we are iterating through the // hashtable `norm_table`. It is not an issue as the called will be sorting the result anyway. - std::unordered_map, StructuralHash, - ExprDeepEqual>::const_iterator it_norm_table; - for (it_norm_table = norm_table.begin(); it_norm_table != norm_table.end(); ++it_norm_table) { - result.push_back(it_norm_table->second); + for (const auto& kv : norm_table) { + result.push_back(kv.second); } return result; diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h index 58014e6a406d..31a81dabdbf2 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.h +++ b/src/tir/transforms/common_subexpr_elim_tools.h @@ -34,10 +34,12 @@ #include // For the class StmtExprVisitor #include -#include // For the hashtable datatype -#include // For pairs datatype +#include +#include // For pairs datatype #include +#include "../../support/ordered_map.h" + namespace tvm { namespace tir { @@ -50,7 +52,7 @@ namespace tir { not do variables remapping), so it is compatible with StructuralHash (intended to be used with StructuralEqual). */ -using ComputationTable = std::unordered_map; +using ComputationTable = support::OrderedMap; /*! * \brief A cache of computations is made of a pair of two hashtables, which respectively associate diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index e7e64d89168e..1be5e57ba15a 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -93,14 +93,14 @@ def test_cse(): assert body.var.name == "z2" assert body.value == 2 - # This is the let-in for the first variable generated cse_var_1 + # This is the let-in for the first variable generated cse_v1 assert isinstance(body.body, tvm.tir.LetStmt) body = body.body # And this is the name and value of this variable - cse_var_1 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_var_1" + cse_v1 = body.var # Keep the variable accessible for later checking the replacements + assert body.var.name == "cse_v1" tvm.ir.assert_structural_equal(body.value, z1 + z2) assert isinstance(body.body, tvm.tir.SeqStmt) @@ -118,27 +118,27 @@ def test_cse(): assert body.var.name == "y" assert body.value == 1 - # This is the let-in for the second variable generated cse_var_2 + # This is the let-in for the second variable generated cse_v2 assert isinstance(body.body, tvm.tir.LetStmt) body = body.body # And this is the name and value of this variable - cse_var_2 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_var_2" + cse_v2 = body.var # Keep the variable accessible for later checking the replacements + assert body.var.name == "cse_v2" tvm.ir.assert_structural_equal(body.value, x + y) body = body.body body.var.name == "a" # Check that the replacement has been done correctly! - tvm.ir.assert_structural_equal(body.value, cse_var_2 + cse_var_1) + tvm.ir.assert_structural_equal(body.value, cse_v2 + cse_v1) body = body.body body.var.name == "b" # Check that the replacement has been done correctly! - tvm.ir.assert_structural_equal(body.value, cse_var_2 + z3) + tvm.ir.assert_structural_equal(body.value, cse_v2 + z3) assert isinstance(body.body, tvm.tir.BufferStore) @@ -199,7 +199,7 @@ def test_cse_ifNode_1(): body = body.then_case # The let-in introduced by the CSE should appear now, inside the Then branch of the If node - assert body.var.name == "cse_var_1" + assert body.var.name == "cse_v1" # and it should contain the expression (y+z) that was redundant tvm.ir.assert_structural_equal(body.value, y + z) @@ -250,7 +250,7 @@ def test_cse_ifNode_2(): assert isinstance(body, tvm.tir.LetStmt) # The let-in introduced by the CSE should appear now, at the toplevel (i.e. before the If) - assert body.var.name == "cse_var_1" + assert body.var.name == "cse_v1" # and it should contain the expression (y+z) that was redundant tvm.ir.assert_structural_equal(body.value, y + z) @@ -291,8 +291,8 @@ def test_cse_cascade(): assert isinstance(body, tvm.tir.LetStmt) # The second let-in (by order introduced) introduced by the CSE should appear first - cse_var_2 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_var_2" + cse_v2 = body.var # Keep the variable accessible for later checking the replacements + assert body.var.name == "cse_v2" # and it should contain the expression (x+y) tvm.ir.assert_structural_equal(body.value, (x + y)) @@ -301,10 +301,10 @@ def test_cse_cascade(): assert isinstance(body, tvm.tir.LetStmt) # The first let-in (by order introduced) introduced by the CSE should appear now, after the 2nd - cse_var_1 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_var_1" - # and it should contain the expression cse_var_2+z - tvm.ir.assert_structural_equal(body.value, cse_var_2 + z) + cse_v1 = body.var # Keep the variable accessible for later checking the replacements + assert body.var.name == "cse_v1" + # and it should contain the expression cse_v2+z + tvm.ir.assert_structural_equal(body.value, cse_v2 + z) body = body.body @@ -317,9 +317,9 @@ def test_cse_cascade(): store2 = body[1] store3 = body[2] - tvm.ir.assert_structural_equal(store1.value, cse_var_1) - tvm.ir.assert_structural_equal(store2.value, cse_var_1) - tvm.ir.assert_structural_equal(store3.value, cse_var_2) + tvm.ir.assert_structural_equal(store1.value, cse_v1) + tvm.ir.assert_structural_equal(store2.value, cse_v1) + tvm.ir.assert_structural_equal(store3.value, cse_v2) # ----------------------------------------------------------------------------------------- @@ -360,9 +360,9 @@ def func_distributivity( def func_distributivity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.LetStmt(x * y + x * z) as cse_var_1: - B[i1] = cse_var_1 - B[i2] = cse_var_1 + with T.LetStmt((y + z) * x) as cse_v1: + B[i1] = cse_v1 + B[i2] = cse_v1 @T.prim_func @@ -377,9 +377,9 @@ def func_associativity( def func_associativity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.LetStmt((x + y) + z) as cse_var_1: - B[i1] = cse_var_1 - B[i2] = cse_var_1 + with T.LetStmt(x + y + z) as cse_v1: + B[i1] = cse_v1 + B[i2] = cse_v1 def _check(original, transformed): @@ -410,10 +410,10 @@ def test_deterministic_cse(): result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3) --> - cse_var_3 = (x + 1) - cse_var_2 = (x + 2) - cse_var_1 = (x + 3) - result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + cse_var_1 + cse_v3 = (x + 1) + cse_v2 = (x + 2) + cse_v1 = (x + 3) + result = cse_v3 + cse_v2 + cse_v1 + cse_v3 + cse_v2 + cse_v1 """ NUM_TERMS = 10 REPEATS = 10 diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index da079f46e38e..13487b42f00f 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -329,11 +329,11 @@ def test_inject_async_copy_barrier(): __asm__ __volatile__("cp.async.commit_group;"); for (int i = 0; i < 13; ++i) { - bool cse_var_1 = (i < 12); + bool cse_v1 = (i < 12); { unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); - int pred_guard = (int)cse_var_1; + int pred_guard = (int)cse_v1; __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" @@ -356,7 +356,7 @@ def test_inject_async_copy_barrier(): { unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); - int pred_guard = (int)cse_var_1; + int pred_guard = (int)cse_v1; __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" @@ -954,10 +954,10 @@ def before(A: T.Buffer((32, 128), "float16")): T.attr("default", "async_scope", 1) for i in range(16): - cse_var_1: T.int64 = T.Cast("int64", i) - A_shared[ - T.Ramp(tx * T.int64(128) + cse_var_1 * T.int64(8), T.int64(1), 8) - ] = A_flattened[T.Ramp(tx * T.int64(128) + cse_var_1 * T.int64(8), T.int64(1), 8)] + cse_v1: T.int64 = T.Cast("int64", i) + A_shared[T.Ramp(tx * T.int64(128) + cse_v1 * T.int64(8), T.int64(1), 8)] = A_flattened[ + T.Ramp(tx * T.int64(128) + cse_v1 * T.int64(8), T.int64(1), 8) + ] T.ptx_commit_group() T.ptx_wait_group(0) @@ -965,13 +965,13 @@ def expected(A: T.Buffer((32, 128), "float16")): tx = T.launch_thread("threadIdx.x", T.int64(32)) A_shared = T.decl_buffer((4096,), "float16", scope="shared") for i in range(16): - cse_var_1: T.int64 = T.Cast("int64", i) + cse_v1: T.int64 = T.Cast("int64", i) T.ptx_cp_async( "float16", A_shared.data, - tx * T.int64(128) + cse_var_1 * T.int64(8), + tx * T.int64(128) + cse_v1 * T.int64(8), A.data, - tx * T.int64(128) + cse_var_1 * T.int64(8), + tx * T.int64(128) + cse_v1 * T.int64(8), 16, ) T.ptx_commit_group() diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index c63d2f8a4137..299c19314654 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -154,8 +154,8 @@ def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112 rxplaceholder_1 = T.Buffer((T.int64(822083584),), data=rxplaceholder.data) T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract) for ax1, ax2 in T.grid(32, 25690112): - cse_var_1: T.int32 = ax1 * 25690112 + ax2 - T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] - rxplaceholder_red_1[ax1] + cse_v1: T.int32 = ax1 * 25690112 + ax2 + T_subtract_1[cse_v1] = rxplaceholder_1[cse_v1] - rxplaceholder_red_1[ax1] func = variance4 tvm.compile(func, target="llvm") # should not crash diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index af2db34415f8..0e1b328844be 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3679,6 +3679,7 @@ def merge_shape_var_def(): # uninitialized vars @T.prim_func(check_well_formed=False) def main(A: T.handle, B: T.handle): + # fmt: off T.func_attr({"global_symbol": "main", "tir.noalias": True}) m, n = T.int32(), T.int32() A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"), buffer_type="auto") @@ -3687,8 +3688,8 @@ def main(A: T.handle, B: T.handle): if T.likely(i_outer * 10 + i_inner < m): for j_inner in range(5): if T.likely(j_outer * 5 + j_inner < n): - cse_var_2: T.int32 = j_outer * 5 + j_inner - cse_var_1: T.int32 = i_outer * 10 + i_inner + cse_v2: T.int32 = j_outer * 5 + j_inner + cse_v1: T.int32 = i_outer * 10 + i_inner B_2 = T.Buffer( (B_1.strides[0] * m,), data=B_1.data, @@ -3701,9 +3702,10 @@ def main(A: T.handle, B: T.handle): strides=("A_2_s0",), buffer_type="auto", ) - B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]] = A_2[ - cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1] + B_2[cse_v1 * B_1.strides[0] + cse_v2 * B_1.strides[1]] = A_2[ + cse_v1 * A_1.strides[0] + cse_v2 * A_1.strides[1] ] + # fmt: on return main