diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 4f0564ec7694..946a54012d0c 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 4f0564ec769477c66d480dd966088f172050c874 +Subproject commit 946a54012d0c390675ab5b46cd990838d4183d6f diff --git a/include/tvm/expr.h b/include/tvm/expr.h index a199d656caf8..8bb432ab6641 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -118,6 +118,8 @@ class Range : public HalideIR::IR::Range { TVM_DLL static Range make_by_min_extent(Expr min, Expr extent); }; +using Region = Array; + /*! * \brief Type of iteration variable. * Each IterVar have a specific type. diff --git a/include/tvm/operation.h b/include/tvm/operation.h index c11242c0a55d..1a1d28ab71bb 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode { } /*! * \return The list of iteration variable at root - * \note root_iter_vars dedides the shape of the outputs. + * \note root_iter_vars decides the shape of the outputs. */ virtual Array root_iter_vars() const = 0; /*! @@ -239,6 +239,74 @@ class TVM_DLL ComputeOpNode : public OperationNode { TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode); }; +/*! + * \brief A TenorCompute op that compute a tensor with an tensor intrinsic. + */ +class TensorComputeOpNode : public OperationNode { + public: + /*! \brief IterVar on each axis */ + Array axis; + /*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */ + Array reduce_axis; + /*! \brief number of axes that can be scheduled */ + int schedulable_ndim; + /*! \brief TensorIntrin used to compute */ + TensorIntrin intrin; + /*! \brief input tensors of intrin */ + Array inputs; + /*! \brief region of input tensors */ + Array input_regions; + /*! \brief constructor */ + TensorComputeOpNode() {} + // override functions + int num_outputs() const final; + Array root_iter_vars() const final; + Type output_dtype(size_t i) const final; + Array output_shape(size_t i) const final; + Array InputTensors() const final; + Operation ReplaceInputs( + const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs( + const Operation& self, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound( + const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize( + const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide( + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("tag", &tag); + v->Visit("axis", &axis); + v->Visit("reduce_axis", &reduce_axis); + v->Visit("schedulable_ndim", &schedulable_ndim); + v->Visit("intrin", &intrin); + v->Visit("inputs", &inputs); + v->Visit("input_regions", &input_regions); + } + static Operation make(std::string name, + std::string tag, + Array axis, + Array reduce_axis, + int schedulable_ndim, + TensorIntrin intrin, + Array tensors, + Array regions); + + static constexpr const char* _type_key = "TensorComputeOp"; + TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode); +}; + /*! * \brief Symbolic scan. */ @@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode { public: /*! \brief The input tensors */ Array inputs; - /*! \brief Symbolic placeholder representationinputs */ + /*! \brief Symbolic placeholder representation of inputs */ Array input_placeholders; /*! \brief Symbolic placeholder representation of outputs */ Array output_placeholders; diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 944498d1e615..fbee4bccc0bf 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -89,5 +89,58 @@ class TensorIntrinNode : public Node { inline const TensorIntrinNode* TensorIntrin::operator->() const { return static_cast(node_.get()); } + + +// Internal node container of tensor intrinsic calling. +class TensorIntrinCallNode; + +/*! \brief Tensor intrinsic calling node. */ +class TensorIntrinCall : public NodeRef { + public: + TensorIntrinCall() {} + explicit TensorIntrinCall(NodePtr n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const TensorIntrinCallNode* operator->() const; + + /*! \brief specify container node */ + using ContainerType = TensorIntrinCallNode; +}; + +class TensorIntrinCallNode : public Node { + public: + /*! \brief the tensor intrinsic */ + TensorIntrin intrin; + /*! \brief input tensors of the intrinsic */ + Array tensors; + /*! \brief regions of input tensors */ + Array regions; + /*! + * \brief IterVar on each reduction axis, if the + * intrin will use the reduce axis + */ + Array reduce_axis; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("intrin", &intrin); + v->Visit("tensors", &tensors); + v->Visit("regions", ®ions); + v->Visit("reduce_axis", &reduce_axis); + } + static TensorIntrinCall make(TensorIntrin intrin, + Array tensors, + Array regions, + Array reduce_axis); + + static constexpr const char* _type_key = "TensorIntrinCall"; + TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node); +}; + +inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { + return static_cast(node_.get()); +} + } // namespace tvm #endif // TVM_TENSOR_INTRIN_H_ diff --git a/python/tvm/api.py b/python/tvm/api.py index 223e73eeb596..793afe52e5fd 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.current.tag shape = (shape,) if isinstance(shape, _expr.Expr) else shape + # for python3 + shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) ndim = len(shape) code = fcompute.__code__ - if fcompute.__code__.co_argcount == 0: + out_ndim = ndim + if code.co_argcount == 0: arg_names = ["i%d" % i for i in range(ndim)] else: arg_names = code.co_varnames[:code.co_argcount] + out_ndim = code.co_argcount - if ndim != len(arg_names): + if out_ndim != len(arg_names): raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) - dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)] + dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])] body = fcompute(*[v.var for v in dim_var]) - if not isinstance(body, (list, tuple)): - body = [body] - body = convert(body) - op_node = _api_internal._ComputeOp( - name, tag, attrs, dim_var, body) + + if isinstance(body, _tensor.TensorIntrinCall): + for i, s in enumerate(shape[out_ndim:]): + var_name = "ax" + str(i) + dim_var.append(_IterVar((0, s), var_name, 4)) + op_node = _api_internal._TensorComputeOp(name, + tag, + dim_var, + body.reduce_axis, + out_ndim, + body.intrin, + body.tensors, + body.regions) + else: + if not isinstance(body, (list, tuple)): + body = [body] + body = convert(body) + op_node = _api_internal._ComputeOp( + name, tag, attrs, dim_var, body) + num = op_node.num_outputs outputs = tuple(op_node.output(i) for i in range(num)) return outputs[0] if num == 1 else outputs @@ -529,14 +548,14 @@ def decl_buffer(shape, dtype = float32 if dtype is None else dtype strides = () if strides is None else strides if offset_factor != 0 and elem_offset is None: - elem_offset = var('%s_elem_offset' % name, shape[0].dtype) + shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" + elem_offset = var('%s_elem_offset' % name, shape_dtype) if data is None: data = var(name, "handle") return _api_internal._Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor) - def _IterVar(dom, name, iter_type, thread_tag=''): """Internal function to create IterVar diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index f0d60f514a37..f32b70eb9a12 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -30,6 +30,11 @@ def dtype(self): """Data content of the tensor.""" return self.tensor.dtype +@register_node +class TensorIntrinCall(NodeBase): + """Intermediate structure for calling a tensor intrinsic.""" + pass + itervar_cls = None @@ -106,6 +111,7 @@ def name(self): return "%s.v%d" % (op.name, self.value_index) + class Operation(NodeBase): """Represent an operation that generate a tensor""" @@ -155,6 +161,12 @@ def reduce_axis(self): return self.__getattr__("reduce_axis") +@register_node +class TensorComputeOp(Operation): + """Tensor operation.""" + pass + + @register_node class ScanOp(Operation): """Scan operation.""" diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 193124b2f946..f1f26655fe27 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -6,9 +6,25 @@ from . import stmt as _stmt from . import make as _make from . import tensor as _tensor +from . import schedule as _schedule from .build_module import current_build_config from ._ffi.node import NodeBase, register_node + +def _get_region(tslice): + region = [] + for idx in tslice.indices: + if isinstance(idx, slice): + assert idx.step is None + region.append(_api.Range(idx.start, idx.stop)) + else: + if isinstance(idx, _schedule.IterVar): + begin = idx.var + else: + begin = idx + region.append(_make.range_by_min_extent(begin, 1)) + return region + @register_node class TensorIntrin(NodeBase): """Tensor intrinsic functions for certain computation. @@ -17,8 +33,16 @@ class TensorIntrin(NodeBase): -------- decl_tensor_intrin: Construct a TensorIntrin """ - pass - + def __call__(self, *args, **kwargs): + tensors = [x.tensor for x in args] + regions = [_get_region(x) for x in args] + reduce_axis = [] + if "reduce_axis" in kwargs: + reduce_axis = kwargs["reduce_axis"] + if not isinstance(reduce_axis, (list, tuple)): + reduce_axis = [reduce_axis] + reduce_axis = _api.convert(reduce_axis) + return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis) def decl_tensor_intrin(op, fcompute, diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 8ca49f19baec..75365da5bf50 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin") args[6]); }); +TVM_REGISTER_API("_TensorIntrinCall") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TensorIntrinCallNode::make(args[0], + args[1], + args[2], + args[3]); + }); + TVM_REGISTER_API("_TensorEqual") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Tensor() == args[1].operator Tensor(); @@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp") args[7]); }); +TVM_REGISTER_API("_TensorComputeOp") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TensorComputeOpNode::make(args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + args[7]); + }); + TVM_REGISTER_API("_ExternOp") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = ExternOpNode::make(args[0], diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 4f9c3e9d1782..9b1a58abcee4 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -10,6 +10,8 @@ namespace tvm { +// Tensor + Expr Tensor::operator()(Array indices) const { Array arr(indices.begin(), indices.end()); return operator()(arr); @@ -26,6 +28,15 @@ Expr Tensor::operator()(Array indices) const { return n; } +Tensor Operation::output(size_t i) const { + auto node = make_node(); + node->op = *this; + node->value_index = i; + node->dtype = (*this)->output_dtype(i); + node->shape = (*this)->output_shape(i); + return Tensor(node); +} + Tensor TensorNode::make(Array shape, Type dtype, Operation op, @@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(TensorNode); -Tensor Operation::output(size_t i) const { - auto node = make_node(); - node->op = *this; - node->value_index = i; - node->dtype = (*this)->output_dtype(i); - node->shape = (*this)->output_shape(i); - return Tensor(node); -} + +// TensorIntrin TensorIntrin TensorIntrinNode::make(std::string name, Operation op, @@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + + +// TensorIntrinCall + +TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, + Array tensors, + Array regions, + Array reduce_axis) { + auto n = make_node(); + n->intrin = std::move(intrin); + n->tensors = std::move(tensors); + n->regions = std::move(regions); + n->reduce_axis = std::move(reduce_axis); + return TensorIntrinCall(n); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const TensorIntrinCallNode *n, IRPrinter *p) { + p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; + }); + +TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); + } // namespace tvm diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 6100c957e473..0c40882c0be2 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -13,6 +13,7 @@ #include "compute_op.h" #include "op_util.h" #include "../schedule/message_passing.h" +#include "../arithmetic/compute_expr.h" namespace tvm { @@ -542,4 +543,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) { v.Run(); } +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + const ComputeLoopNest& n, + Stmt body, + Stmt update) { + Array conds; + std::unordered_set banned; + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + IterVar iv = stage->leaf_iter_vars[i]; + auto iit = stage->iter_var_attrs.find(iv); + if (iit != stage->iter_var_attrs.end()) { + const IterVarAttr& attr = (*iit).second; + if (attr->iter_type == kTensorized) { + break; + } + } + if (iv->iter_type == kCommReduce) { + auto vit = dom_map.find(iv); + CHECK(vit != dom_map.end()); + const Range& vrange = vit->second; + conds.push_back(likely(iv->var > vrange->min)); + banned.insert(iv->var.get()); + } + } + for (const Expr& pred : n.main_predicates) { + if (ir::ExprUseVar(pred, banned)) { + LOG(FATAL) << "Tensorize update transform failed, the condition " + << pred << " has a conflict with the reset condition"; + } + } + + return IfThenElse::make(arith::ComputeReduce(conds, const_true(1)), + update, body); +} } // namespace tvm diff --git a/src/op/compute_op.h b/src/op/compute_op.h index 996764c6cdc1..87b0814c1ad9 100644 --- a/src/op/compute_op.h +++ b/src/op/compute_op.h @@ -14,7 +14,7 @@ namespace tvm { // loop nest structure for general compute -// This the the loop nest structured used in compute. +// This the loop nest structured used in compute. // Does not include the loop body. struct ComputeLoopNest { // The common number of loops between init and main @@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop); + +/*! + * \brief Transform the update part when there is no init func in tensorizing + * \param stage The stage for tensorizing. + * \param dom_map The range of each iter var. + * \param n The loop nest structured used in compute. + * \param body The body func in tensorize intrin + * \param update The update func in tensorize intrin + * \return Transformed result. + */ +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + const ComputeLoopNest& n, + Stmt body, + Stmt update); } // namespace tvm #endif // TVM_OP_COMPUTE_OP_H_ diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc new file mode 100644 index 000000000000..f9b8188d4685 --- /dev/null +++ b/src/op/tensor_compute_op.cc @@ -0,0 +1,361 @@ +/*! + * Copyright (c) 2017 by Contributors + * \brief Tensor Compute Op. + * \file tensor_compute_op.cc + */ +#include +#include +#include +#include +#include +#include +#include "./op_util.h" +#include "./compute_op.h" +#include "../arithmetic/compute_expr.h" + +namespace tvm { +using namespace ir; +// TensorComputeOpNode +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const TensorComputeOpNode *op, + IRPrinter *p) { + p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; + }); + +TVM_REGISTER_NODE_TYPE(TensorComputeOpNode); + +int TensorComputeOpNode::num_outputs() const { + return static_cast(this->intrin->buffers.size() - this->inputs.size()); +} + +Array TensorComputeOpNode::root_iter_vars() const { + Array ret = axis; + for (IterVar iv : reduce_axis) { + ret.push_back(iv); + } + return ret; +} + +Type TensorComputeOpNode::output_dtype(size_t i) const { + return this->intrin->buffers[this->inputs.size() + i]->dtype; +} + +Array TensorComputeOpNode::output_shape(size_t i) const { + Array shape; + for (const auto& ivar : this->axis) { + shape.push_back(ivar->dom->extent); + } + return shape; +} + + +Operation TensorComputeOpNode::make(std::string name, + std::string tag, + Array axis, + Array reduce_axis, + int schedulable_ndim, + TensorIntrin intrin, + Array tensors, + Array regions) { + auto n = make_node(); + n->name = std::move(name); + n->tag = std::move(tag); + n->axis = std::move(axis); + n->reduce_axis = std::move(reduce_axis); + n->schedulable_ndim = std::move(schedulable_ndim); + n->intrin = std::move(intrin); + n->inputs = std::move(tensors); + n->input_regions = std::move(regions); + return Operation(n); +} + +Array TensorComputeOpNode::InputTensors() const { + return inputs; +} + +Operation TensorComputeOpNode::ReplaceInputs( + const Operation& self, + const std::unordered_map& rmap) const { + CHECK_EQ(self.operator->(), this); + auto n = make_node(*this); + auto intrin = make_node(*(this->intrin.operator->())); + intrin->body = op::ReplaceTensor(this->intrin->body, rmap); + if (intrin->reduce_init.defined()) { + intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap); + } + if (intrin->reduce_update.defined()) { + intrin->reduce_update = op::ReplaceTensor(this->intrin->reduce_update, rmap); + } + for (size_t i = 0; i < n->inputs.size(); ++i) { + Tensor t = n->inputs[i]; + if (rmap.count(t)) { + n->inputs.Set(i, rmap.at(t)); + } + } + + if (intrin->body.same_as(n->intrin->body) && + intrin->reduce_init.same_as(n->intrin->reduce_init) && + intrin->reduce_update.same_as(n->intrin->reduce_update) && + inputs.same_as(n->inputs)) { + return self; + } else { + n->intrin = TensorIntrin(intrin); + return Operation(n); + } +} + +void TensorComputeOpNode::PropBoundToInputs( + const Operation& self, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { + for (size_t i = 0; i < this->inputs.size(); ++i) { + Tensor t = this->inputs[i]; + Region region = input_regions[i]; + + auto it = out_dom_map->find(t); + if (it == out_dom_map->end()) continue; + TensorDom& dom = it->second; + for (size_t j = 0; j < t.ndim(); ++j) { + dom.data[j].emplace_back(EvalSet(region[j], dom_map)); + } + } +} + +void TensorComputeOpNode::GatherBound( + const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { + const TensorDom& tdom = tensor_dom.at(self.output(0)); + for (size_t i = 0; i < this->axis.size(); ++i) { + Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom); + CHECK(!out_dom_map->count(this->axis[i])); + (*out_dom_map)[this->axis[i]] = r; + } + for (size_t i = 0; i < this->reduce_axis.size(); ++i) { + CHECK(!out_dom_map->count(this->reduce_axis[i])); + (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom; + } +} + +Stmt TensorComputeOpNode::BuildRealize( + const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { + CHECK_EQ(stage->op.get(), this); + HalideIR::Internal::Region bounds; + for (IterVar iv : this->axis) { + bounds.push_back(realize_map.at(iv)); + } + Stmt realize = body; + for (int i = this->num_outputs(); i > 0; --i) { + Tensor t = stage->op.output(i-1); + realize = ir::Realize::make(t->op, t->value_index, + t->dtype, bounds, const_true(), realize); + // alignment requirement, only useful for compute + for (int i = 0; i < schedulable_ndim; ++i) { + auto it = stage->iter_var_attrs.find(this->axis[i]); + if (it != stage->iter_var_attrs.end()) { + IterVarAttr attr = (*it).second; + if (attr->dim_align_factor != 0) { + Array tuple = {static_cast(i), + attr->dim_align_factor, + attr->dim_align_offset}; + realize = ir::AttrStmt::make( + t, ir::attr::buffer_dim_align, + Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), + realize); + } + } + } + } + return realize; +} + +ComputeLoopNest MakeLoopNest( + const TensorComputeOpNode* self, + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { + CHECK_EQ(stage->op.operator->(), self); + ComputeLoopNest ret; + // make main loop nest + ret.main_nest = op::MakeLoopNest( + stage, dom_map, 0, false, std::unordered_set(), &ret.main_vmap, + debug_keep_trivial_loop); + ret.main_predicates = schedule::MakeBoundCheck( + stage, dom_map, ret.main_vmap, false, + std::unordered_set()); + for (auto& e : ret.main_predicates) { + e = likely(e); + } + if (stage->store_predicate.defined()) { + ret.main_predicates.push_back(stage->store_predicate); + } + if (self->reduce_axis.size() != 0) { + // try to find the location to insert the initialization. + // Fuse the initialization and provide loop when possible. + std::unordered_map update_state; + for (IterVar iv : self->reduce_axis) { + update_state[iv] = 2; + } + for (int i = 0; i < self->schedulable_ndim; ++i) { + update_state[self->axis[i]] = 1; + } + // find which iter var is related to reduction and which is related to axis. + schedule::PassDownBitMaskOr(stage, &update_state); + auto leaf_iter_vars = stage->leaf_iter_vars; + // first first loop that is related to reduction. + size_t begin_loop = leaf_iter_vars.size(); + for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { + auto iv = leaf_iter_vars[i]; + int flag = update_state.at(iv); + if ((flag & 2) != 0) { + begin_loop = i; break; + } + ret.init_vmap[iv] = ret.main_vmap.at(iv); + } + ret.num_common_loop = begin_loop; + // skip loops that does not relates to axis. + std::unordered_set skip_iter; + for (auto kv : update_state) { + int flag = kv.second; + if ((flag & 1) == 0) skip_iter.insert(kv.first); + } + ret.init_nest = op::MakeLoopNest( + stage, dom_map, begin_loop, true, + skip_iter, &(ret.init_vmap), debug_keep_trivial_loop); + ret.init_predicates = schedule::MakeBoundCheck( + stage, dom_map, ret.init_vmap, true, skip_iter); + for (auto& e : ret.init_predicates) { + e = likely(e); + } + } else { + CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1); + ret.num_common_loop = stage->leaf_iter_vars.size(); + } + // copy elison here. + return ret; +} + + +Stmt TensorComputeOpNode::BuildProvide( + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { + CHECK_EQ(stage->op.operator->(), this); + + // Start bind data. + Stmt nop = Evaluate::make(0); + std::vector input_bind_nest, output_bind_nest; + Array inputs = this->InputTensors(); + + // input binding + size_t num_inputs = inputs.size(); + for (size_t i = 0; i < num_inputs; ++i) { + Tensor tensor = inputs[i]; + Region region = this->input_regions[i]; + Buffer buffer = this->intrin->buffers[i]; + Array bind_spec{buffer, tensor}; + + Array tuple; + for (size_t i = 0; i < region.size(); ++i) { + tuple.push_back(region[i]->min); + tuple.push_back(region[i]->extent); + } + input_bind_nest.emplace_back(AttrStmt::make( + bind_spec, ir::attr::buffer_bind_scope, + Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + } + + // output binding + for (int i = 0; i < this->num_outputs(); ++i) { + Tensor tensor = stage->op.output(i); + Buffer buffer = this->intrin->buffers[num_inputs + i]; + Array bind_spec{buffer, tensor}; + + Array tuple; + for (size_t i = 0; i < this->axis.size(); ++i) { + auto ivar = this->axis[i]; + if (i < static_cast(this->schedulable_ndim)) { + tuple.push_back(ivar->var); + tuple.push_back(1); + } else { + Range dom = ivar->dom; + tuple.push_back(dom->min); + tuple.push_back(dom->extent); + } + } + + output_bind_nest.emplace_back(AttrStmt::make( + bind_spec, ir::attr::buffer_bind_scope, + Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + } + + // Check variable remap + std::unordered_map vmap; + ir::ArgBinder binder(&vmap); + + size_t tloc = stage->leaf_iter_vars.size(); + ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop); + + if (this->reduce_axis.size() == 0) { + std::vector > nest( + n.main_nest.begin(), n.main_nest.begin() + tloc + 1); + nest.emplace_back(op::MakeIfNest(n.main_predicates)); + CHECK_EQ(n.init_predicates.size(), 0U); + CHECK(this->intrin->body.defined()) + << "Normal store op for intrin " << this << " is not defined"; + Stmt body = MergeNest(output_bind_nest, this->intrin->body); + body = MergeNest(input_bind_nest, body); + body = ir::Substitute(body, vmap); + body = MergeNest(binder.asserts(), body); + body = op::Substitute(body, n.main_vmap); + Stmt ret = MergeNest(nest, body); + return ret; + } else { + // Need to split reduction + CHECK(this->intrin->reduce_update.defined()) + << "Reduction update op is not defined"; + // Need init and update steps + CHECK_NE(this->reduce_axis.size(), 0U); + std::vector > common( + n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); + std::vector > update_nest( + n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); + update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); + + if (this->intrin->reduce_init.defined()) { + // init nest + std::vector > init_nest( + n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); + Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init); + init = op::Substitute(init, n.init_vmap); + init = MergeNest(init_nest, init); + // The update + Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update); + update = MergeNest(input_bind_nest, update); + update = ir::Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = op::Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, Block::make(init, update)); + } else { + // When init op is not available, use body op for reset in the first iter. + CHECK(this->intrin->body.defined()) + << "Normal body op is not defined"; + Stmt update = TransformUpdate(stage, dom_map, n, + this->intrin->body, + this->intrin->reduce_update); + update = MergeNest(output_bind_nest, update); + update = MergeNest(input_bind_nest, update); + update = ir::Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = op::Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, update); + } + } +} + +} // namespace tvm diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 6daaedd16de1..a61aac422284 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -10,7 +10,6 @@ #include "op_util.h" #include "compute_op.h" #include "../schedule/message_passing.h" -#include "../arithmetic/compute_expr.h" namespace tvm { @@ -323,50 +322,6 @@ void VerifyTensorizeBody( } } -/*! - * \brief Transform the update part when there is no init func in tensorizing - * \param stage The stage for tensorizing. - * \param dom_map The range of each iter var. - * \param n The loop nest structured used in compute. - * \param body The body func in tensorize intrin - * \param update The update func in tensorize intrin - * \return Transformed result. - */ -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - const ComputeLoopNest& n, - Stmt body, - Stmt update) { - Array conds; - std::unordered_set banned; - for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { - IterVar iv = stage->leaf_iter_vars[i]; - auto iit = stage->iter_var_attrs.find(iv); - if (iit != stage->iter_var_attrs.end()) { - const IterVarAttr& attr = (*iit).second; - if (attr->iter_type == kTensorized) { - break; - } - } - if (iv->iter_type == kCommReduce) { - auto vit = dom_map.find(iv); - CHECK(vit != dom_map.end()); - const Range& vrange = vit->second; - conds.push_back(likely(iv->var > vrange->min)); - banned.insert(iv->var.get()); - } - } - for (const Expr& pred : n.main_predicates) { - if (ir::ExprUseVar(pred, banned)) { - LOG(FATAL) << "Tensorize update transform failed, the condition " - << pred << " has a conflict with the reset condition"; - } - } - - return IfThenElse::make(arith::ComputeReduce(conds, const_true(1)), - update, body); -} - Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 0fac313c079b..623886c31b86 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg, // bind pointer and offset. if (is_zero(arg->elem_offset)) { CHECK(is_zero(value->elem_offset)) - << "Trying to bind a Buffer with offset into one without offset"; + << "Trying to bind a Buffer with offset into one without offset " + << " required elem_offset=" << arg->elem_offset + << ", provided elem_offset=" << value->elem_offset; } this->Bind(arg->data, value->data, arg_name + ".data"); diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 8591c77bd7cc..ccf7fd617194 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -135,29 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } -// Cache write and relayout the data according to loop pattern -Array CacheWriteWithReLayout(Schedule sch, - const Array& tensor_array, - const std::string& scope) { - size_t tensor_size = tensor_array.size(); - sch->InvalidateCache(); - Tensor tensor = tensor_array[0]; - Stage orig_stage = sch[tensor->op]; - const ComputeOpNode* compute = orig_stage->op.as(); - std::unordered_set red_axis; - for (IterVar iv : compute->reduce_axis) { +template +void PrepareAxisMapping(Stage orig_stage, + OpType* op, + std::unordered_set* p_red_axis, + Array* p_new_axis, + std::unordered_map* p_dom_map, + std::unordered_map* p_vsub, + std::unordered_map* p_vsub2newvar, + std::vector* p_predicates) { + auto& red_axis = *p_red_axis; + auto& new_axis = *p_new_axis; + auto& dom_map = *p_dom_map; + auto& vsub = *p_vsub; + auto& vsub2newvar = *p_vsub2newvar; + auto& predicates = *p_predicates; + + for (IterVar iv : op->reduce_axis) { red_axis.insert(iv); } - std::unordered_map dom_map; - Array new_axis; - - for (IterVar iv : compute->axis) { + for (IterVar iv : op->axis) { dom_map[iv] = iv->dom; } schedule::PassDownDomain(orig_stage, &dom_map, true); - std::unordered_map vsub; - std::unordered_map vsub2newvar; - std::vector predicates; { // The source->cache std::unordered_map value_map; @@ -178,17 +178,85 @@ Array CacheWriteWithReLayout(Schedule sch, } // skip reduction iteration. std::unordered_set skip_bound_check; - for (IterVar iv : compute->reduce_axis) { + for (IterVar iv : op->reduce_axis) { skip_bound_check.insert(iv); } schedule::PassUpIndex(orig_stage, dom_map, &value_map, true); predicates = schedule::MakeBoundCheck( orig_stage, dom_map, value_map, true, skip_bound_check); // The root axis - for (IterVar iv : compute->axis) { - vsub[iv->var.get()] = value_map.at(iv); + for (IterVar iv : op->axis) { + if (value_map.count(iv)) { + vsub[iv->var.get()] = value_map.at(iv); + } // to handle tensor axis } } +} + +Array ReplaceOriginalOp(Schedule sch, + Stage orig_stage, + const std::string& scope, + Operation cache_op, + Operation orig_new_op, + size_t tensor_size) { + Array cache_tensor_list; + for (size_t i = 0; i < tensor_size; i++) { + Tensor cache_tensor = cache_op.output(i); + cache_tensor_list.push_back(cache_tensor); + } + // The replace of the dataflow + std::unordered_map vmap; + std::unordered_map rvmap; + vmap[orig_stage->op.output(0)] = orig_new_op.output(0); + rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + for (size_t i = 0; i < tensor_size; i++) { + vmap[orig_stage->op.output(0)] = orig_new_op.output(0); + rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + } + ReplaceDataFlow(sch->stages, &vmap, &rvmap); + // mutate orig stage + orig_stage->op = orig_new_op; + orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); + orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; + orig_stage->relations = Array(); + // create schedule for new cached stage. + ArrayNode* stages = sch->stages.CopyOnWrite(); + size_t pos = FindNodeRef(stages, orig_stage); + Stage cache_stage = Stage(cache_op); + cache_stage.set_scope(scope); + CHECK_LT(pos, stages->data.size()); + stages->data.insert(stages->data.begin() + pos, + cache_stage.node_); + sch->stage_map.Set(cache_op, cache_stage); + // Update group + cache_stage->group = orig_stage->group; + if (cache_stage->group.defined()) { + ++cache_stage->group->num_child_stages; + } + return cache_tensor_list; +} + + +// Cache write and relayout the data according to loop pattern +Array CacheWriteWithReLayout(Schedule sch, + const Array& tensor_array, + const std::string& scope) { + size_t tensor_size = tensor_array.size(); + sch->InvalidateCache(); + Tensor tensor = tensor_array[0]; + Stage orig_stage = sch[tensor->op]; + const ComputeOpNode* compute = orig_stage->op.as(); + + std::unordered_set red_axis; + Array new_axis; + std::unordered_map dom_map; + + std::unordered_map vsub; + std::unordered_map vsub2newvar; + std::vector predicates; + + PrepareAxisMapping(orig_stage, compute, + &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); Expr body; Array body_list; @@ -198,7 +266,7 @@ Array CacheWriteWithReLayout(Schedule sch, body = InjectPredicate(predicates, body); body = VarReplacer(vsub2newvar).Mutate(body); // Reduce nodes in ONE computeOp must be the same except value_index - // This is right only if the oringinal body ensures Reduce nodes are the same + // This is right only if the original body ensures Reduce nodes are the same if (body->is_type()) { const ir::Reduce* reduce_body = body.as(); if (first_reduce != nullptr) { @@ -234,48 +302,107 @@ Array CacheWriteWithReLayout(Schedule sch, Operation cache_op = ComputeOpNode::make( compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list); - Array cache_tensor_list; + Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); - cache_tensor_list.push_back(cache_tensor); cache_expr_list.push_back(cache_tensor(args)); } Operation orig_new_op = ComputeOpNode::make( compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list); - // The replace of the dataflow - std::unordered_map vmap; - std::unordered_map rvmap; - vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); - for (size_t i = 0; i < tensor_size; i++) { - vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + return ReplaceOriginalOp(sch, orig_stage, scope, + cache_op, orig_new_op, tensor_size); +} + + +// for tensor compute op +Array CacheWriteWithReLayoutTensor(Schedule sch, + const Array& tensor_array, + const std::string& scope) { + size_t tensor_size = tensor_array.size(); + sch->InvalidateCache(); + Tensor tensor = tensor_array[0]; + Stage orig_stage = sch[tensor->op]; + const TensorComputeOpNode* tensor_op = orig_stage->op.as(); + CHECK_EQ(tensor_op->num_outputs(), 1) + << "cache write only support single output tensor_compute_op"; + + std::unordered_set red_axis; + Array new_axis; + std::unordered_map dom_map; + + std::unordered_map vsub; + std::unordered_map vsub2newvar; + std::vector predicates; + + PrepareAxisMapping(orig_stage, tensor_op, + &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); + + + for (int i = tensor_op->schedulable_ndim; i < static_cast(tensor_op->axis.size()); ++i) { + IterVar iv = tensor_op->axis[i]; + IterVar new_iv = IterVarNode::make( + iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + new_axis.push_back(new_iv); + } + Array new_regions; + for (Region old_region : tensor_op->input_regions) { + Region region; + for (Range r : old_region) { + Expr min = VarReplacer(vsub2newvar).Mutate(r->min); + Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent); + region.push_back(Range::make_by_min_extent(min, extent)); + } + new_regions.push_back(region); } - ReplaceDataFlow(sch->stages, &vmap, &rvmap); - // mutate orig stage - orig_stage->op = orig_new_op; - orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); - orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; - orig_stage->relations = Array(); - // create schedule for new cached stage. - ArrayNode* stages = sch->stages.CopyOnWrite(); - size_t pos = FindNodeRef(stages, orig_stage); - Stage cache_stage = Stage(cache_op); - cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos, - cache_stage.node_); - sch->stage_map.Set(cache_op, cache_stage); - // Update group - cache_stage->group = orig_stage->group; - if (cache_stage->group.defined()) { - ++cache_stage->group->num_child_stages; + + Operation cache_op = TensorComputeOpNode::make( + tensor_op->name + "." + scope, tensor_op->tag, new_axis, + tensor_op->reduce_axis, tensor_op->schedulable_ndim, + tensor_op->intrin, tensor_op->inputs, new_regions); + + // axis will be used in generating compute op + Array compute_axis = tensor_op->axis; + for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { + IterVar iv = tensor_op->axis[i]; + IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); + compute_axis.Set(i, aiv); } - return cache_tensor_list; + + // The reader args + Array args; + { + // cache->compute + std::unordered_map value_map; + for (IterVar iv : compute_axis) { + value_map[iv] = iv->var; + } + schedule::PassDownIndex(orig_stage, dom_map, &value_map, true); + for (IterVar iv : orig_stage->leaf_iter_vars) { + if (red_axis.count(iv)) continue; + args.push_back(value_map.at(iv)); + } + // tensorized region axis + for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { + IterVar iv = compute_axis[i]; + args.push_back(value_map.at(iv)); + } + } + + Array cache_expr_list; + for (size_t i = 0; i < tensor_size; i++) { + Tensor cache_tensor = cache_op.output(i); + cache_expr_list.push_back(cache_tensor(args)); + } + Operation orig_new_op = ComputeOpNode::make( + tensor_op->name, tensor_op->tag, {}, + compute_axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, + cache_op, orig_new_op, tensor_size); } + Array Schedule::cache_write(const Array& tensor_array, const std::string& scope) { (*this)->InvalidateCache(); @@ -291,23 +418,26 @@ Array Schedule::cache_write(const Array& tensor_array, CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp"; } - return CacheWriteWithReLayout(*this, tensor_array, scope); } + Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { + // support original compute and tensor compute both (*this)->InvalidateCache(); - Stage orig_stage = operator[](tensor->op); - const ComputeOpNode* compute = tensor->op.as(); - CHECK(compute) - << "cache write only take ComputeOp as writers"; - CHECK_EQ(compute->num_outputs(), 1) - << "cache write only support single output ComputeOp"; - - return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; + const char* type_key = tensor->op->type_key(); + if (!strcmp(type_key, "ComputeOp")) { + return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; + } else if (!strcmp(type_key, "TensorComputeOp")) { + return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0]; + } else { + LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers"; + return Tensor(); + } } + void RebaseNonZeroMinLoop(const Schedule& sch) { std::unordered_map rebase_map; for (Stage s : sch->stages) { diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 1d8603dfc98b..2f49b084b875 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -84,6 +84,78 @@ def test_tensor_reduce(): assert(isinstance(C_loaded, tvm.tensor.Tensor)) assert(str(C_loaded) == str(C)) +def test_tensor_compute1(): + m = 1024 + factor = 16 + dtype = 'float32' + + def intrin_vadd(n): + x = tvm.placeholder((n,)) + y = tvm.placeholder((n,)) + z = tvm.compute(x.shape, lambda i: x[i] + y[i]) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) + return ib.get() + + with tvm.build_config(offset_factor=n): + return tvm.decl_tensor_intrin(z.op, intrin_func) + + vadd = intrin_vadd(factor) + + A = tvm.placeholder((m//factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((m//factor, factor), name="B", dtype=dtype) + C = tvm.compute((m//factor, factor), + lambda i: vadd(A[i, 0:factor], B[i, 0:factor])) + + s = tvm.create_schedule(C.op) + stmt = tvm.lower(s, [A, B, C], simple_mode=True) + assert isinstance(stmt.body.body, tvm.stmt.Evaluate) + +def test_tensor_compute2(): + M = 2048 + N = 1024 + L = 1024 + factor = 16 + factor1 = 32 + factor2 = 32 + dtype = 'float32' + + def intrin_gemm(m, n, l): + k = tvm.reduce_axis((0, l)) + x = tvm.placeholder((m, l)) + y = tvm.placeholder((n, l)) + # in theory, no relation + z = tvm.compute((m, n), lambda i, j: tvm.sum(x[i][k] * y[j][k], axis=k)) + + def intrin_func(ins, outs): + x_ptr = ins[0].access_ptr("r") + y_ptr = ins[1].access_ptr("r") + z_ptr = outs[0].access_ptr("w") + body = tvm.call_packed( + "gemv", x_ptr, y_ptr, z_ptr, m, n, l) + reset = tvm.call_packed( + "fill_zero", z_ptr, m, n) + update = tvm.call_packed( + "gemv_add", x_ptr, y_ptr, z_ptr, m, n, l) + return body, reset, update + + with tvm.build_config(offset_factor=n): + return tvm.decl_tensor_intrin(z.op, intrin_func) + + vgemm = intrin_gemm(factor1, factor2, factor) + + A = tvm.placeholder((M//factor1, L//factor, factor1, factor), name="A", dtype=dtype) + B = tvm.placeholder((N//factor2, L//factor, factor2, factor), name="B", dtype=dtype) + k = tvm.reduce_axis((0, L//factor), name='k') + C = tvm.compute((M//factor1, N//factor2, factor1, factor2), + lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k)) + + s = tvm.create_schedule(C.op) + stmt = tvm.lower(s, [A, B, C], simple_mode=True) + assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate) + assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate) def test_tensor_scan(): m = tvm.var("m") @@ -192,6 +264,8 @@ def test_tensor_inputs(): test_conv1d() test_tensor_slice() test_tensor() + test_tensor_compute1() + test_tensor_compute2() test_tensor_reduce() test_tensor_scan() test_scan_multi_out() diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 8e6f4090d403..8774514cfa17 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -276,6 +276,133 @@ def test_schedule_bound_condition(): stmt = tvm.ir_pass.Simplify(stmt) assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) + +def intrin_gemv(m, n): + w = tvm.placeholder((m, n), name='w') + x = tvm.placeholder((n,), name='x') + k = tvm.reduce_axis((0, n), name='k') + z = tvm.compute((m,), lambda i: + tvm.sum(w[i, k] * x[k], axis=k), name='z') + Wb = tvm.decl_buffer(w.shape, w.dtype, + name="W", + offset_factor=16, + strides=[tvm.var('ldw'), 1]) + def intrin_func(ins, outs): + ww, xx = ins + zz = outs[0] + ww_ptr = ww.access_ptr("r") + xx_ptr = xx.access_ptr("r") + zz_ptr = zz.access_ptr("w") + body = tvm.call_packed( + "gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + reset = tvm.call_packed( + "fill_zero", zz_ptr, n) + update = tvm.call_packed( + "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + return body, reset, update + + with tvm.build_config(data_alignment=16, + offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, + binds={w: Wb}) + + +def test_schedule_tensor_compute1(): + # basic: split, reorder, tile + M, N, L = 2048, 1024, 512 + factor, rfactor = 16, 16 + A = tvm.placeholder((N//factor, L//rfactor, factor, rfactor), name='A') + B = tvm.placeholder((M, L//rfactor, rfactor), name='B') + k = tvm.reduce_axis((0, L//rfactor), name='k') + + gemv = intrin_gemv(factor, rfactor) + C = tvm.compute((N, M//factor, factor), + lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k), + name='C') + + s = tvm.create_schedule(C.op) + ai, aj, ax = s[C].op.axis + aio, aii = s[C].split(ai, 16) + s[C].reorder(aio, aj, aii) + aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4) + + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def intrin_vadd(n, cache_read=False, cache_write=False): + scope_ubuf = 'local' + dtype = 'float32' + x = tvm.placeholder((n,), dtype=dtype, name='vx') + y = tvm.placeholder((n,), dtype=dtype, name='vy') + z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') + s = tvm.create_schedule(z.op) + + def create_buffer(t): + return tvm.decl_buffer(t.shape, t.dtype, + name='W'+t.name, + scope=scope_ubuf, + offset_factor=16) + + binds = {} + if cache_read: + binds[x] = create_buffer(x) + binds[y] = create_buffer(y) + if cache_write: + binds[z] = create_buffer(z) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) + return ib.get() + + with tvm.build_config(offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds) + + +def test_schedule_tensor_compute2(): + # cache_read, cache_write + M = 1024 + factor = 16 + dtype = 'float32' + scope_ubuf = 'local' + + A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype) + + vadd = intrin_vadd(factor, True, True) + C = tvm.compute((M//factor, factor), + lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C') + + s = tvm.create_schedule(C.op) + AL = s.cache_read(A, scope_ubuf, C) + BL = s.cache_read(B, scope_ubuf, C) + CL = s.cache_write(C, scope_ubuf) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def test_schedule_tensor_compute3(): + # compute_at + M = 1024 + factor = 16 + dtype = 'float32' + A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype) + Bi = tvm.compute((M//factor, factor), lambda i, j: B[i, j] + 5, name="Bi") + + vadd = intrin_vadd(factor) + C = tvm.compute((M//factor, factor), + lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C') + s = tvm.create_schedule(C.op) + s[Bi].compute_at(s[C], C.op.axis[0]) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + if __name__ == "__main__": test_schedule_middle_cache() test_inline_multi_reduce() @@ -294,3 +421,6 @@ def test_schedule_bound_condition(): test_schedule2() test_schedule_cache() test_schedule_bound_condition() + test_schedule_tensor_compute1() + test_schedule_tensor_compute2() + test_schedule_tensor_compute3()