Skip to content

Commit

Permalink
Extend TensorComputeOp to allow scalar inputs (apache#2606).
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavies-huawei committed Feb 15, 2019
1 parent d05fed2 commit 781df70
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 13 deletions.
6 changes: 5 additions & 1 deletion include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ class TensorComputeOpNode : public OperationNode {
Array<Tensor> inputs;
/*! \brief region of input tensors */
Array<Region> input_regions;
/*! \brief scalar expression inputs */
Array<Expr> scalar_inputs;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
Expand Down Expand Up @@ -293,6 +295,7 @@ class TensorComputeOpNode : public OperationNode {
v->Visit("intrin", &intrin);
v->Visit("inputs", &inputs);
v->Visit("input_regions", &input_regions);
v->Visit("scalar_inputs", &scalar_inputs);
}
static Operation make(std::string name,
std::string tag,
Expand All @@ -301,7 +304,8 @@ class TensorComputeOpNode : public OperationNode {
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions);
Array<Region> regions,
Array<Expr> scalar_inputs);

static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
Expand Down
16 changes: 15 additions & 1 deletion include/tvm/tensor_intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class TensorIntrinNode : public Node {
* When it is a constant, it means we can only take data in that shape.
*/
Array<Buffer> buffers;
/*! \brief List of scalar variables, used in body. These placeholders
* will be bound to expressions passed in when the TensorIntrin is called
* from a TensorComputeOp.
*/
Array<Var> scalar_params;
/*! \brief The normal statement to execute the intrinsic */
Stmt body;
/*!
Expand All @@ -69,6 +74,7 @@ class TensorIntrinNode : public Node {
v->Visit("op", &op);
v->Visit("inputs", &inputs);
v->Visit("buffers", &buffers);
v->Visit("scalar_params", &scalar_params);
v->Visit("body", &body);
v->Visit("reduce_init", &reduce_init);
v->Visit("reduce_update", &reduce_update);
Expand All @@ -78,6 +84,7 @@ class TensorIntrinNode : public Node {
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Array<Var> scalar_params,
Stmt body,
Stmt reduce_init,
Stmt reduce_update);
Expand Down Expand Up @@ -116,22 +123,29 @@ class TensorIntrinCallNode : public Node {
Array<Tensor> tensors;
/*! \brief regions of input tensors */
Array<Region> regions;


/*!
* \brief IterVar on each reduction axis, if the
* intrin will use the reduce axis
*/
Array<IterVar> reduce_axis;

/*! \brief scalar expression inputs */
Array<Expr> scalar_inputs;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", &regions);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("scalar_inputs", &scalar_inputs);
}
static TensorIntrinCall make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis);
Array<IterVar> reduce_axis,
Array<Expr> scalar_inputs);

static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
out_ndim,
body.intrin,
body.tensors,
body.regions)
body.regions,
body.scalar_inputs)
else:
if not isinstance(body, (list, tuple)):
body = [body]
Expand Down
21 changes: 17 additions & 4 deletions python/tvm/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,18 @@ def __call__(self, *args, **kwargs):
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
scalar_inputs = []
if "scalar_inputs" in kwargs:
scalar_inputs = kwargs["scalar_inputs"]
if not isinstance(scalar_inputs, (list, tuple)):
scalar_inputs = [scalar_inputs]
scalar_inputs = _api.convert(scalar_inputs)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)

