Skip to content

Commit

Permalink
[REFACTOR] top - namespace for Tensor Operation DSL (apache#4727)
Browse files Browse the repository at this point in the history
* [REFACTOR] introduce top - Tensor Operation DSL.

Historically we put Tensor, Schedule and compute under the root tvm namespace.
This is no longer a good idea as the project's scope grows larger
than the tensor operation DSL.

This PR introduces top -- a namespace for tensor operational
DSL concepts such as schedule, tensor, compute.
We moved the related files to the new top subfolder.

* Move relevant files into include/tvm/top and src/top
  • Loading branch information
tqchen authored and zhiics committed Mar 2, 2020
1 parent b4645ec commit 62a7e53
Show file tree
Hide file tree
Showing 125 changed files with 830 additions and 730 deletions.
14 changes: 9 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,24 @@ assign_source_group("Source" ${GROUP_SOURCE})
assign_source_group("Include" ${GROUP_INCLUDE})

# Source file lists
file(GLOB COMPILER_SRCS
file(GLOB_RECURSE COMPILER_SRCS
src/node/*.cc
src/ir/*.cc
src/target/*.cc
src/api/*.cc
src/arith/*.cc
src/top/*.cc
src/api/*.cc
src/autotvm/*.cc
src/codegen/*.cc
src/lang/*.cc
src/pass/*.cc
src/op/*.cc
src/schedule/*.cc
)

file(GLOB CODEGEN_SRCS
src/codegen/*.cc
)

list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})

file(GLOB_RECURSE RELAY_OP_SRCS
src/relay/op/*.cc
)
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/arith/bound.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

namespace tvm {
// forward delcare Tensor
namespace top {
class Tensor;
}
namespace arith {

/*!
Expand Down Expand Up @@ -75,7 +77,10 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
* \param consider_provides If provides (write) are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
Domain DomainTouched(Stmt body,
const top::Tensor &tensor,
bool consider_calls,
bool consider_provides);

} // namespace arith
} // namespace tvm
Expand Down
10 changes: 6 additions & 4 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@

#include <tvm/target/target.h>
#include <tvm/support/with.h>
#include <tvm/top/schedule_pass.h>

#include <string>
#include <vector>
#include <utility>
#include <unordered_map>
#include <unordered_set>

#include "runtime/packed_func.h"
#include "schedule_pass.h"

#include "lowered_func.h"

namespace tvm {
Expand Down Expand Up @@ -172,10 +174,10 @@ class BuildConfig : public ::tvm::ObjectRef {
* \param config The build configuration.
* \return The lowered function.
*/
TVM_DLL Array<LoweredFunc> lower(Schedule sch,
const Array<Tensor>& args,
TVM_DLL Array<LoweredFunc> lower(top::Schedule sch,
const Array<top::Tensor>& args,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
const std::unordered_map<top::Tensor, Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_

#include <tvm/top/schedule.h>

#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <string>
#include "expr.h"
#include "buffer.h"
#include "schedule.h"
#include "lowered_func.h"

namespace tvm {
Expand Down Expand Up @@ -203,7 +204,7 @@ Stmt Inline(Stmt stmt,
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
Map<top::Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);

Expand All @@ -217,8 +218,8 @@ Stmt StorageFlatten(Stmt stmt,
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer);
top::Schedule schedule,
Map<top::Tensor, Buffer> extern_buffer);

/*!
* \brief Verify if there is any argument bound to compact buffer.
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/lowered_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_

#include <tvm/top/tensor.h>

#include <string>

#include "expr.h"
#include "tensor.h"
#include "tvm/node/container.h"

namespace tvm {
Expand Down
7 changes: 4 additions & 3 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_

#include <tvm/top/tensor.h>

#include <string>
#include <memory>
#include <limits>
#include <type_traits>

#include "expr.h"
#include "tensor.h"
#include "runtime/packed_func.h"

namespace tvm {
Expand Down Expand Up @@ -116,8 +117,8 @@ inline TVMPODValue_::operator tvm::PrimExpr() const {
if (ptr->IsInstance<IterVarNode>()) {
return IterVar(ObjectPtr<Object>(ptr))->var;
}
if (ptr->IsInstance<TensorNode>()) {
return Tensor(ObjectPtr<Object>(ptr))();
if (ptr->IsInstance<top::TensorNode>()) {
return top::Tensor(ObjectPtr<Object>(ptr))();
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
Expand Down
28 changes: 14 additions & 14 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
#define TVM_RELAY_OP_ATTR_TYPES_H_

#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/top/tensor.h>
#include <tvm/top/schedule.h>
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
Expand Down Expand Up @@ -99,10 +99,10 @@ using TShapeDataDependant = bool;
* \return The output compute description of the operator.
*/
using FTVMCompute = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target)>;
Array<top::Tensor>(const Attrs& attrs,
const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target)>;

/*!
* \brief Build the computation schedule for
Expand All @@ -114,9 +114,9 @@ using FTVMCompute = runtime::TypedPackedFunc<
* \return schedule The computation schedule.
*/
using FTVMSchedule = runtime::TypedPackedFunc<
Schedule(const Attrs& attrs,
const Array<Tensor>& outs,
const Target& target)>;
top::Schedule(const Attrs& attrs,
const Array<top::Tensor>& outs,
const Target& target)>;

/*!
* \brief Alternate the layout of operators or replace the
Expand All @@ -131,7 +131,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos)>;
const Array<top::Tensor>& tinfos)>;

/*!
* \brief Convert the layout of operators or replace the
Expand All @@ -147,7 +147,7 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
using FTVMConvertOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos,
const Array<top::Tensor>& tinfos,
const std::string& desired_layout)>;
/*!
* \brief Legalizes an expression with another expression. This function will be
Expand Down Expand Up @@ -206,9 +206,9 @@ enum AnyCodegenStrategy {
using Shape = Array<IndexExpr>;

using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Array<IndexExpr>& out_ndims)>;
Array<top::Tensor>(const Attrs& attrs,
const Array<top::Tensor>& inputs,
const Array<IndexExpr>& out_ndims)>;

} // namespace relay
} // namespace tvm
Expand Down
22 changes: 13 additions & 9 deletions include/tvm/operation.h → include/tvm/top/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,28 @@
*/

/*!
* \file tvm/operation.h
* \file tvm/top/operation.h
* \brief Operation node can generate one or multiple Tensors
*/
#ifndef TVM_OPERATION_H_
#define TVM_OPERATION_H_
#ifndef TVM_TOP_OPERATION_H_
#define TVM_TOP_OPERATION_H_

#include <tvm/arith/analyzer.h>
#include <tvm/top/tensor.h>
#include <tvm/top/schedule.h>

#include <tvm/expr.h>
#include <tvm/expr_operator.h>
#include <tvm/buffer.h>

#include <string>
#include <vector>
#include <unordered_map>

#include "expr.h"
#include "expr_operator.h"
#include "tensor.h"
#include "schedule.h"
#include "buffer.h"


namespace tvm {
namespace top {

using arith::IntSet;

Expand Down Expand Up @@ -655,5 +658,6 @@ inline Tensor compute(Array<PrimExpr> shape,
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(get());
}
} // namespace top
} // namespace tvm
#endif // TVM_OPERATION_H_
#endif // TVM_TOP_OPERATION_H_
20 changes: 12 additions & 8 deletions include/tvm/schedule.h → include/tvm/top/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
*/

/*!
* \file tvm/schedule.h
* \file tvm/top/schedule.h
* \brief Define a schedule.
*/
// Acknowledgement: Many schedule primitives originate from Halide and Loopy.
#ifndef TVM_SCHEDULE_H_
#define TVM_SCHEDULE_H_
#ifndef TVM_TOP_SCHEDULE_H_
#define TVM_TOP_SCHEDULE_H_

#include <tvm/expr.h>
#include <tvm/top/tensor.h>
#include <tvm/top/tensor_intrin.h>


#include <string>
#include <unordered_map>
#include "expr.h"
#include "tensor.h"
#include "tensor_intrin.h"

namespace tvm {

namespace tvm {
namespace top {
// Node container for Stage
class StageNode;
// Node container for Schedule
Expand Down Expand Up @@ -764,5 +767,6 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(get());
}
} // namespace top
} // namespace tvm
#endif // TVM_SCHEDULE_H_
#endif // TVM_TOP_SCHEDULE_H_
14 changes: 7 additions & 7 deletions include/tvm/schedule_pass.h → include/tvm/top/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@
*/

/*!
* \file tvm/schedule_pass.h
* \file tvm/top/schedule_pass.h
* \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/
#ifndef TVM_SCHEDULE_PASS_H_
#define TVM_SCHEDULE_PASS_H_
#ifndef TVM_TOP_SCHEDULE_PASS_H_
#define TVM_TOP_SCHEDULE_PASS_H_

#include "schedule.h"
#include <tvm/top/schedule.h>

namespace tvm {
namespace schedule {
namespace top {

/*!
* \brief Infer the bound of all iteration variables relates to the schedule.
Expand Down Expand Up @@ -71,6 +71,6 @@ void AutoInlineElemWise(Schedule sch);
*/
TVM_DLL void AutoInlineInjective(Schedule sch);

} // namespace schedule
} // namespace top
} // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
#endif // TVM_TOP_SCHEDULE_PASS_H_
Loading

0 comments on commit 62a7e53

Please sign in to comment.