Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LANG] Generalize compute to tensor region #1476

Merged
merged 8 commits into from
Oct 6, 2018
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class Range : public HalideIR::IR::Range {
TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
};

using Region = Array<Range>;

/*!
* \brief Type of iteration variable.
* Each IterVar have a specific type.
Expand Down
78 changes: 76 additions & 2 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode {
}
/*!
* \return The list of iteration variable at root
* \note root_iter_vars dedides the shape of the outputs.
* \note root_iter_vars decides the shape of the outputs.
*/
virtual Array<IterVar> root_iter_vars() const = 0;
/*!
Expand Down Expand Up @@ -182,6 +182,80 @@ class PlaceholderOpNode : public OperationNode {
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
};

class TensorComputeOpNode : public OperationNode {
public:
Array<IterVar> axis;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document the fields

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, document each field


Array<IterVar> out_axis;

Array<IterVar> tensor_axis;

Array<IterVar> reduce_axis;

Array<Tensor> inputs;

Array<Region> input_regions;

TensorIntrin intrin;

/*! \brief constructor */
TensorComputeOpNode() {}

// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(
const Operation& self,
const std::unordered_map<const Variable*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(
const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis);
v->Visit("out_axis", &out_axis);
v->Visit("tensor_axis", &tensor_axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("inputs", &inputs);
}

static Operation make(std::string name,
std::string tag,
Array<IterVar> out_axis,
Array<IterVar> tensor_axis,
TensorIntrinCall intrin_call);

static Operation make(std::string name,
std::string tag,
Array<IterVar> out_axis,
Array<IterVar> tensor_axis,
Array<IterVar> reduce_axis,
Array<Tensor> tensors,
Array<Region> regions,
TensorIntrin intrin);

static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
};

/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
Expand Down Expand Up @@ -326,7 +400,7 @@ class ExternOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representationinputs */
/*! \brief Symbolic placeholder representation of inputs */
Array<Buffer> input_placeholders;
/*! \brief Symbolic placeholder representation of outputs */
Array<Buffer> output_placeholders;
Expand Down
47 changes: 47 additions & 0 deletions include/tvm/tensor_intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,52 @@ class TensorIntrinNode : public Node {
inline const TensorIntrinNode* TensorIntrin::operator->() const {
return static_cast<const TensorIntrinNode*>(node_.get());
}


// Internal node container of tensor intrinsics.
class TensorIntrinCallNode;

/*! \brief Tensor intrinsic node. */
class TensorIntrinCall : public NodeRef {
public:
TensorIntrinCall() {}
explicit TensorIntrinCall(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorIntrinCallNode* operator->() const;

/*! \brief specify container node */
using ContainerType = TensorIntrinCallNode;
};

class TensorIntrinCallNode : public Node {
public:
TensorIntrin intrin;
Array<Tensor> tensors;
Array<Region> regions;
Array<IterVar> reduce_axis;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", &regions);
v->Visit("reduce_axis", &reduce_axis);
}

static TensorIntrinCall make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis);

static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
};

inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
return static_cast<const TensorIntrinCallNode*>(node_.get());
}

} // namespace tvm
#endif // TVM_TENSOR_INTRIN_H_
37 changes: 27 additions & 10 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,24 +243,41 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.current.tag
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape)
code = fcompute.__code__

if fcompute.__code__.co_argcount == 0:
out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
else:
arg_names = code.co_varnames[:code.co_argcount]
out_ndim = code.co_argcount

if ndim != len(arg_names):
if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)

dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var])
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body)

if isinstance(body, _tensor.TensorIntrinCall):
tensor_var = []
for i, s in enumerate(shape[out_ndim:]):
var_name = "ax" + str(i)
tensor_var.append(_IterVar((0, s), var_name, 4))
op_node = _api_internal._TensorComputeOp(name,
tag,
dim_var,
tensor_var,
body)
else:
if not isinstance(body, (list, tuple)):
body = [body]
body = convert(body)
op_node = _api_internal._ComputeOp(
name, tag, attrs, dim_var, body)

num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs
Expand Down Expand Up @@ -529,14 +546,14 @@ def decl_buffer(shape,
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None:
elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
elem_offset = var('%s_elem_offset' % name, shape_dtype)
if data is None:
data = var(name, "handle")
return _api_internal._Buffer(
data, dtype, shape, strides, elem_offset, name, scope,
data_alignment, offset_factor)


def _IterVar(dom, name, iter_type, thread_tag=''):
"""Internal function to create IterVar

Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def dtype(self):
"""Data content of the tensor."""
return self.tensor.dtype

@register_node
class TensorIntrinCall(NodeBase):
"""Intermediate structure for calling a tensor intrinsic."""
pass


itervar_cls = None

Expand Down Expand Up @@ -106,6 +111,7 @@ def name(self):
return "%s.v%d" % (op.name, self.value_index)



class Operation(NodeBase):
"""Represent an operation that generate a tensor"""

Expand Down Expand Up @@ -155,6 +161,12 @@ def reduce_axis(self):
return self.__getattr__("reduce_axis")


@register_node
class TensorComputeOp(Operation):
"""Tensor operation."""
pass


@register_node
class ScanOp(Operation):
"""Scan operation."""
Expand Down
28 changes: 26 additions & 2 deletions python/tvm/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,25 @@
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor
from . import schedule as _schedule
from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node


def _get_region(tslice):
region = []
for idx in tslice.indices:
if isinstance(idx, slice):
assert idx.step is None
region.append(_api.Range(idx.start, idx.stop))
else:
if isinstance(idx, _schedule.IterVar):
begin = idx.var
else:
begin = idx
region.append(_make.range_by_min_extent(begin, 1))
return region

@register_node
class TensorIntrin(NodeBase):
"""Tensor intrinsic functions for certain computation.
Expand All @@ -17,8 +33,16 @@ class TensorIntrin(NodeBase):
--------
decl_tensor_intrin: Construct a TensorIntrin
"""
pass

def __call__(self, *args, **kwargs):
tensors = [x.tensor for x in args]
regions = [_get_region(x) for x in args]
reduce_axis = []
if "reduce_axis" in kwargs:
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)

def decl_tensor_intrin(op,
fcompute,
Expand Down
17 changes: 17 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin")
args[6]);
});

TVM_REGISTER_API("_TensorIntrinCall")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorIntrinCallNode::make(args[0],
args[1],
args[2],
args[3]);
});

TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Tensor() == args[1].operator Tensor();
Expand Down Expand Up @@ -278,6 +286,15 @@ TVM_REGISTER_API("_ScanOp")
args[7]);
});

TVM_REGISTER_API("_TensorComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorComputeOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});

TVM_REGISTER_API("_ExternOp")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ExternOpNode::make(args[0],
Expand Down
Loading