Skip to content

Commit

Permalink
[TIR][REFACTOR] Cleanup unused classes
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jun 12, 2020
1 parent 04496d3 commit 12f9bbe
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 58 deletions.
8 changes: 2 additions & 6 deletions include/tvm/arith/bound.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@
#include <unordered_map>

namespace tvm {
// forward delcare Tensor
namespace te {
class Tensor;
}
namespace arith {

using tir::Domain;
using tir::Region;
using tir::Stmt;
using tir::Var;
using tir::VarNode;
Expand Down Expand Up @@ -82,7 +78,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
* \param consider_stores If stores are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads,
Region DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads,
bool consider_stores);

} // namespace arith
Expand Down
8 changes: 5 additions & 3 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,18 @@ struct TensorDom {
/*!
* \brief Base class of all operation nodes
*/
class OperationNode : public tir::FunctionBaseNode {
class OperationNode : public Object {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief additional attributes of the operation*/
Map<String, ObjectRef> attrs;
/*! \return name of the operation */
const std::string& func_name() const final { return name; }
// virtual destructor.
virtual ~OperationNode() {}
/*! \return number of outputs */
virtual int num_outputs() const = 0;
/*!
* \return The list of iteration variable at root
* \note root_iter_vars decides the shape of the outputs.
Expand Down
5 changes: 3 additions & 2 deletions include/tvm/te/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ using namespace tvm::tir;

// internal node container for Operation
class OperationNode;
class Tensor;

/*! \brief Operation that produces tensors */
class Operation : public tir::FunctionRef {
class Operation : public ObjectRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down
34 changes: 0 additions & 34 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -870,40 +870,6 @@ class Let : public PrimExpr {
TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
};

// Call node, represent a function call or a multi-dimensional array load.
//
// TODO(tvm-team):
// Refactor call with more explicit property registrations.
// rather than calling a string symbol.
// We should move most information into function itself and remove name.

/*! \brief Base node of internal functions. */
class FunctionBaseNode : public Object {
public:
/*! \brief virtual destructor */
virtual ~FunctionBaseNode() {}
/*! \return the name of the function */
virtual const std::string& func_name() const = 0;
/*! \return the number of outputs of this function */
virtual int num_outputs() const = 0;

// fall back to pointer equality now before refactor.
bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const {
return this == other;
}

void SHashReduce(SHashReducer hash_reduce) const {}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
};

/*! \brief reference to a function */
class FunctionRef : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FunctionRef, ObjectRef, FunctionBaseNode);
};

/*!
* \brief Call node.
*/
Expand Down
2 changes: 0 additions & 2 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,6 @@ enum IterVarType : int {
kTensorized = 8
};

using Domain = Array<Range>;

/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
Expand Down
6 changes: 3 additions & 3 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class BufferTouchedDomain final : public StmtExprVisitor {
BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores)
: buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {}

Domain Find(const Stmt& stmt) {
Region Find(const Stmt& stmt) {
operator()(stmt);
Domain ret;
Region ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
ret.push_back(arith::Union(bounds_[i]).cover_range(none));
Expand Down Expand Up @@ -107,7 +107,7 @@ class BufferTouchedDomain final : public StmtExprVisitor {
std::unordered_map<const VarNode*, IntSet> dom_map_;
};

Domain DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
bool consider_stores) {
return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt);
}
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) {
if (id_map_.count(key)) {
return id_map_[key];
}
std::string name_hint = tensor->op->func_name();
std::string name_hint = tensor->op->name;
if (tensor->op->num_outputs() > 1) {
name_hint += "_v" + std::to_string(tensor->value_index);
}
Expand Down
12 changes: 6 additions & 6 deletions src/te/schedule/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ namespace tvm {
namespace te {
// key to specific tensor dimension.
struct TensorDimKey {
tir::FunctionRef f;
Operation op;
int value_index;
int dim;
TensorDimKey() {}
TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {}
TensorDimKey(const Tensor& t, int dim) : op(t->op), value_index(t->value_index), dim(dim) {}
TensorDimKey(const Tensor& t, size_t dim)
: f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
: op(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
inline bool operator==(const TensorDimKey& other) const {
return f == other.f && value_index == other.value_index && dim == other.dim;
return op == other.op && value_index == other.value_index && dim == other.dim;
}
inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); }
};
Expand All @@ -55,7 +55,7 @@ namespace std {
template <>
struct hash<::tvm::te::TensorDimKey> {
std::size_t operator()(const ::tvm::te::TensorDimKey& k) const {
size_t lhs = ::tvm::ObjectPtrHash()(k.f);
size_t lhs = ::tvm::ObjectPtrHash()(k.op);
size_t rhs = static_cast<size_t>(k.value_index) << 16UL | static_cast<size_t>(k.dim);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
Expand Down Expand Up @@ -378,7 +378,7 @@ Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
if (k != target && place_holder_ref.count(k)) break;
stack.pop_back();
if (!reach.count(k)) {
LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
LOG(FATAL) << "cannot find reach of " << k.op << "-" << k.dim;
}

for (TensorDimKey kk : reach.at(k)) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/inject_prefetch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class PrefetchInjector : public StmtMutator {
if (op && op->attr_key == attr::prefetch_scope) {
Buffer buffer = Downcast<Buffer>(op->node);
CHECK_NE(loop_nest_.size(), 0U);
Domain domain = DomainTouched(op->body, buffer, true, false);
Region domain = DomainTouched(op->body, buffer, true, false);
Region region;

auto iter_var = loop_nest_.back().get();
Expand Down

0 comments on commit 12f9bbe

Please sign in to comment.