def decl_tensor_intrin(op,
fcompute,
name="tensor_intrin",
binds=None):
binds=None, scalar_params=None):
"""Declare a tensor intrinsic function.
Parameters
Expand Down Expand Up @@ -80,6 +86,9 @@ def decl_tensor_intrin(op,
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
scalar_params: a list of variables used by op, whose values will be passed
as scalar_inputs when the tensor intrinsic is called.
Returns
-------
intrin: TensorIntrin
Expand All @@ -106,11 +115,15 @@ def decl_tensor_intrin(op,
offset_factor=cfg.offset_factor))
binds_list.append(buf)

body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
if scalar_params:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
else:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
scalar_params = []
if isinstance(body, (_expr.Expr, _stmt.Stmt)):
body = [body]
body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
name, op, inputs, binds_list, *body)
name, op, inputs, binds_list, scalar_params, *body)
9 changes: 6 additions & 3 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,17 @@ TVM_REGISTER_API("_TensorIntrin")
args[3],
args[4],
args[5],
args[6]);
args[6],
args[7]);
});

TVM_REGISTER_API("_TensorIntrinCall")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorIntrinCallNode::make(args[0],
args[1],
args[2],
args[3]);
args[3],
args[4]);
});

TVM_REGISTER_API("_TensorEqual")
Expand Down Expand Up @@ -299,7 +301,8 @@ TVM_REGISTER_API("_TensorComputeOp")
args[4],
args[5],
args[6],
args[7]);
args[7],
args[8]);
});

TVM_REGISTER_API("_ExternOp")
Expand Down
6 changes: 5 additions & 1 deletion src/lang/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
Array<Var> scalar_params,
Stmt body,
Stmt reduce_init,
Stmt reduce_update) {
Expand All @@ -72,6 +73,7 @@ TensorIntrin TensorIntrinNode::make(std::string name,
n->op = std::move(op);
n->inputs = std::move(inputs);
n->buffers = std::move(buffers);
n->scalar_params = std::move(scalar_params);
n->body = std::move(body);
n->reduce_init = std::move(reduce_init);
n->reduce_update = std::move(reduce_update);
Expand All @@ -91,12 +93,14 @@ TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis) {
Array<IterVar> reduce_axis,
Array<Expr> scalar_inputs) {
auto n = make_node<TensorIntrinCallNode>();
n->intrin = std::move(intrin);
n->tensors = std::move(tensors);
n->regions = std::move(regions);
n->reduce_axis = std::move(reduce_axis);
n->scalar_inputs = std::move(scalar_inputs);
return TensorIntrinCall(n);
}

Expand Down
17 changes: 16 additions & 1 deletion src/op/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ Operation TensorComputeOpNode::make(std::string name,
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions) {
Array<Region> regions,
Array<Expr> scalar_inputs) {
auto n = make_node<TensorComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
Expand All @@ -66,6 +67,7 @@ Operation TensorComputeOpNode::make(std::string name,
n->intrin = std::move(intrin);
n->inputs = std::move(tensors);
n->input_regions = std::move(regions);
n->scalar_inputs = std::move(scalar_inputs);
return Operation(n);
}

Expand Down Expand Up @@ -295,6 +297,19 @@ Stmt TensorComputeOpNode::BuildProvide(
std::unordered_map<const Variable*, Expr> vmap;
ir::ArgBinder binder(&vmap);

// Map the expressions passed in the call to the TensorIntrin, to the placeholder
// variables
Array<Expr> user_expr = this->scalar_inputs;
Array<Var> scalar_params = this->intrin->scalar_params;
Array<Expr> sp_expr;
for (auto sp : scalar_params) {
Expr esp = sp;
sp_expr.push_back(esp);
}
CHECK_EQ(sp_expr.size(), user_expr.size());
// TODO(jdavies-huawei): what name should be used here?
binder.BindArray(sp_expr, user_expr, this->name);

size_t tloc = stage->leaf_iter_vars.size();
ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop);

Expand Down
7 changes: 6 additions & 1 deletion src/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,15 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
new_regions.push_back(region);
}

Array<Expr> new_scalar_inputs;
for (Expr old_input : tensor_op->scalar_inputs) {
new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input));
}

Operation cache_op = TensorComputeOpNode::make(
tensor_op->name + "." + scope, tensor_op->tag, new_axis,
tensor_op->reduce_axis, tensor_op->schedulable_ndim,
tensor_op->intrin, tensor_op->inputs, new_regions);
tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);

// axis will be used in generating compute op
Array<IterVar> compute_axis = tensor_op->axis;
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_lang_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,38 @@ def intrin_func(ins, outs):
assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)

def test_tensor_intrin_scalar_params():
n = 2
x = tvm.placeholder((n,), name='x')
y = tvm.placeholder((n,), name='y')
v = tvm.var("v")
w = tvm.var("w")
z = tvm.compute((1, ), lambda i: v*(x[0] + y[1]) + w*(x[1] + y[0]), name='z')
def intrin_func(ins, outs, sp):
assert(isinstance(ins[0], tvm.schedule.Buffer))
assert(ins[0].shape[0].value == n)
assert(sp[0] == v)
assert(sp[1] == w)
return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1])
intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w])
assert intrin.op == z.op
assert intrin.reduce_init is None
assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
assert(intrin.buffers[0].shape[0].value == n)
assert tuple(intrin.scalar_params) == tuple((v, w))
m = 32
x = tvm.placeholder((m,), name='x')
y = tvm.placeholder((m,), name='y')
z = tvm.compute(x.shape, lambda i: intrin(x[i:i+2], y[i:i+2], scalar_inputs=(i*i, i*25)), name='z')
s = tvm.create_schedule(z.op)



if __name__ == "__main__":
test_singleton()
test_pragma()
test_tensor_intrin()
test_tensor_intrin_scalar_params()
test_rfactor()
test_schedule_create()
test_reorder()
Expand Down

0 comments on commit 781df70

Please sign in to comment.