Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][REFACTOR] Cleanup unused classes #5789

Merged
merged 1 commit into from
Jun 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions include/tvm/arith/bound.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@
#include <unordered_map>

namespace tvm {
// forward delcare Tensor
namespace te {
class Tensor;
}
namespace arith {

using tir::Domain;
using tir::Region;
using tir::Stmt;
using tir::Var;
using tir::VarNode;
Expand Down Expand Up @@ -82,7 +78,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
* \param consider_stores If stores are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads,
Region DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads,
bool consider_stores);

} // namespace arith
Expand Down
8 changes: 5 additions & 3 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,18 @@ struct TensorDom {
/*!
* \brief Base class of all operation nodes
*/
class OperationNode : public tir::FunctionBaseNode {
class OperationNode : public Object {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief additional attributes of the operation*/
Map<String, ObjectRef> attrs;
/*! \return name of the operation */
const std::string& func_name() const final { return name; }
// virtual destructor.
virtual ~OperationNode() {}
/*! \return number of outputs */
virtual int num_outputs() const = 0;
/*!
* \return The list of iteration variable at root
* \note root_iter_vars decides the shape of the outputs.
Expand Down
5 changes: 3 additions & 2 deletions include/tvm/te/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ using namespace tvm::tir;

// internal node container for Operation
class OperationNode;
class Tensor;

/*! \brief Operation that produces tensors */
class Operation : public tir::FunctionRef {
class Operation : public ObjectRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down
34 changes: 0 additions & 34 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -870,40 +870,6 @@ class Let : public PrimExpr {
TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
};

// Call node, represent a function call or a multi-dimensional array load.
//
// TODO(tvm-team):
// Refactor call with more explicit property registrations.
// rather than calling a string symbol.
// We should move most information into function itself and remove name.

/*! \brief Base node of internal functions. */
class FunctionBaseNode : public Object {
public:
/*! \brief virtual destructor */
virtual ~FunctionBaseNode() {}
/*! \return the name of the function */
virtual const std::string& func_name() const = 0;
/*! \return the number of outputs of this function */
virtual int num_outputs() const = 0;

// fall back to pointer equality now before refactor.
bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const {
return this == other;
}

void SHashReduce(SHashReducer hash_reduce) const {}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
};

/*! \brief reference to a function */
class FunctionRef : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FunctionRef, ObjectRef, FunctionBaseNode);
};

/*!
* \brief Call node.
*/
Expand Down
2 changes: 0 additions & 2 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,6 @@ enum IterVarType : int {
kTensorized = 8
};

using Domain = Array<Range>;

/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
Expand Down
6 changes: 3 additions & 3 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class BufferTouchedDomain final : public StmtExprVisitor {
BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores)
: buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {}

Domain Find(const Stmt& stmt) {
Region Find(const Stmt& stmt) {
operator()(stmt);
Domain ret;
Region ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
ret.push_back(arith::Union(bounds_[i]).cover_range(none));
Expand Down Expand Up @@ -107,7 +107,7 @@ class BufferTouchedDomain final : public StmtExprVisitor {
std::unordered_map<const VarNode*, IntSet> dom_map_;
};

Domain DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
bool consider_stores) {
return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt);
}
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) {
if (id_map_.count(key)) {
return id_map_[key];
}
std::string name_hint = tensor->op->func_name();
std::string name_hint = tensor->op->name;
if (tensor->op->num_outputs() > 1) {
name_hint += "_v" + std::to_string(tensor->value_index);
}
Expand Down
12 changes: 6 additions & 6 deletions src/te/schedule/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ namespace tvm {
namespace te {
// key to specific tensor dimension.
struct TensorDimKey {
tir::FunctionRef f;
Operation op;
int value_index;
int dim;
TensorDimKey() {}
TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {}
TensorDimKey(const Tensor& t, int dim) : op(t->op), value_index(t->value_index), dim(dim) {}
TensorDimKey(const Tensor& t, size_t dim)
: f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
: op(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
inline bool operator==(const TensorDimKey& other) const {
return f == other.f && value_index == other.value_index && dim == other.dim;
return op == other.op && value_index == other.value_index && dim == other.dim;
}
inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); }
};
Expand All @@ -55,7 +55,7 @@ namespace std {
template <>
struct hash<::tvm::te::TensorDimKey> {
std::size_t operator()(const ::tvm::te::TensorDimKey& k) const {
size_t lhs = ::tvm::ObjectPtrHash()(k.f);
size_t lhs = ::tvm::ObjectPtrHash()(k.op);
size_t rhs = static_cast<size_t>(k.value_index) << 16UL | static_cast<size_t>(k.dim);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
Expand Down Expand Up @@ -378,7 +378,7 @@ Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
if (k != target && place_holder_ref.count(k)) break;
stack.pop_back();
if (!reach.count(k)) {
LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
LOG(FATAL) << "cannot find reach of " << k.op << "-" << k.dim;
}

for (TensorDimKey kk : reach.at(k)) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/inject_prefetch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class PrefetchInjector : public StmtMutator {
if (op && op->attr_key == attr::prefetch_scope) {
Buffer buffer = Downcast<Buffer>(op->node);
CHECK_NE(loop_nest_.size(), 0U);
Domain domain = DomainTouched(op->body, buffer, true, false);
Region domain = DomainTouched(op->body, buffer, true, false);
Region region;

auto iter_var = loop_nest_.back().get();
Expand Down