Skip to content

Commit

Permalink
[LANG] Generalize compute to tensor region (apache#1476)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and tqchen committed Oct 6, 2018
1 parent 39703e3 commit c313256
Show file tree
Hide file tree
Showing 17 changed files with 1,059 additions and 131 deletions.
2 changes: 2 additions & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class Range : public HalideIR::IR::Range {
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
};

using Region = Array<Range>;

/*!
* \brief Type of iteration variable.
* Each IterVar have a specific type.
Expand Down
72 changes: 70 additions & 2 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode {
}
/*!
* \return The list of iteration variable at root
* \note root_iter_vars dedides the shape of the outputs.
* \note root_iter_vars decides the shape of the outputs.
*/
virtual Array<IterVar> root_iter_vars() const = 0;
/*!
Expand Down Expand Up @@ -239,6 +239,74 @@ class TVM_DLL ComputeOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
};

/*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
class TensorComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */
Array<IterVar> reduce_axis;
/*! \brief number of axes that can be scheduled */
int schedulable_ndim;
/*! \brief TensorIntrin used to compute */
TensorIntrin intrin;
/*! \brief input tensors of intrin */
Array<Tensor> inputs;
/*! \brief region of input tensors */
Array<Region> input_regions;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("schedulable_ndim", &schedulable_ndim);
v->Visit("intrin", &intrin);
v->Visit("inputs", &inputs);
v->Visit("input_regions", &input_regions);
}
static Operation make(std::string name,
std::string tag,
Array<IterVar> axis,
Array<IterVar> reduce_axis,
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions);

static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
};

/*!
* \brief Symbolic scan.
*/
Expand Down Expand Up @@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representationinputs */
/*! \brief Symbolic placeholder representation of inputs */
Array<Buffer> input_placeholders;
/*! \brief Symbolic placeholder representation of outputs */
Array<Buffer> output_placeholders;
Expand Down
53 changes: 53 additions & 0 deletions include/tvm/tensor_intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,58 @@ class TensorIntrinNode : public Node {
inline const TensorIntrinNode* TensorIntrin::operator->() const {
return static_cast<const TensorIntrinNode*>(node_.get());
}


// Internal node container of tensor intrinsic calling.
class TensorIntrinCallNode;

/*! \brief Tensor intrinsic calling node. */
class TensorIntrinCall : public NodeRef {
public:
TensorIntrinCall() {}
explicit TensorIntrinCall(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorIntrinCallNode* operator->() const;

/*! \brief specify container node */
using ContainerType = TensorIntrinCallNode;
};

class TensorIntrinCallNode : public Node {
public:
/*! \brief the tensor intrinsic */
TensorIntrin intrin;
/*! \brief input tensors of the intrinsic */
Array<Tensor> tensors;
/*! \brief regions of input tensors */
Array<Region> regions;
/*!
* \brief IterVar on each reduction axis, if the
* intrin will use the reduce axis
*/
Array<IterVar> reduce_axis;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", &regions);
v->Visit("reduce_axis", &reduce_axis);
}
static TensorIntrinCall make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis);

static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
};

inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
return static_cast<const TensorIntrinCallNode*>(node_.get());
}

} // namespace tvm
#endif // TVM_TENSOR_INTRIN_H_
39 changes: 29 additions & 10 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape)
code = fcompute.__code__

if fcompute.__code__.co_argcount == 0:
out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount

if ndim != len(arg_names):
if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)

dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var])
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body)

if isinstance(body, _tensor.TensorIntrinCall):
for i, s in enumerate(shape[out_ndim:]):
var_name = "ax" + str(i)
dim_var.append(_IterVar((0, s), var_name, 4))
op_node = _api_internal._TensorComputeOp(name,
tag,
dim_var,
body.reduce_axis,
out_ndim,
body.intrin,
body.tensors,
body.regions)
else:
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body)

num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
Expand Down Expand Up @@ -529,14 +548,14 @@ def decl_buffer(shape,
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None:
elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
elem_offset = var('%s_elem_offset' % name, shape_dtype)
if data is None:
data = var(name, "handle")
return _api_internal._Buffer(
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor)


def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def dtype(self):
"""Data content of the tensor."""
return self.tensor.dtype

@register_node
class TensorIntrinCall(NodeBase):
"""Intermediate structure for calling a tensor intrinsic."""
pass


itervar_cls = None

Expand Down Expand Up @@ -106,6 +111,7 @@ def name(self):
return "%s.v%d" % (op.name, self.value_index)



class Operation(NodeBase):
"""Represent an operation that generate a tensor"""

Expand Down Expand Up @@ -155,6 +161,12 @@ def reduce_axis(self):
return self.__getattr__("reduce_axis")


@register_node
class TensorComputeOp(Operation):
"""Tensor operation."""
pass


@register_node
class ScanOp(Operation):
"""Scan operation."""
Expand Down
28 changes: 26 additions & 2 deletions python/tvm/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,25 @@
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node


def _get_region(tslice):
region = []
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(_api.Range(idx.start, idx.stop))
else:
if isinstance(idx, _schedule.IterVar):
begin = idx.var
else:
begin = idx
region.append(_make.range_by_min_extent(begin, 1))
return region

@register_node
class TensorIntrin(NodeBase):
"""Tensor intrinsic functions for certain computation.
Expand All @@ -17,8 +33,16 @@ class TensorIntrin(NodeBase):
--------
decl_tensor_intrin: Construct a TensorIntrin
"""
pass

def __call__(self, *args, **kwargs):
tensors = [x.tensor for x in args]
regions = [_get_region(x) for x in args]
reduce_axis = []
if "reduce_axis" in kwargs:
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)

def decl_tensor_intrin(op,
fcompute,
Expand Down
20 changes: 20 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin")
args[6]);
});

TVM_REGISTER_API("_TensorIntrinCall")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorIntrinCallNode::make(args[0],
args[1],
args[2],
args[3]);
});

TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor();
Expand Down Expand Up @@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp")
args[7]);
});

TVM_REGISTER_API("_TensorComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorComputeOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7]);
});

TVM_REGISTER_API("_ExternOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ExternOpNode::make(args[0],
Expand Down
Loading

0 comments on commit c313256

Please sign in to comment.