Skip to content

Commit

Permalink
[TVMSCRIPT] Attach span information to tir nodes in tvmscript (#6910)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkonolige authored Dec 5, 2020
1 parent 0d46cf7 commit fdfc7eb
Show file tree
Hide file tree
Showing 40 changed files with 1,346 additions and 640 deletions.
14 changes: 11 additions & 3 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ class IntImmNode : public PrimExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -283,6 +284,7 @@ class FloatImmNode : public PrimExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -415,13 +417,17 @@ class RangeNode : public Object {
PrimExpr min;
/*! \brief the extend of range */
PrimExpr extent;
/*! \brief the location of this range in the source */
mutable Span span;
/*! \brief constructor */
RangeNode() {}
RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {}
RangeNode(PrimExpr min, PrimExpr extent, Span span = Span())
: min(min), extent(extent), span(span) {}

void VisitAttrs(AttrVisitor* v) {
v->Visit("min", &min);
v->Visit("extent", &extent);
v->Visit("span", &span);
}

bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
Expand All @@ -446,8 +452,9 @@ class Range : public ObjectRef {
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
* \param span The location of the Range in the source.
*/
TVM_DLL Range(PrimExpr begin, PrimExpr end);
TVM_DLL Range(PrimExpr begin, PrimExpr end, Span span = Span());
/*!
* \brief construct a new range with min and extent
* The corresponding constructor is removed,
Expand All @@ -456,8 +463,9 @@ class Range : public ObjectRef {
*
* \param min The minimum range.
* \param extent The extent of the range.
* \param span The location of the Range in the source.
*/
static Range FromMinExtent(PrimExpr min, PrimExpr extent);
static Range FromMinExtent(PrimExpr min, PrimExpr extent, Span span = Span());
// declare range.
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
};
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,9 @@ class TVMArgValue : public TVMPODValue_ {
} else if (type_code_ == kTVMStr) {
return std::string(value_.v_str);
} else {
ICHECK(IsObjectRef<tvm::runtime::String>());
ICHECK(IsObjectRef<tvm::runtime::String>())
<< "Could not convert TVM object of type " << runtime::Object::TypeIndex2Key(type_code_)
<< " to a string.";
return AsObjectRef<tvm::runtime::String>().operator std::string();
}
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class BufferNode : public Object {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("buffer_type", &buffer_type);
v->Visit("span", &span);
}

bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
Expand Down
23 changes: 22 additions & 1 deletion include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class StringImmNode : public PrimExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -90,6 +91,7 @@ class CastNode : public PrimExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -131,6 +133,7 @@ class BinaryOpNode : public PrimExprNode {
v->Visit("dtype", &(this->dtype));
v->Visit("a", &a);
v->Visit("b", &b);
v->Visit("span", &span);
}

bool SEqualReduce(const T* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -312,6 +315,7 @@ class CmpOpNode : public PrimExprNode {
v->Visit("dtype", &(this->dtype));
v->Visit("a", &a);
v->Visit("b", &b);
v->Visit("span", &span);
}

bool SEqualReduce(const T* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -435,6 +439,7 @@ class AndNode : public PrimExprNode {
v->Visit("dtype", &(this->dtype));
v->Visit("a", &a);
v->Visit("b", &b);
v->Visit("span", &span);
}

bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -473,6 +478,7 @@ class OrNode : public PrimExprNode {
v->Visit("dtype", &dtype);
v->Visit("a", &a);
v->Visit("b", &b);
v->Visit("span", &span);
}

bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -508,6 +514,7 @@ class NotNode : public PrimExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("a", &a);
v->Visit("span", &span);
}

bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -554,6 +561,7 @@ class SelectNode : public PrimExprNode {
v->Visit("condition", &condition);
v->Visit("true_value", &true_value);
v->Visit("false_value", &false_value);
v->Visit("span", &span);
}

bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -604,6 +612,7 @@ class BufferLoadNode : public PrimExprNode {
v->Visit("dtype", &(this->dtype));
v->Visit("buffer", &buffer);
v->Visit("indices", &indices);
v->Visit("span", &span);
}

bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -651,6 +660,7 @@ class ProducerLoadNode : public PrimExprNode {
v->Visit("dtype", &(this->dtype));
v->Visit("producer", &producer);
v->Visit("indices", &indices);
v->Visit("span", &span);
}

bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -708,6 +718,7 @@ class LoadNode : public PrimExprNode {
v->Visit("buffer_var", &buffer_var);
v->Visit("index", &index);
v->Visit("predicate", &predicate);
v->Visit("span", &span);
}

bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -760,6 +771,7 @@ class RampNode : public PrimExprNode {
v->Visit("base", &base);
v->Visit("stride", &stride);
v->Visit("lanes", &lanes);
v->Visit("span", &span);
}

bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -800,6 +812,7 @@ class BroadcastNode : public PrimExprNode {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
v->Visit("lanes", &lanes);
v->Visit("span", &span);
}

bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -843,6 +856,7 @@ class LetNode : public PrimExprNode {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
v->Visit("span", &span);
}

bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -890,6 +904,7 @@ class CallNode : public PrimExprNode {
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("span", &span);
}

bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -931,6 +946,7 @@ class ShuffleNode : public PrimExprNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("vectors", &vectors);
v->Visit("indices", &indices);
v->Visit("span", &span);
}

bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -993,6 +1009,7 @@ class CommReducerNode : public Object {
v->Visit("rhs", &rhs);
v->Visit("result", &result);
v->Visit("identity_element", &identity_element);
v->Visit("span", &span);
}

bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -1052,6 +1069,7 @@ class ReduceNode : public PrimExprNode {
v->Visit("axis", &axis);
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
v->Visit("span", &span);
}

bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
Expand Down Expand Up @@ -1091,7 +1109,10 @@ class Reduce : public PrimExpr {
/*! \brief Any shape. */
class AnyNode : public PrimExprNode {
public:
void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); }
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}

bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
Expand Down
Loading

0 comments on commit fdfc7eb

Please sign in to comment.