diff --git a/HalideIR b/HalideIR index adfa66240265..30bf0f043e63 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit adfa662402650e2f9b02ea600ffb70d6e7bb5adf +Subproject commit 30bf0f043e6388418958fd1f29259ee43c42b600 diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index beed7e9d1281..2e4d7debcf42 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -50,6 +50,9 @@ class Buffer : public NodeRef { * \return the pointer to the internal node container */ inline const BufferNode* operator->() const; + + /*! \brief specify container node */ + using ContainerType = BufferNode; }; /*! \brief Node to represent a buffer */ diff --git a/include/tvm/c_runtime_api.h b/include/tvm/c_runtime_api.h index 1a21adc41cd6..25b81d80ce5a 100644 --- a/include/tvm/c_runtime_api.h +++ b/include/tvm/c_runtime_api.h @@ -30,6 +30,7 @@ #endif #include +#include TVM_EXTERN_C { @@ -216,18 +217,45 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); /*! - * \brief Launch a generated TVM function + * \brief TVM Function API: Get resource requirement + * + * By default TVM function try not to do internal allocations. + * Instead, TVMFuncRequirement can be called, given the input arguments. + * + * \param func function handle to be launched. + * \param args The arguments + * \param arg_type_ids The type id of the arguments + * \param num_args Number of arguments. + * \param out_workspace_size The workspace size needed to launch this function. + * \param out_workspace_align The alignment requirement of workspace. + * + * \note The data pointer in the arrays is not used by requirement. + */ +TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func, + TVMArg* args, + int* arg_type_ids, + int num_args, + size_t* out_workspace_size, + size_t* out_workspace_align); + +/*! + * \brief TVM Function API: Launch generated function. + * * \param func function handle to be launched. * \param args The arguments * \param arg_type_ids The type id of the arguments * \param num_args Number of arguments. * \param stream The stream this function to be launched on. + * \param workspace Additional workspace used to launch this function. + * + * \sa TVMFuncRequirement */ -TVM_DLL int TVMLaunch(TVMFunctionHandle func, - TVMArg* args, - int* arg_type_ids, - int num_args, - TVMStreamHandle stream); +TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func, + TVMArg* args, + int* arg_type_ids, + int num_args, + TVMStreamHandle stream, + TVMArrayHandle workspace); } // TVM_EXTERN_C #endif // TVM_C_RUNTIME_API_H_ diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h new file mode 100644 index 000000000000..b4a15e5a5d7f --- /dev/null +++ b/include/tvm/codegen.h @@ -0,0 +1,68 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file codegen.h + * \brief Collection of Lowlevel IR pass to codegen. + */ +#ifndef TVM_CODEGEN_H_ +#define TVM_CODEGEN_H_ + +#include +#include "./base.h" +#include "./expr.h" +#include "./module.h" + +namespace tvm { +/*! \brief namespace for lowlevel IR pass and codegen */ +namespace codegen { +/*! + * \brief Make an user callable API LoweredFunc. + * + * The main task of this function is to create code to : + * - Map the values in the api_args to of Var that is required by body. + * - Insert assertions to check type/value of the passed arguments. + * + * \param body The body of the function. + * \param name The name of the function. + * \param api_args Arguments to the function, can be either Var, or Buffer + * \param num_packed_args Number of arguments that are processed in packed form. + * \return a LoweredFunc with the specified signiture. + * + * \note + * The function signiture have two cases + * + * if num_packed_args is zero: + * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) + * + * if num_packed_args is not zero: + * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, + * api_arg_k, api_arg_k+1, ... api_arg_n) + * + * where n == len(api_args), k == num_packed_args + * + * There is no thread_axis in generated function. + */ +LoweredFunc MakeAPI(Stmt body, + std::string name, + Array api_args, + int num_packed_args); + +/*! + * \brief Count number of undefined vars in f. + * \param f The function to be checked. + * \return Number of undefined vars. + */ +Array UndefinedVars(const LoweredFunc& f); + +/*! + * \brief Split the function into a host function and device functions. + * \param func The function to be splitted. + * + * \return Array of functions, the first one is host function, + * the others are device functions. + */ +Array SplitHostDevice(LoweredFunc func); + +} // namespace codegen +} // namespace tvm + +#endif // TVM_CODEGEN_H_ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index dd53d53b2c37..0676104213c9 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -49,6 +49,48 @@ struct Reduce : public ExprNode { static constexpr const char* Min = "Min"; }; +/*! \brief namespace of TVM Intrinsic functions */ +namespace intrinsic { +// Most of the intrinsics is to enab +/*! + * \brief See pesudo code + * + * Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) { + * assert(arg_type_id[i] == typeid(Type)); + * return args[i]; + * } + */ +constexpr const char* tvm_api_load_arg = "tvm_api_load_arg"; +/*! + * \brief See pesudo code + * + * Type tvm_array_get_field(TVMArray* arr, int field_id) { + * return arr->field; + * } + * \sa TVMArrayFieldKind + */ +constexpr const char* tvm_array_get_field = "tvm_array_get_field"; +/*! + * \brief See pesudo code + * + * bool tvm_handle_is_null(void* handle) { + * return handle == nullptr + * } + */ +constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; + +/*! \brief The field id of each field in array */ +enum TVMArrayFieldKind { + kData = 0, + kNDim = 1, + kShape = 2, + kStrides = 3, + kTypeCode = 4, + kTypeBits = 5, + kTypeLanes = 6 +}; +} // namespace intrinsic + // Reuse IR node defintiion from HalideIR using Halide::Internal::IntImm; using Halide::Internal::UIntImm; diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 3106df1ffd02..b57bca25eb49 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -9,6 +9,7 @@ #include #include #include "./expr.h" +#include "./ir.h" namespace tvm { namespace ir { @@ -51,6 +52,20 @@ class IRMutator { static FMutateExpr& vtable_expr(); // NOLINT(*) /*! \return internal stmt of expr */ static FMutateStmt& vtable_stmt(); // NOLINT(*) + // Set of overloadable functions + // The underscore allows Mutate not to be shadowed by inheritance + virtual Stmt Mutate_(const LetStmt* op, const Stmt& s); + virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s); + virtual Stmt Mutate_(const For* op, const Stmt& s); + virtual Stmt Mutate_(const Provide* op, const Stmt& s); + virtual Stmt Mutate_(const Allocate* op, const Stmt& s); + virtual Stmt Mutate_(const Realize* op, const Stmt& s); + virtual Stmt Mutate_(const Store* op, const Stmt& s); + virtual Stmt Mutate_(const Free* op, const Stmt& s); + virtual Expr Mutate_(const Call* op, const Expr& e); + virtual Expr Mutate_(const Load* op, const Expr& s); + virtual Expr Mutate_(const Variable* op, const Expr& e); + virtual Expr Mutate_(const Let* op, const Expr& e); }; /*! diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index a45bbbb91fd8..a2c2956a944a 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -56,6 +56,12 @@ Stmt ScheduleOps(Schedule s, Map dom_map); */ bool VerifySSA(const Stmt& ir); +/*! + * \brief Whether the expression have side effect. + * \return whether expression have side effect + */ +bool HasSideEffect(const Expr& e); + /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. @@ -79,7 +85,6 @@ Stmt Inline(Stmt stmt, Array args, Expr body); - /*! * \brief Flatten the multi-dimensional read/write * to single dimensional Load/Store diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index b64406d7ec4f..0df5d3e324f6 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -34,6 +34,17 @@ class IRVisitor { using FVisit = IRFunctor; /*! \return internal vtable*/ static FVisit& vtable(); + // overloadable visit function. + virtual void Visit_(const Variable* op); + virtual void Visit_(const AttrStmt* op); + virtual void Visit_(const LetStmt* op); + virtual void Visit_(const For* op); + virtual void Visit_(const Allocate* op); + virtual void Visit_(const Load* op); + virtual void Visit_(const Store* op); + virtual void Visit_(const Let* op); + virtual void Visit_(const Free* op); + virtual void Visit_(const Call* op); }; /*! diff --git a/include/tvm/module.h b/include/tvm/module.h new file mode 100644 index 000000000000..263fdc2f28f1 --- /dev/null +++ b/include/tvm/module.h @@ -0,0 +1,108 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file module.h + * \brief Low level IR module, + * Contains lowered function information. + */ +#ifndef TVM_MODULE_H_ +#define TVM_MODULE_H_ + +#include +#include +#include + +#include "./base.h" +#include "./expr.h" +#include "./tensor.h" + +namespace tvm { + +// Internal node container of lowered function. +class LoweredFuncNode; + +// Internal node container of module. +class ModuleNode; + +/*! + * \brief LoweredFunc represents function after lowering. + * This is the final IR representation before codegen. + */ +class LoweredFunc : public FunctionRef { + public: + LoweredFunc() {} + explicit LoweredFunc(std::shared_ptr n) : FunctionRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const LoweredFuncNode* operator->() const; + /*! \brief specify container node */ + using ContainerType = LoweredFuncNode; +}; + +/*! \brief Node container of LoweredFunc */ +class LoweredFuncNode : public FunctionBaseNode { + public: + /*! \brief The name of the function */ + std::string name; + /*! + * \brief The arguments of the function + * This function can only take pod type(int, float) and void* as arguments. + */ + Array args; + /*! + * \brief The IterVar axis of threads + * Each axis need host function to specify a size. + * \note Calling convention into LoweredFunc + * + * Assume we have a LoweredFunc f, a call into f + * Call(f, arg1, arg2, ..., arg_n, + * size_axis_1, size_axis_2, ... size_axis_m) + * + * Here n = len(args), m = len(thread_axis) + * + * The CodeGen should take this and translate this call + * to corresponding API specific kernel launchs or function calls. + */ + Array thread_axis; + /*! + * \brief The hint data type of Var handles defined in LetStmt + * Can be used as hint when generating type signiture. + * The creation rule is given by + * handle_data_type[var_handle] = make_const(the_type, 0); + * + * \note Expr is used instead Type, because Type cannot be hold by Map. + * constant Expr of given type is used. + */ + Map handle_data_type; + /*! \brief The body statment of the function */ + Stmt body; + /*! \return name of the operation */ + const std::string& func_name() const final { + return name; + } + // there is no return value, but return 1 + // to enable Call into this function. + int num_outputs() const final { + return 1; + } + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("args", &args); + v->Visit("thread_axis", &thread_axis); + v->Visit("handle_data_type", &handle_data_type); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "LoweredFunc"; + TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode); +}; + +// Implementations of inline functions +inline const LoweredFuncNode* LoweredFunc::operator->() const { + return static_cast(node_.get()); +} + +} // namespace tvm + +#endif // TVM_MODULE_H_ diff --git a/python/tvm/collections.py b/python/tvm/collections.py index 85e629cc96da..2e43e2e6bec0 100644 --- a/python/tvm/collections.py +++ b/python/tvm/collections.py @@ -56,3 +56,9 @@ class IterVar(NodeBase, _expr.ExprOp): class Buffer(NodeBase): """Represent a Buffer in TVM.""" pass + + +@register_node +class LoweredFunc(NodeBase): + """Represent a LoweredFunc in TVM.""" + pass diff --git a/src/base/common.h b/src/base/common.h index ea2f4bdad9e5..432ec74db9af 100644 --- a/src/base/common.h +++ b/src/base/common.h @@ -7,6 +7,7 @@ #define TVM_BASE_COMMON_H_ #include +#include #include namespace tvm { @@ -30,7 +31,7 @@ inline Type String2Type(std::string s) { } else if (s.substr(0, 5) == "float") { code = Type::Float; s = s.substr(5); } else if (s == "handle") { - return Type(Type::Handle, 32, 1); + return Handle(); } else { LOG(FATAL) << "unknown type " << s; } diff --git a/src/c_api/c_api_codegen.cc b/src/c_api/c_api_codegen.cc index 365033ea445f..0fa5973a4f87 100644 --- a/src/c_api/c_api_codegen.cc +++ b/src/c_api/c_api_codegen.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include "./c_api_registry.h" #include "../codegen/codegen_c.h" @@ -17,9 +18,19 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_codegen_CompileToC) .set_body([](const ArgStack& args, RetValue *ret) { - *ret = CodeGenC().Compile( + *ret = CodeGenC().Compile(args.at(0), args.at(1)); + }); + +TVM_REGISTER_API(_codegen_MakeAPI) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = MakeAPI( args.at(0), args.at(1), args.at(2), args.at(3)); }); +TVM_REGISTER_API(_codegen_SplitHostDevice) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = SplitHostDevice(args.at(0)); + }); + } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index a42569e9ad32..327778db05bb 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -9,24 +9,27 @@ namespace codegen { using namespace ir; -std::string CodeGenC::Compile( - Stmt stmt, std::string fun_name, - Array args, bool output_ssa) { +std::string CodeGenC::Compile(LoweredFunc f, + bool output_ssa) { print_ssa_form_ = output_ssa; // skip the first underscore, so SSA variable starts from _1 if (print_ssa_form_) GetUniqueName("_"); + // add to alloc buffer type. + for (const auto & kv : f->handle_data_type) { + HandleTypeRegister(kv.first.get(), kv.second.type()); + } this->indent += 2; - this->stream << "void " << fun_name << "("; - for (size_t i = 0; i < args.size(); ++i) { - Var v = args[i]; + this->stream << "void " << f->name << "("; + for (size_t i = 0; i < f->args.size(); ++i) { + Var v = f->args[i]; std::string vid = AllocVarID(v.get()); if (i != 0) stream << ", "; PrintType(v.type(), stream); stream << ' ' << vid; } stream << ") {\n"; - this->PrintStmt(stmt); + this->PrintStmt(f->body); this->indent -= 2; this->PrintIndent(); this->stream << "}\n"; @@ -104,12 +107,22 @@ std::string CodeGenC::GetVarID(const Variable* v) const { return it->second; } -bool CodeGenC::BufferTypeMatch(const Variable* buf_var, Type t) const { - auto it = alloc_buf_type_.find(buf_var); - if (it == alloc_buf_type_.end()) return false; +bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const { + auto it = handle_data_type_.find(buf_var); + if (it == handle_data_type_.end()) return false; return it->second == t; } +void CodeGenC::HandleTypeRegister(const Variable* buf_var, Type t) { + auto it = handle_data_type_.find(buf_var); + if (it == handle_data_type_.end()) { + handle_data_type_[buf_var] = t; + } else { + CHECK(it->second == t) + << "conflicting buf var type"; + } +} + void CodeGenC::PrintIndent() { for (int i = 0; i < this->indent; ++i) { this->stream << ' '; @@ -234,6 +247,18 @@ inline void PrintBinaryExpr(const T* op, os << ')'; } +inline void PrintBinaryIntrinsitc(const Call* op, + const char *opstr, + std::ostream& os, // NOLINT(*) + CodeGenC* p) { + CHECK_EQ(op->args.size(), 2U); + os << '('; + p->PrintExpr(op->args[0], os); + os << opstr; + p->PrintExpr(op->args[1], os); + os << ')'; +} + TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) .set_dispatch([](const Cast *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) p->PrintType(op->type, os); @@ -300,24 +325,9 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) .set_dispatch([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) os << '!'; p->PrintExpr(op->a, os); - }) -.set_dispatch([](const Call *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - os << op->name << "("; - for (size_t i = 0; i < op->args.size(); i++) { - p->PrintExpr(op->args[i], os); - if (i < op->args.size() - 1) { - os << ", "; - } - } - os << ")"; }); TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) -.set_dispatch([](const AssertStmt *op, CodeGenC* p) { - std::string cond = p->PrintExpr(op->condition); - p->PrintIndent(); - p->stream << "assert(" << cond << ");\n"; - }) .set_dispatch([](const ProducerConsumer *op, CodeGenC* p) { p->PrintStmt(op->body); }) @@ -372,14 +382,95 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) .DISPATCH_EXPR(Load) +.DISPATCH_EXPR(Call) .DISPATCH_EXPR(Let) .DISPATCH_EXPR(Ramp) .DISPATCH_EXPR(Broadcast) .DISPATCH_EXPR(Select); + +void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*) + CodeGenC* p = this; + if (op->is_intrinsic(Call::bitwise_and)) { + PrintBinaryIntrinsitc(op, " & ", os, p); + } else if (op->is_intrinsic(Call::bitwise_xor)) { + PrintBinaryIntrinsitc(op, " ^ ", os, p); + } else if (op->is_intrinsic(Call::bitwise_or)) { + PrintBinaryIntrinsitc(op, " | ", os, p); + } else if (op->is_intrinsic(Call::bitwise_not)) { + CHECK_EQ(op->args.size(), 1U); + os << "(~"; + p->PrintExpr(op->args[0], os); + os << ')'; + } else if (op->is_intrinsic(Call::shift_left)) { + PrintBinaryIntrinsitc(op, " << ", os, p); + } else if (op->is_intrinsic(Call::shift_right)) { + PrintBinaryIntrinsitc(op, " >> ", os, p); + } else if (op->is_intrinsic(Call::address_of)) { + const Load *l = op->args[0].as(); + CHECK(op->args.size() == 1 && l); + os << "(("; + p->PrintType(l->type.element_of(), os); + os << " *)" << p->GetVarID(l->buffer_var.get()) + << " + "; + p->PrintExpr(l->index, os); + os << ')'; + } else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) { + CHECK_EQ(op->args.size(), 3U); + if (!op->type.is_handle()) { + os << '('; + p->PrintType(op->type, os); + os << ')'; + } + os << "(((TVMArg*)"; + p->PrintExpr(op->args[0], os); + os << ")[" << op->args[2] << "]."; + if (op->type.is_handle()) { + os << "v_handle"; + } else if (op->type.is_float()) { + os << "v_double"; + } else if (op->type.is_int() || op->type.is_uint()) { + os << "v_long"; + } else { + LOG(FATAL) << "donot know how to handle type" << op->type; + } + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) { + CHECK_EQ(op->args.size(), 2U); + os << "(((TVMArray*)"; + p->PrintExpr(op->args[0], os); + os << ")->"; + switch (op->args[1].as()->value) { + case intrinsic::kData: os << "data"; break; + case intrinsic::kShape: os << "shape"; break; + case intrinsic::kStrides: os << "strides"; break; + case intrinsic::kNDim: os << "ndim"; break; + case intrinsic::kTypeCode: os << "dtype.type_code"; break; + case intrinsic::kTypeBits: os << "dtype.bits"; break; + case intrinsic::kTypeLanes: os << "dtype.lanes"; break; + default: LOG(FATAL) << "unknown field code"; + } + os << ')'; + } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { + CHECK_EQ(op->args.size(), 1U); + os << "("; + p->PrintExpr(op->args[0], os); + os << " == NULL)"; + } else { + os << op->name << "("; + for (size_t i = 0; i < op->args.size(); i++) { + p->PrintExpr(op->args[i], os); + if (i < op->args.size() - 1) { + os << ", "; + } + } + os << ")"; + } +} + void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*) std::string vid = GetVarID(op->buffer_var.get()); - if (!BufferTypeMatch(op->buffer_var.get(), op->type)) { + if (!HandleTypeMatch(op->buffer_var.get(), op->type)) { os << "((const "; PrintType(op->type, os); os << "*)" << vid << ')'; @@ -416,7 +507,8 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) .set_dispatch([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); }) .set_dispatch([](const Store *op, CodeGenC* p) { p->PrintStmt(op); }) .set_dispatch([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); }) -.set_dispatch([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); }); +.set_dispatch([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); }) +.set_dispatch([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); }); void CodeGenC::PrintStmt(const LetStmt* op) { @@ -426,10 +518,20 @@ void CodeGenC::PrintStmt(const LetStmt* op) { var_idmap_[op->var.get()] = value; } else { PrintIndent(); - PrintType(op->var.type(), this->stream); - this->stream << ' ' - << AllocVarID(op->var.get()) - << " = " << value << ";\n"; + if (op->var.type() == Handle() && + handle_data_type_.count(op->var.get())) { + PrintType(handle_data_type_.at(op->var.get()), stream); + stream << "* " + << AllocVarID(op->var.get()) + << " = ("; + PrintType(handle_data_type_.at(op->var.get()), stream); + stream << "*)" << value << ";\n"; + } else { + PrintType(op->var.type(), this->stream); + this->stream << ' ' + << AllocVarID(op->var.get()) + << " = " << value << ";\n"; + } } PrintStmt(op->body); } @@ -439,7 +541,7 @@ void CodeGenC::PrintStmt(const Store* op) { std::string value = this->PrintExpr(op->value); this->PrintIndent(); std::string vid = GetVarID(op->buffer_var.get()); - if (!BufferTypeMatch(op->buffer_var.get(), op->value.type())) { + if (!HandleTypeMatch(op->buffer_var.get(), op->value.type())) { this->stream << "(("; PrintType(op->value.type(), this->stream); this->stream << "*)" << vid << ')'; @@ -452,16 +554,25 @@ void CodeGenC::PrintStmt(const Store* op) { } void CodeGenC::PrintStmt(const Allocate* op) { - this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - std::string vid = AllocVarID(op->buffer_var.get()); - CHECK(!op->new_expr.defined()); CHECK(!is_zero(op->condition)); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; - PrintType(op->type, stream); - stream << ' '<< vid << '[' - << constant_size << "]\n;"; + std::string vid = AllocVarID(op->buffer_var.get()); + if (op->new_expr.defined()) { + // Prefer global static allocation for the program + CHECK_EQ(op->free_function, "nop"); + std::string new_data = PrintExpr(op->new_expr); + this->PrintIndent(); + PrintType(op->type, stream); + stream << "* "<< vid << '=' << new_data << ";\n"; + } else { + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + PrintType(op->type, stream); + stream << ' '<< vid << '[' + << constant_size << "]\n;"; + } + HandleTypeRegister(op->buffer_var.get(), op->type); this->PrintStmt(op->body); } @@ -469,15 +580,29 @@ void CodeGenC::PrintStmt(const AttrStmt* op) { if (op->type_key == "scope") { IterVar iv(op->node.node_); if (iv->thread_tag.length() != 0) { - this->PrintIndent(); - PrintType(iv->var.type(), stream); - stream << ' ' - << AllocVarID(iv->var.get()) - << " = " << iv->thread_tag << ";\n"; + if (!var_idmap_.count(iv->var.get())) { + this->PrintIndent(); + PrintType(iv->var.type(), stream); + stream << ' ' + << AllocVarID(iv->var.get()) + << " = " << iv->thread_tag << ";\n"; + } } } this->PrintStmt(op->body); } +void CodeGenC::PrintStmt(const AssertStmt* op) { + std::string cond = PrintExpr(op->condition); + PrintIndent(); + if (op->message.as()) { + // GLOG style check + stream << "CHECK(" << cond << ") << \"" + << op->message.as()->value << "\";\n"; + } else { + stream << "assert(" << cond << ");\n"; + } +} + } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index a8ce1828e4b4..4630e9990b56 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -23,16 +24,12 @@ class CodeGenC { public: /*! * \brief Generate the C code of statement - * \param body The body of the function. - * \param fun_name The name of the function. - * \param args The arguments to the function. + * \param f The function to be compiled * \param output_ssa Whether output ssa form. * \note Only call compile once, * create a new codegen object each time. */ - std::string Compile(Stmt body, - std::string fun_name, - Array args, + std::string Compile(LoweredFunc f, bool output_ssa); /*! * \brief Print the Stmt n to CodeGenC->stream @@ -49,7 +46,7 @@ class CodeGenC { * \brief Same as PrintExpr, but simply returns result string * \param n The expression to be printed. */ - inline std::string PrintExpr(const Expr& n) { + std::string PrintExpr(const Expr& n) { std::ostringstream os; PrintExpr(n, os); return os.str(); @@ -85,7 +82,9 @@ class CodeGenC { virtual void PrintStmt(const ir::Store* op); virtual void PrintStmt(const ir::Allocate* op); virtual void PrintStmt(const ir::AttrStmt* op); + virtual void PrintStmt(const ir::AssertStmt* op); virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*) + virtual void PrintExpr(const ir::Call* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*) @@ -116,7 +115,13 @@ class CodeGenC { * \param buf_var The buffer variable. * \param t The type to be checked. */ - bool BufferTypeMatch(const Variable* buf_var, Type t) const; + bool HandleTypeMatch(const Variable* buf_var, Type t) const; + /*! + * \brief Register the data type of buf_var + * \param buf_var The buffer variable. + * \param t The type to be checked. + */ + void HandleTypeRegister(const Variable* buf_var, Type t); /*! * \brief get a unique name with the corresponding prefix * \param prefix The prefix of the name @@ -128,7 +133,7 @@ class CodeGenC { /*! \brief name of each variable */ std::unordered_map var_idmap_; /*! \brief the data type of allocated buffers */ - std::unordered_map alloc_buf_type_; + std::unordered_map handle_data_type_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief assignment map of ssa */ diff --git a/src/codegen/make_api.cc b/src/codegen/make_api.cc new file mode 100644 index 000000000000..227faf37f410 --- /dev/null +++ b/src/codegen/make_api.cc @@ -0,0 +1,200 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file make_api.cc Build API function. + */ +#include +#include +#include + +#include +#include +#include + +#include "../pass/ir_util.h" + +namespace tvm { +namespace codegen { +using namespace ir; + +inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) { + return Call::make( + t, intrinsic::tvm_array_get_field, + {arr, IntImm::make(Int(32), kind)}, + Call::PureIntrinsic); +} + +inline Stmt AssertNull(Var handle, std::string msg) { + return AssertStmt::make(Call::make( + Bool(1), intrinsic::tvm_handle_is_null, + {handle}, Call::PureIntrinsic), msg); +} + +inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) { + return AssertStmt::make(lhs == rhs, msg); +} + +LoweredFunc MakeAPI(Stmt body, + std::string name, + Array api_args, + int num_packed_args) { + const Type tvm_index_type = UInt(32); + const Stmt nop = Evaluate::make(0); + // Data field definitions + // The packed fields + Var v_packed_args("args", Handle()); + Var v_packed_arg_type_ids("arg_type_ids", Handle()); + Var v_num_packed_args("num_args", Int(32)); + // The arguments of the function. + Array args; + // seq_init gives sequence of initialization + // seq_check gives sequence of later checks after iniit + std::vector seq_init, seq_check; + std::unordered_set visited; + // the handle data types + Map handle_data_type; + // --------------------------- + // local function defintiions + // load i-th argument as type t + auto f_arg_value = [&](Type t, int i) { + Array call_args{ + v_packed_args, v_packed_arg_type_ids, IntImm::make(Int(32), i)}; + return Call::make( + t, intrinsic::tvm_api_load_arg, call_args, + Call::PureIntrinsic); + }; + // get declaration of argument i + auto f_arg_decl = [&](int i) { + std::ostringstream os; + os << "arg" << i; + const Variable* v = api_args[i].as(); + return Var(os.str(), v ? v->type: Handle()); + }; + // Push related into assertions or variable defintion + // given the symbolic declaration and concrete value + auto f_push = [&](Expr sym, Expr value, std::string field) { + if (sym.as()) { + // If sym is a Variable and this Variable is not yet defined + // add this to defintion. + Var v(sym.node_); + if (!visited.count(v.get())) { + seq_init.emplace_back(LetStmt::make(v, value, nop)); + visited.insert(v.get()); + return true; + } + } + // otherwise, assume sym is already defined, insert assertion. + std::ostringstream os; + os << "Field " << field << " has a unsatisfied constraint"; + seq_check.emplace_back(MakeAssertEQ(sym, value, os.str())); + return false; + }; + // --------------------------- + // start of logics + // add signiture for packed arguments. + if (num_packed_args != 0) { + args.push_back(v_packed_args); + args.push_back(v_packed_arg_type_ids); + args.push_back(v_num_packed_args); + std::ostringstream os; + os << "expected num_args to be " << num_packed_args; + seq_init.emplace_back( + MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); + } + + for (size_t i = 0; i < api_args.size(); ++i) { + Var v_arg = f_arg_decl(i); + if (i < static_cast(num_packed_args)) { + seq_init.emplace_back(LetStmt::make( + v_arg, f_arg_value(v_arg.type(), i), nop)); + } else { + args.push_back(v_arg); + } + // add checks for functions. + if (api_args[i].as()) { + f_push(Var(api_args[i].node_), v_arg, v_arg->name_hint); + } else { + // Buffer checks + CHECK(api_args[i].as()) + << "api_args can only be Buffer or Var"; + Buffer buf(api_args[i].node_); + // dimension checks + Expr v_ndim = TVMArrayGet(tvm_index_type, v_arg, intrinsic::kNDim); + std::ostringstream ndim_err_msg; + ndim_err_msg << "arg_" << i + << ".ndim is expected to equal " + << buf->shape.size(); + seq_init.emplace_back( + MakeAssertEQ(v_ndim, UIntImm::make(tvm_index_type, buf->shape.size()), + ndim_err_msg.str())); + // type checks + Type dtype = buf->dtype; + std::ostringstream type_err_msg; + type_err_msg << "arg" << i << ".dtype is expected to be " << dtype; + Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeCode) == + UIntImm::make(UInt(8), dtype.code()) && + TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeBits) == + UIntImm::make(UInt(8), dtype.bits()) && + TVMArrayGet(UInt(16), v_arg, intrinsic::kTypeLanes) == + UIntImm::make(UInt(16), dtype.lanes())); + seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str())); + // Data Field + if (f_push(buf->ptr, TVMArrayGet(Handle(), v_arg, intrinsic::kData), + v_arg->name_hint + ".data")) { + Var vptr(buf->ptr); + handle_data_type.Set(vptr, make_const(buf->dtype, 0)); + } + // shape field + Var v_shape(v_arg->name_hint + ".shape", Handle()); + handle_data_type.Set(v_shape, UIntImm::make(tvm_index_type, 0)); + seq_init.emplace_back(LetStmt::make( + v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop)); + for (size_t k = 0; k < buf->shape.size(); ++k) { + std::ostringstream field_name; + field_name << v_shape->name_hint << '[' << k << ']'; + f_push(buf->shape[k], + cast(buf->shape[k].type(), + Load::make(tvm_index_type, v_shape, IntImm::make(Int(32), k))), + field_name.str()); + } + // strides field + Var v_strides(v_arg->name_hint + ".strides", Handle()); + handle_data_type.Set(v_strides, UIntImm::make(tvm_index_type, 0)); + seq_init.emplace_back(LetStmt::make( + v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop)); + if (buf->strides.size() == 0) { + std::ostringstream stride_err_msg; + stride_err_msg << "arg_" << i << ".strides:" + << " expected to be nullptr for contiguous array"; + seq_init.emplace_back(AssertNull(v_strides, stride_err_msg.str())); + } else { + for (size_t k = 0; k < buf->strides.size(); ++k) { + std::ostringstream field_name; + field_name << v_strides->name_hint << '[' << k << ']'; + f_push(buf->strides[k], + cast(buf->shape[k].type(), + Load::make(tvm_index_type, v_strides, IntImm::make(Int(32), k))), + field_name.str()); + } + } + } + } + + std::shared_ptr n = std::make_shared(); + n->name = name; + n->args = args; + n->handle_data_type = handle_data_type; + n->body = MergeNest({seq_init, seq_check}, body); + LoweredFunc f(n); + Array undefined = UndefinedVars(f); + if (undefined.size() != 0) { + std::ostringstream os; + for (Var v : undefined) { + os << " \'" << v->name_hint << "\' "; + } + os << " does not appeared in api_args"; + LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str(); + } + return f; +} +} // namespace codegen +} // namespace tvm diff --git a/src/codegen/split_host_device.cc b/src/codegen/split_host_device.cc new file mode 100644 index 000000000000..1560fda4ee36 --- /dev/null +++ b/src/codegen/split_host_device.cc @@ -0,0 +1,218 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file split_host_device.cc + * \brief Split device function from host. + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace codegen { + +using namespace ir; + +// use/def analysis, also delete unreferenced lets +class IRUseDefAnalysis : public IRMutator { + public: + Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { + if (op->type_key == "thread_extent") { + IterVar iv(op->node.node_); + CHECK_NE(iv->thread_tag.length(), 0U); + // thread_extent can appear multiple times + // use the first appearance as def. + if (!use_count_.count(iv->var.get())) { + this->HandleDef(iv->var.get()); + thread_axis_.push_back(iv); + thread_extent_.push_back(op->value); + } + + Expr value = op->value; + if (visit_thread_extent_) { + value = this->Mutate(value); + } + Stmt body = this->Mutate(op->body); + if (value.same_as(value) && body.same_as(body)) return s; + return AttrStmt::make(op->node, op->type_key, value, body); + } else { + return IRMutator::Mutate_(op, s); + } + } + + Stmt Mutate_(const LetStmt *op, const Stmt& s) final { + this->HandleDef(op->var.get()); + Stmt body = this->Mutate(op->body); + // eliminate unreferenced let + if (use_count_.at(op->var.get()) == 0 && + !HasSideEffect(op->value)) { + return body; + } else { + Expr value = this->Mutate(op->value); + if (body.same_as(op->body) && + value.same_as(op->value)) { + return s; + } else { + return LetStmt::make(op->var, value, body); + } + } + } + + Stmt Mutate_(const For *op, const Stmt& s) final { + this->HandleDef(op->loop_var.get()); + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Allocate *op, const Stmt& s) final { + this->HandleDef(op->buffer_var.get()); + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Store *op, const Stmt& s) final { + this->HandleUse(op->buffer_var); + return IRMutator::Mutate_(op, s); + } + + Expr Mutate_(const Let *op, const Expr& e) final { + this->HandleDef(op->var.get()); + Expr body = this->Mutate(op->body); + // eliminate unreferenced let + if (use_count_.at(op->var.get()) == 0 && + !HasSideEffect(op->value)) { + return body; + } else { + Expr value = this->Mutate(op->value); + if (body.same_as(op->body) && + value.same_as(op->value)) { + return e; + } else { + return Let::make(op->var, value, body); + } + } + } + + Expr Mutate_(const Variable *op, const Expr& e) final { + this->HandleUse(e); + return IRMutator::Mutate_(op, e); + } + + Expr Mutate_(const Load *op, const Expr& e) final { + this->HandleUse(op->buffer_var); + return IRMutator::Mutate_(op, e); + } + + void HandleDef(const Variable* v) { + CHECK(!use_count_.count(v)) + << "variable is already defined"; + use_count_[v] = 0; + } + + void HandleUse(const Expr& v) { + CHECK(v.as()); + Var var(v.node_); + auto it = use_count_.find(var.get()); + if (it != use_count_.end()) { + if (it->second >= 0) { + ++it->second; + } + } else { + undefined_.push_back(var); + use_count_[var.get()] = -1; + } + } + + // The fields are publically readible to + // be accessible to the users. + bool visit_thread_extent_{true}; + Array undefined_; + Array thread_axis_; + Array thread_extent_; + std::unordered_map use_count_; +}; + +class HostDeviceSplitter : public IRMutator { + public: + Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { + if (op->type_key == "thread_extent") { + LOG(INFO) << "??"; + IterVar iv(op->node.node_); + return SplitDeviceFunc(s); + } + return IRMutator::Mutate_(op, s); + } + + Array Split(LoweredFunc f) { + for (auto kv : f->handle_data_type) { + handle_data_type_[kv.first.get()] = kv.second; + } + name_ = f->name; + std::shared_ptr n = + std::make_shared(*f.operator->()); + n->body = this->Mutate(f->body); + + Array ret{LoweredFunc(n)}; + for (LoweredFunc x : device_funcs_) { + ret.push_back(x); + } + return ret; + } + + private: + Stmt SplitDeviceFunc(Stmt body) { + std::ostringstream os; + os << name_ << "_kernel" << device_funcs_.size(); + std::shared_ptr n = std::make_shared(); + // isolate the device function. + IRUseDefAnalysis m; + m.visit_thread_extent_ = false; + n->body = m.Mutate(body); + n->name = os.str(); + n->args = m.undefined_; + CHECK_NE(m.thread_extent_.size(), 0U); + + // improve the handle data type + for (Var arg : n->args) { + auto it = handle_data_type_.find(arg.get()); + if (it != handle_data_type_.end()) { + n->handle_data_type.Set(arg, it->second); + } + } + LoweredFunc f_device(n); + Array call_args; + for (Var arg : n->args) { + call_args.push_back(arg); + } + + for (Expr ext : m.thread_extent_) { + call_args.push_back(ext); + } + device_funcs_.emplace_back(f_device); + return Evaluate::make(Call::make( + Int(32), f_device->name, call_args, Call::Extern, f_device)); + } + + // function name + std::string name_; + // the device functions + std::vector device_funcs_; + std::unordered_map handle_data_type_; +}; + + +Array UndefinedVars(const LoweredFunc& f) { + IRUseDefAnalysis m; + for (Var arg : f->args) { + m.use_count_[arg.get()] = 0; + } + m.Mutate(f->body); + return m.undefined_; +} + +Array SplitHostDevice(LoweredFunc func) { + return HostDeviceSplitter().Split(func); +} + +} // namespace codegen +} // namespace tvm diff --git a/src/pass/inline.cc b/src/pass/inline.cc index 085fe738eaeb..de452c364cd8 100644 --- a/src/pass/inline.cc +++ b/src/pass/inline.cc @@ -17,36 +17,28 @@ class IRInline : public IRMutator { IRInline(FunctionRef f, Array args, Expr body) : f_(f), args_(args), body_(body) {} - Expr Mutate(Expr expr) final { - expr = IRMutator::Mutate(expr); - const Call* call = expr.as(); - if (call != nullptr && call->func == f_) { - CHECK_EQ(call->value_index, 0); - return InlineCall(call); - } else { + Expr Mutate_(const Call* op, const Expr& e) final { + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + + if (op->func == f_) { + CHECK_EQ(op->value_index, 0); + Expr expr = body_; + CHECK_EQ(args_.size(), op->args.size()) + << op->args.size() << " vs " << args_.size(); + for (size_t i = 0; i < args_.size(); ++i) { + expr = Let::make(args_[i], op->args[i], expr); + } return expr; + } else { + return e; } } - Stmt Mutate(Stmt stmt) final { - return IRMutator::Mutate(stmt); - } - private: FunctionRef f_; Array args_; Expr body_; - - Expr InlineCall(const Call* op) { - Expr expr = body_; - - CHECK_EQ(args_.size(), op->args.size()) - << op->args.size() << " vs " << args_.size(); - for (size_t i = 0; i < args_.size(); ++i) { - expr = Let::make(args_[i], op->args[i], expr); - } - return expr; - } }; Stmt Inline(Stmt stmt, diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index ad0ace10fffa..85b0589ce60c 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -58,6 +58,183 @@ inline Array MutateRDom(Array rdom, IRMutator *m) { } } +#define DISPATCH_TO_MUTATE_STMT(OP) \ + set_dispatch([](const OP* op, const Stmt& s, IRMutator* m) { \ + return m->Mutate_(op, s); \ + }) + +#define DISPATCH_TO_MUTATE_EXPR(OP) \ + set_dispatch([](const OP* op, const Expr& e, IRMutator* m) { \ + return m->Mutate_(op, e); \ + }) + +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) +.DISPATCH_TO_MUTATE_STMT(LetStmt) +.DISPATCH_TO_MUTATE_STMT(AttrStmt) +.DISPATCH_TO_MUTATE_STMT(Provide) +.DISPATCH_TO_MUTATE_STMT(Realize) +.DISPATCH_TO_MUTATE_STMT(Store) +.DISPATCH_TO_MUTATE_STMT(For) +.DISPATCH_TO_MUTATE_STMT(Free); + +Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { + Expr value = this->Mutate(op->value); + Stmt body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return s; + } else { + return LetStmt::make(op->var, value, body); + } +} + +Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { + Expr value = this->Mutate(op->value); + Stmt body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return s; + } else { + return AttrStmt::make(op->node, op->type_key, value, body); + } +} + +Stmt IRMutator::Mutate_(const For *op, const Stmt& s) { + Expr min = this->Mutate(op->min); + Expr extent = this->Mutate(op->extent); + Stmt body = this->Mutate(op->body); + if (min.same_as(op->min) && + extent.same_as(op->extent) && + body.same_as(op->body)) { + return s; + } else { + return For::make( + op->loop_var, min, extent, op->for_type, op->device_api, body); + } +} + +Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { + IRMutator* m = this; + std::vector new_extents; + bool all_extents_unmodified = true; + for (size_t i = 0; i < op->extents.size(); i++) { + new_extents.push_back(m->Mutate(op->extents[i])); + all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); + } + Stmt body = m->Mutate(op->body); + Expr condition = m->Mutate(op->condition); + Expr new_expr; + if (op->new_expr.defined()) { + new_expr = m->Mutate(op->new_expr); + } + if (all_extents_unmodified && + body.same_as(op->body) && + condition.same_as(op->condition) && + new_expr.same_as(op->new_expr)) { + return s; + } else { + return Allocate::make( + op->buffer_var, op->type, + new_extents, condition, body, + new_expr, op->free_function); + } +} + +Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) { + auto new_args = MutateArray(op->args, this); + auto new_value = this->Mutate(op->value); + if (op->args.same_as(new_args) && op->value.same_as(new_value)) { + return s; + } else { + return Provide::make(op->func, op->value_index, new_value, new_args); + } +} + +Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { + IRMutator* m = this; + Halide::Internal::Region new_bounds; + bool bounds_changed = false; + + // Mutate the bounds + for (size_t i = 0; i < op->bounds.size(); i++) { + Expr old_min = op->bounds[i]->min; + Expr old_extent = op->bounds[i]->extent; + Expr new_min = m->Mutate(old_min); + Expr new_extent = m->Mutate(old_extent); + if (!new_min.same_as(old_min)) bounds_changed = true; + if (!new_extent.same_as(old_extent)) bounds_changed = true; + new_bounds.push_back( + Range::make_by_min_extent(new_min, new_extent)); + } + + Stmt body = m->Mutate(op->body); + Expr condition = m->Mutate(op->condition); + if (!bounds_changed && + body.same_as(op->body) && + condition.same_as(op->condition)) { + return s; + } else { + return Realize::make(op->func, op->value_index, + op->type, new_bounds, + condition, body); + } +} + +Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { + Expr value = this->Mutate(op->value); + Expr index = this->Mutate(op->index); + if (value.same_as(op->value) && index.same_as(op->index)) { + return s; + } else { + return Store::make(op->buffer_var, value, index); + } +} + +Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { + return s; +} + +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) +.DISPATCH_TO_MUTATE_EXPR(Call) +.DISPATCH_TO_MUTATE_EXPR(Let) +.DISPATCH_TO_MUTATE_EXPR(Load) +.DISPATCH_TO_MUTATE_EXPR(Variable); + +Expr IRMutator::Mutate_(const Call* op, const Expr& e) { + auto new_args = MutateArray(op->args, this); + if (op->args.same_as(new_args)) { + return e; + } else { + return Call::make(op->type, op->name, new_args, op->call_type, + op->func, op->value_index); + } +} + +Expr IRMutator::Mutate_(const Load *op, const Expr& e) { + Expr index = this->Mutate(op->index); + if (index.same_as(op->index)) { + return e; + } else { + return Load::make(op->type, op->buffer_var, index); + } +} + + +Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { + return e; +} + +Expr IRMutator::Mutate_(const Let *op, const Expr& e) { + Expr value = this->Mutate(op->value); + Expr body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return e; + } else { + return Let::make(op->var, value, body); + } +} + TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .set_dispatch([](const Reduce* op, const Expr& e, IRMutator* m) { Array new_rdom = MutateRDom(op->rdom, m); @@ -70,24 +247,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) } }); -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) -.set_dispatch([](const AttrStmt* op, const Stmt& s, IRMutator* m) { - Expr value = m->Mutate(op->value); - Stmt body = m->Mutate(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { - return s; - } else { - return AttrStmt::make(op->node, op->type_key, value, body); - } - }); - TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .set_dispatch(ReturnSelfExpr) .set_dispatch(ReturnSelfExpr) .set_dispatch(ReturnSelfExpr) -.set_dispatch(ReturnSelfExpr) -.set_dispatch(ReturnSelfExpr); +.set_dispatch(ReturnSelfExpr); TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .set_dispatch([](const Cast* op, const Expr& e, IRMutator* m) { @@ -150,14 +314,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) return Select::make(cond, t, f); } }) -.set_dispatch([](const Load *op, const Expr& e, IRMutator* m) { - Expr index = m->Mutate(op->index); - if (index.same_as(op->index)) { - return e; - } else { - return Load::make(op->type, op->buffer_var, index); - } - }) .set_dispatch([](const Ramp *op, const Expr& e, IRMutator* m) { Expr base = m->Mutate(op->base); Expr stride = m->Mutate(op->stride); @@ -175,38 +331,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) } else { return Broadcast::make(value, op->lanes); } - }) -.set_dispatch([](const Call *op, const Expr& e, IRMutator* m) { - auto new_args = MutateArray(op->args, m); - if (op->args.same_as(new_args)) { - return e; - } else { - return Call::make(op->type, op->name, new_args, op->call_type, - op->func, op->value_index); - } - }) -.set_dispatch([](const Let *op, const Expr& e, IRMutator* m) { - Expr value = m->Mutate(op->value); - Expr body = m->Mutate(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { - return e; - } else { - return Let::make(op->var, value, body); - } }); TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) -.set_dispatch([](const LetStmt *op, const Stmt& s, IRMutator* m) { - Expr value = m->Mutate(op->value); - Stmt body = m->Mutate(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { - return s; - } else { - return LetStmt::make(op->var, value, body); - } - }) .set_dispatch([](const AssertStmt *op, const Stmt& s, IRMutator* m) { Expr condition = m->Mutate(op->condition); Expr message = m->Mutate(op->message); @@ -225,93 +352,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) return ProducerConsumer::make(op->func, op->is_producer, body); } }) -.set_dispatch([](const For *op, const Stmt& s, IRMutator* m) { - Expr min = m->Mutate(op->min); - Expr extent = m->Mutate(op->extent); - Stmt body = m->Mutate(op->body); - if (min.same_as(op->min) && - extent.same_as(op->extent) && - body.same_as(op->body)) { - return s; - } else { - return For::make( - op->loop_var, min, extent, op->for_type, op->device_api, body); - } - }) -.set_dispatch([](const Store *op, const Stmt& s, IRMutator* m) { - Expr value = m->Mutate(op->value); - Expr index = m->Mutate(op->index); - if (value.same_as(op->value) && index.same_as(op->index)) { - return s; - } else { - return Store::make(op->buffer_var, value, index); - } - }) -.set_dispatch([](const Provide *op, const Stmt& s, IRMutator* m) { - auto new_args = MutateArray(op->args, m); - auto new_value = m->Mutate(op->value); - if (op->args.same_as(new_args) && op->value.same_as(new_value)) { - return s; - } else { - return Provide::make(op->func, op->value_index, new_value, new_args); - } - }) -.set_dispatch([](const Allocate *op, const Stmt& s, IRMutator* m) { - std::vector new_extents; - bool all_extents_unmodified = true; - for (size_t i = 0; i < op->extents.size(); i++) { - new_extents.push_back(m->Mutate(op->extents[i])); - all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); - } - Stmt body = m->Mutate(op->body); - Expr condition = m->Mutate(op->condition); - Expr new_expr; - if (op->new_expr.defined()) { - new_expr = m->Mutate(op->new_expr); - } - if (all_extents_unmodified && - body.same_as(op->body) && - condition.same_as(op->condition) && - new_expr.same_as(op->new_expr)) { - return s; - } else { - return Allocate::make( - op->buffer_var, op->type, - new_extents, condition, body, - new_expr, op->free_function); - } - }) -.set_dispatch([](const Free *op, const Stmt& s, IRMutator* m) { - return s; - }) -.set_dispatch([](const Realize *op, const Stmt& s, IRMutator* m) { - Halide::Internal::Region new_bounds; - bool bounds_changed = false; - - // Mutate the bounds - for (size_t i = 0; i < op->bounds.size(); i++) { - Expr old_min = op->bounds[i]->min; - Expr old_extent = op->bounds[i]->extent; - Expr new_min = m->Mutate(old_min); - Expr new_extent = m->Mutate(old_extent); - if (!new_min.same_as(old_min)) bounds_changed = true; - if (!new_extent.same_as(old_extent)) bounds_changed = true; - new_bounds.push_back( - Range::make_by_min_extent(new_min, new_extent)); - } - - Stmt body = m->Mutate(op->body); - Expr condition = m->Mutate(op->condition); - if (!bounds_changed && - body.same_as(op->body) && - condition.same_as(op->condition)) { - return s; - } else { - return Realize::make(op->func, op->value_index, - op->type, new_bounds, - condition, body); - } - }) .set_dispatch([](const Block *op, const Stmt& s, IRMutator* m) { Stmt first = m->Mutate(op->first); Stmt rest = m->Mutate(op->rest); diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h new file mode 100644 index 000000000000..794dcd820715 --- /dev/null +++ b/src/pass/ir_util.h @@ -0,0 +1,70 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file ir_util.h + * \brief Helper functions to construct and compose IR nodes. + */ +#ifndef TVM_PASS_IR_UTIL_H_ +#define TVM_PASS_IR_UTIL_H_ + +#include +#include + +namespace tvm { +namespace ir { + +/*! + * \brief combine the nest stmt, whose body is not defined. + * \param nest A list of For and LetStmt, whose body is not defined. + * \param body body + * \return The combined Stmt + */ +inline Stmt MergeNest(std::vector nest, Stmt body) { + // use reverse iteration + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + Stmt s = *ri; + if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (s.as()) { + auto n = std::make_shared(*s.as()); + CHECK(is_no_op(n->then_case)); + CHECK(!n->else_case.defined()); + n->then_case = body; + body = Stmt(n); + } else if (s.as()) { + body = Block::make(s, body); + } else { + LOG(FATAL) << "not supported nest type"; + } + } + return body; +} + +/*! + * \brief combine the nest stmt, whose body is not defined. + * \param nest A list of For and LetStmt, whose body is not defined. + * \param body body + * \return The combined Stmt + */ +inline Stmt MergeNest(std::vector > nest, Stmt body) { + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + body = MergeNest(*ri, body); + } + return body; +} + +} // namespace ir +} // namespace tvm +#endif // TVM_PASS_IR_UTIL_H_ diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 3bbcbbd002ad..77ce3928f2fe 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -8,7 +8,6 @@ namespace tvm { namespace ir { -namespace { // visitor to implement apply class IRApplyVisit : public IRVisitor { public: @@ -26,7 +25,6 @@ class IRApplyVisit : public IRVisitor { std::unordered_set visited_; }; -} // namespace void PostOrderVisit(const NodeRef& node, std::function fvisit) { IRApplyVisit(fvisit).Visit(node); @@ -36,12 +34,6 @@ IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) static FVisit inst; return inst; } - -// namespace to register the functors. -namespace { - -using namespace Halide::Internal; - void NoOp(const NodeRef& n, IRVisitor* v) { } @@ -59,24 +51,82 @@ inline void VisitRDom(const Array& rdom, IRVisitor* v) { } } -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch([](const Reduce* op, IRVisitor* v) { - VisitRDom(op->rdom, v); - v->Visit(op->source); - }); +#define DISPATCH_TO_VISIT(OP) \ + set_dispatch([](const OP* op, IRVisitor* v) { \ + v->Visit_(op); \ + }) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch([](const AttrStmt* op, IRVisitor* v) { - v->Visit(op->value); - v->Visit(op->body); - }); +.DISPATCH_TO_VISIT(Variable) +.DISPATCH_TO_VISIT(LetStmt) +.DISPATCH_TO_VISIT(For) +.DISPATCH_TO_VISIT(Allocate) +.DISPATCH_TO_VISIT(Load) +.DISPATCH_TO_VISIT(Store) +.DISPATCH_TO_VISIT(Let) +.DISPATCH_TO_VISIT(Call) +.DISPATCH_TO_VISIT(Free); + +void IRVisitor::Visit_(const Variable* op) {} + +void IRVisitor::Visit_(const LetStmt *op) { + this->Visit(op->value); + this->Visit(op->body); +} + +void IRVisitor::Visit_(const AttrStmt* op) { + this->Visit(op->value); + this->Visit(op->body); +} + +void IRVisitor::Visit_(const For *op) { + IRVisitor* v = this; + v->Visit(op->min); + v->Visit(op->extent); + v->Visit(op->body); +} + +void IRVisitor::Visit_(const Allocate *op) { + IRVisitor* v = this; + for (size_t i = 0; i < op->extents.size(); i++) { + v->Visit(op->extents[i]); + } + v->Visit(op->body); + v->Visit(op->condition); + if (op->new_expr.defined()) { + v->Visit(op->new_expr); + } +} + +void IRVisitor::Visit_(const Load *op) { + this->Visit(op->index); +} + +void IRVisitor::Visit_(const Store *op) { + this->Visit(op->value); + this->Visit(op->index); +} + +void IRVisitor::Visit_(const Let *op) { + this->Visit(op->value); + this->Visit(op->body); +} + +void IRVisitor::Visit_(const Free* op) {} + +void IRVisitor::Visit_(const Call *op) { + VisitArray(op->args, this); +} TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) +.set_dispatch([](const Reduce* op, IRVisitor* v) { + VisitRDom(op->rdom, v); + v->Visit(op->source); + }) .set_dispatch(NoOp) .set_dispatch(NoOp) .set_dispatch(NoOp) -.set_dispatch(NoOp) -.set_dispatch(NoOp); +.set_dispatch(NoOp); TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .set_dispatch([](const Cast* op, IRVisitor* v) { @@ -116,29 +166,15 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) v->Visit(op->true_value); v->Visit(op->false_value); }) -.set_dispatch([](const Load *op, IRVisitor* v) { - v->Visit(op->index); - }) .set_dispatch([](const Ramp *op, IRVisitor* v) { v->Visit(op->base); v->Visit(op->stride); }) .set_dispatch([](const Broadcast *op, IRVisitor* v) { v->Visit(op->value); - }) -.set_dispatch([](const Call *op, IRVisitor* v) { - VisitArray(op->args, v); - }) -.set_dispatch([](const Let *op, IRVisitor* v) { - v->Visit(op->value); - v->Visit(op->body); }); TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch([](const LetStmt *op, IRVisitor* v) { - v->Visit(op->value); - v->Visit(op->body); - }) .set_dispatch([](const AssertStmt *op, IRVisitor* v) { v->Visit(op->condition); v->Visit(op->message); @@ -146,30 +182,10 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .set_dispatch([](const ProducerConsumer *op, IRVisitor* v) { v->Visit(op->body); }) -.set_dispatch([](const For *op, IRVisitor* v) { - v->Visit(op->min); - v->Visit(op->extent); - v->Visit(op->body); - }) -.set_dispatch([](const Store *op, IRVisitor* v) { - v->Visit(op->value); - v->Visit(op->index); - }) .set_dispatch([](const Provide *op, IRVisitor* v) { VisitArray(op->args, v); v->Visit(op->value); }) -.set_dispatch([](const Allocate *op, IRVisitor* v) { - for (size_t i = 0; i < op->extents.size(); i++) { - v->Visit(op->extents[i]); - } - v->Visit(op->body); - v->Visit(op->condition); - if (op->new_expr.defined()) { - v->Visit(op->new_expr); - } - }) -.set_dispatch(NoOp) .set_dispatch([](const Realize *op, IRVisitor* v) { // Mutate the bounds for (size_t i = 0; i < op->bounds.size(); i++) { @@ -193,6 +209,5 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) v->Visit(op->value); }); -} // namespace } // namespace ir } // namespace tvm diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc index a62cf678b8cf..c2332a819609 100644 --- a/src/pass/schedule_ops.cc +++ b/src/pass/schedule_ops.cc @@ -9,6 +9,7 @@ #include #include "./scope.h" +#include "./ir_util.h" #include "../schedule/graph.h" namespace tvm { @@ -32,18 +33,27 @@ void PassUpOffset(const Stage& s, Expr outer = state.at(s->outer); Expr inner = state.at(s->inner); Expr factor = dom_map.at(s->inner)->extent; - Expr offset = inner + outer * factor; - Expr outer_min = dom_map.at(s->parent)->min; - if (!is_zero(outer_min)) { - offset = outer_min + offset; + Expr parent_min = dom_map.at(s->parent)->min; + state[s->parent] = inner + outer * factor; + // add min if they exist + if (!is_zero(parent_min)) { + state[s->parent] = parent_min + state[s->parent]; } - state[s->parent] = offset; } else if (rel.as()) { const FuseNode* s = rel.as(); Expr value = state.at(s->fused); Expr factor = dom_map.at(s->inner)->extent; + Expr outer_min = dom_map.at(s->outer)->min; + Expr inner_min = dom_map.at(s->inner)->min; state[s->outer] = value / factor; state[s->inner] = value % factor; + // add min if they exist + if (!is_zero(outer_min)) { + state[s->outer] = outer_min + state[s->outer]; + } + if (!is_zero(inner_min)) { + state[s->inner] = outer_min + state[s->inner]; + } } else { LOG(FATAL) << "unknown relation type"; } @@ -81,45 +91,6 @@ void SplitByAdd(Expr expr, } } -/*! - * \brief combine the nest stmt, whose body is not defined. - * \param nest A list of For and LetStmt, whose body is not defined. - * \param body body - */ -Stmt MergeNest(std::vector > nest, Stmt body) { - // use reverse iteration - for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { - for (auto rj = ri->rbegin(); rj != ri->rend(); ++rj) { - Stmt s = *rj; - if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (s.as()) { - auto n = std::make_shared(*s.as()); - CHECK(is_no_op(n->then_case)); - CHECK(!n->else_case.defined()); - n->then_case = body; - body = Stmt(n); - } else { - LOG(FATAL) << "not supported nest type"; - } - } - } - return body; -} - /*! * \brief Make the loop nest of the correspondings schedule. * \param sch The schedule. @@ -142,16 +113,32 @@ std::vector > MakeLoopNest( for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { auto iv = leaf_iter_vars[i]; + Range dom = dom_map.at(iv); // initialize the offset and loop_level offset[iv] = iv->var; loop_level[iv->var.as()] = i + 1; // Mark the iter var in the IR, to remember the point if (iv->thread_tag.length() == 0) { - Range dom = dom_map.at(iv); + if (is_zero(dom->min)) { + nest[i + 1].emplace_back( + For::make(iv->var, 0, dom->extent, + ForType::Serial, DeviceAPI::None, no_op)); + } else { + Var idx(iv->var->name_hint + ".idx", iv->var.type()); + nest[i + 1].emplace_back( + For::make(idx, 0, dom->extent, + ForType::Serial, DeviceAPI::None, no_op)); + nest[i + 1].emplace_back( + LetStmt::make(iv->var, dom->min + idx, no_op)); + } + } else { + // Always restrict threaded IterVar to starts from 0. + CHECK(is_zero(dom->min)); + // annotate the extent of the IterVar nest[i + 1].emplace_back( - For::make(iv->var, dom->min, dom->extent, - ForType::Serial, DeviceAPI::None, no_op)); + AttrStmt::make(iv, "thread_extent", dom->extent, no_op)); } + // annotate the extent of the IterVar nest[i + 1].emplace_back( AttrStmt::make(iv, "scope", iv->var, no_op)); } diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc new file mode 100644 index 000000000000..38939459722b --- /dev/null +++ b/src/pass/simple_passes.cc @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file simple_passes.cc + * \brief Implementation of simple passes + */ +#include +#include +#include + +namespace tvm { +namespace ir { + +class IRSideEffect : public IRVisitor { + public: + void Visit(const NodeRef& e) final { + if (has_side_effect_) return; + } + + void Visit_(const Call* op) final { + if (!op->is_pure()) { + has_side_effect_ = true; return; + } else { + IRVisitor::Visit_(op); + } + } + + bool has_side_effect_{false}; +}; + +bool HasSideEffect(const Expr& e) { + IRSideEffect v; + v.Visit(e); + return v.has_side_effect_; +} +} // namespace ir +} // namespace tvm diff --git a/tests/python/test_codegen_cuda.py b/tests/python/test_codegen_cuda.py index 0f0a8df30506..dc20dda363d4 100644 --- a/tests/python/test_codegen_cuda.py +++ b/tests/python/test_codegen_cuda.py @@ -24,31 +24,15 @@ def mock_test_add(): Bb = tvm.Buffer(B.shape, B.dtype, name='B') Cb = tvm.Buffer(C.shape, C.dtype, name='C') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) + stmt = tvm.ir_pass.Simplify(stmt) print(stmt) output_ssa = False - code = tvm.codegen.CompileToC(stmt, "myadd", - [Ab.ptr, Bb.ptr, Cb.ptr, n], - output_ssa) - - print(code) - def codegen(): - # generate host/device code - host_code, device_code = tvm.codegen.GenCUDA( - s, - inputs={A: Ab, B:Bb}, - outputs={C: Cb}, - args=[A, B, C]) - # generate a function based on the code - f = tvm.cuda.build_function(host_code, device_code) - # create arrays - a = tvm.nd.array(np.ones(10), ctx=tvm.gpu(0)) - b = tvm.nd.array(np.ones(10), ctx=tvm.gpu(0)) - c = tvm.nd.array(np.zeros(10), ctx=tvm.gpu(0)) - # calll the generated code - f(a, b, c) - # sync the result - np.testing.assert_equal(c.asnumpy(), np.ones(10) * 2) + f = tvm.codegen.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 1) + f_list = tvm.codegen.SplitHostDevice(f) + for x in f_list: + code = tvm.codegen.CompileToC(x, output_ssa) + print(code) if __name__ == "__main__": mock_test_add() diff --git a/tests/python/test_codegen_makeapi.py b/tests/python/test_codegen_makeapi.py new file mode 100644 index 000000000000..ebe6f4e63da5 --- /dev/null +++ b/tests/python/test_codegen_makeapi.py @@ -0,0 +1,27 @@ +import tvm +import numpy + +def test_makeapi(): + """Not yet working, mock design""" + n = tvm.Var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.Schedule(C.op) + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.ir_pass.ScheduleOps(s, bounds) + + Ab = tvm.Buffer(A.shape, A.dtype, name='A') + Bb = tvm.Buffer(B.shape, B.dtype, name='B') + Cb = tvm.Buffer(C.shape, C.dtype, name='C') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) + num_packed_args = 2 + f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args) + assert(f.handle_data_type[Ab.ptr].dtype == Ab.dtype) + assert(len(f.args) == 5) + output_ssa = False + + +if __name__ == "__main__": + test_makeapi() diff --git a/tests/python/test_pass_storage_flatten.py b/tests/python/test_pass_storage_flatten.py index b7dff05d0f6f..98200bc7d528 100644 --- a/tests/python/test_pass_storage_flatten.py +++ b/tests/python/test_pass_storage_flatten.py @@ -18,6 +18,7 @@ def test_flatten2(): Ab = tvm.Buffer(A.shape, A.dtype, name='A') A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) + stmt = tvm.ir_pass.Simplify(stmt) print(stmt) if __name__ == "__main__":