From 781df707d1a0e837de25854652a6cc811104d279 Mon Sep 17 00:00:00 2001 From: Jessica Davies Date: Fri, 15 Feb 2019 17:53:37 -0500 Subject: [PATCH] Extend TensorComputeOp to allow scalar inputs (#2606). --- include/tvm/operation.h | 6 ++++- include/tvm/tensor_intrin.h | 16 +++++++++++- python/tvm/api.py | 3 ++- python/tvm/tensor_intrin.py | 21 +++++++++++++--- src/api/api_lang.cc | 9 ++++--- src/lang/tensor.cc | 6 ++++- src/op/tensor_compute_op.cc | 17 ++++++++++++- src/schedule/schedule_dataflow_rewrite.cc | 7 +++++- tests/python/unittest/test_lang_schedule.py | 27 +++++++++++++++++++++ 9 files changed, 99 insertions(+), 13 deletions(-) diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 3509b133cfc37..87c7ed6059863 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -256,6 +256,8 @@ class TensorComputeOpNode : public OperationNode { Array inputs; /*! \brief region of input tensors */ Array input_regions; + /*! \brief scalar expression inputs */ + Array scalar_inputs; /*! \brief constructor */ TensorComputeOpNode() {} // override functions @@ -293,6 +295,7 @@ class TensorComputeOpNode : public OperationNode { v->Visit("intrin", &intrin); v->Visit("inputs", &inputs); v->Visit("input_regions", &input_regions); + v->Visit("scalar_inputs", &scalar_inputs); } static Operation make(std::string name, std::string tag, @@ -301,7 +304,8 @@ class TensorComputeOpNode : public OperationNode { int schedulable_ndim, TensorIntrin intrin, Array tensors, - Array regions); + Array regions, + Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode); diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 6cffc931d42a3..0d4a14ae2ba80 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -49,6 +49,11 @@ class TensorIntrinNode : public Node { * When it is a constant, it means we can only take data in that shape. */ Array buffers; + /*! \brief List of scalar variables, used in body. These placeholders + * will be bound to expressions passed in when the TensorIntrin is called + * from a TensorComputeOp. + */ + Array scalar_params; /*! \brief The normal statement to execute the intrinsic */ Stmt body; /*! @@ -69,6 +74,7 @@ class TensorIntrinNode : public Node { v->Visit("op", &op); v->Visit("inputs", &inputs); v->Visit("buffers", &buffers); + v->Visit("scalar_params", &scalar_params); v->Visit("body", &body); v->Visit("reduce_init", &reduce_init); v->Visit("reduce_update", &reduce_update); @@ -78,6 +84,7 @@ class TensorIntrinNode : public Node { Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update); @@ -116,22 +123,29 @@ class TensorIntrinCallNode : public Node { Array tensors; /*! \brief regions of input tensors */ Array regions; + + /*! * \brief IterVar on each reduction axis, if the * intrin will use the reduce axis */ Array reduce_axis; + /*! \brief scalar expression inputs */ + Array scalar_inputs; + void VisitAttrs(AttrVisitor* v) final { v->Visit("intrin", &intrin); v->Visit("tensors", &tensors); v->Visit("regions", ®ions); v->Visit("reduce_axis", &reduce_axis); + v->Visit("scalar_inputs", &scalar_inputs); } static TensorIntrinCall make(TensorIntrin intrin, Array tensors, Array regions, - Array reduce_axis); + Array reduce_axis, + Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node); diff --git a/python/tvm/api.py b/python/tvm/api.py index 10a97171e58fd..333bdcf339baf 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -303,7 +303,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): out_ndim, body.intrin, body.tensors, - body.regions) + body.regions, + body.scalar_inputs) else: if not isinstance(body, (list, tuple)): body = [body] diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index f1f26655fe27d..4bfd65b7d2e40 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -42,12 +42,18 @@ def __call__(self, *args, **kwargs): if not isinstance(reduce_axis, (list, tuple)): reduce_axis = [reduce_axis] reduce_axis = _api.convert(reduce_axis) - return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis) + scalar_inputs = [] + if "scalar_inputs" in kwargs: + scalar_inputs = kwargs["scalar_inputs"] + if not isinstance(scalar_inputs, (list, tuple)): + scalar_inputs = [scalar_inputs] + scalar_inputs = _api.convert(scalar_inputs) + return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs) def decl_tensor_intrin(op, fcompute, name="tensor_intrin", - binds=None): + binds=None, scalar_params=None): """Declare a tensor intrinsic function. Parameters @@ -80,6 +86,9 @@ def decl_tensor_intrin(op, requirement of the function. By default, a new compact buffer is created for each tensor in the argument. + scalar_params: a list of variables used by op, whose values will be passed + as scalar_inputs when the tensor intrinsic is called. + Returns ------- intrin: TensorIntrin @@ -106,11 +115,15 @@ def decl_tensor_intrin(op, offset_factor=cfg.offset_factor)) binds_list.append(buf) - body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) + if scalar_params: + body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params) + else: + body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) + scalar_params = [] if isinstance(body, (_expr.Expr, _stmt.Stmt)): body = [body] body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body] if len(body) < 3: body += [None] * (3 - len(body)) return _api_internal._TensorIntrin( - name, op, inputs, binds_list, *body) + name, op, inputs, binds_list, scalar_params, *body) diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index e30111e938bd1..7cf8fe36ca6d0 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -240,7 +240,8 @@ TVM_REGISTER_API("_TensorIntrin") args[3], args[4], args[5], - args[6]); + args[6], + args[7]); }); TVM_REGISTER_API("_TensorIntrinCall") @@ -248,7 +249,8 @@ TVM_REGISTER_API("_TensorIntrinCall") *ret = TensorIntrinCallNode::make(args[0], args[1], args[2], - args[3]); + args[3], + args[4]); }); TVM_REGISTER_API("_TensorEqual") @@ -299,7 +301,8 @@ TVM_REGISTER_API("_TensorComputeOp") args[4], args[5], args[6], - args[7]); + args[7], + args[8]); }); TVM_REGISTER_API("_ExternOp") diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 9b1a58abcee4f..f2b8274fa2b36 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -64,6 +64,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { @@ -72,6 +73,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, n->op = std::move(op); n->inputs = std::move(inputs); n->buffers = std::move(buffers); + n->scalar_params = std::move(scalar_params); n->body = std::move(body); n->reduce_init = std::move(reduce_init); n->reduce_update = std::move(reduce_update); @@ -91,12 +93,14 @@ TVM_REGISTER_NODE_TYPE(TensorIntrinNode); TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array tensors, Array regions, - Array reduce_axis) { + Array reduce_axis, + Array scalar_inputs) { auto n = make_node(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); n->reduce_axis = std::move(reduce_axis); + n->scalar_inputs = std::move(scalar_inputs); return TensorIntrinCall(n); } diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index 0262db7d8fc50..1c4aa8f022592 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -56,7 +56,8 @@ Operation TensorComputeOpNode::make(std::string name, int schedulable_ndim, TensorIntrin intrin, Array tensors, - Array regions) { + Array regions, + Array scalar_inputs) { auto n = make_node(); n->name = std::move(name); n->tag = std::move(tag); @@ -66,6 +67,7 @@ Operation TensorComputeOpNode::make(std::string name, n->intrin = std::move(intrin); n->inputs = std::move(tensors); n->input_regions = std::move(regions); + n->scalar_inputs = std::move(scalar_inputs); return Operation(n); } @@ -295,6 +297,19 @@ Stmt TensorComputeOpNode::BuildProvide( std::unordered_map vmap; ir::ArgBinder binder(&vmap); + // Map the expressions passed in the call to the TensorIntrin, to the placeholder + // variables + Array user_expr = this->scalar_inputs; + Array scalar_params = this->intrin->scalar_params; + Array sp_expr; + for (auto sp : scalar_params) { + Expr esp = sp; + sp_expr.push_back(esp); + } + CHECK_EQ(sp_expr.size(), user_expr.size()); + // TODO(jdavies-huawei): what name should be used here? + binder.BindArray(sp_expr, user_expr, this->name); + size_t tloc = stage->leaf_iter_vars.size(); ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop); diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 774c623a5df2a..732084ee8456d 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -391,10 +391,15 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, new_regions.push_back(region); } + Array new_scalar_inputs; + for (Expr old_input : tensor_op->scalar_inputs) { + new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input)); + } + Operation cache_op = TensorComputeOpNode::make( tensor_op->name + "." + scope, tensor_op->tag, new_axis, tensor_op->reduce_axis, tensor_op->schedulable_ndim, - tensor_op->intrin, tensor_op->inputs, new_regions); + tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs); // axis will be used in generating compute op Array compute_axis = tensor_op->axis; diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index a00785dea7afc..dcb67f4a89ff8 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -193,11 +193,38 @@ def intrin_func(ins, outs): assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin) assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized) +def test_tensor_intrin_scalar_params(): + n = 2 + x = tvm.placeholder((n,), name='x') + y = tvm.placeholder((n,), name='y') + v = tvm.var("v") + w = tvm.var("w") + z = tvm.compute((1, ), lambda i: v*(x[0] + y[1]) + w*(x[1] + y[0]), name='z') + def intrin_func(ins, outs, sp): + assert(isinstance(ins[0], tvm.schedule.Buffer)) + assert(ins[0].shape[0].value == n) + assert(sp[0] == v) + assert(sp[1] == w) + return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1]) + intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w]) + assert intrin.op == z.op + assert intrin.reduce_init is None + assert tuple(intrin.inputs) == tuple(z.op.input_tensors) + assert(intrin.buffers[0].shape[0].value == n) + assert tuple(intrin.scalar_params) == tuple((v, w)) + m = 32 + x = tvm.placeholder((m,), name='x') + y = tvm.placeholder((m,), name='y') + z = tvm.compute(x.shape, lambda i: intrin(x[i:i+2], y[i:i+2], scalar_inputs=(i*i, i*25)), name='z') + s = tvm.create_schedule(z.op) + + if __name__ == "__main__": test_singleton() test_pragma() test_tensor_intrin() + test_tensor_intrin_scalar_params() test_rfactor() test_schedule_create() test_reorder()