Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] top - namespace for Tensor Operation DSL #4727

Merged
merged 2 commits into from
Jan 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,24 @@ assign_source_group("Source" ${GROUP_SOURCE})
assign_source_group("Include" ${GROUP_INCLUDE})

# Source file lists
file(GLOB COMPILER_SRCS
file(GLOB_RECURSE COMPILER_SRCS
src/node/*.cc
src/ir/*.cc
src/target/*.cc
src/api/*.cc
src/arith/*.cc
src/top/*.cc
src/api/*.cc
src/autotvm/*.cc
src/codegen/*.cc
src/lang/*.cc
src/pass/*.cc
src/op/*.cc
src/schedule/*.cc
)

file(GLOB CODEGEN_SRCS
src/codegen/*.cc
)

list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})

file(GLOB_RECURSE RELAY_OP_SRCS
src/relay/op/*.cc
)
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/arith/bound.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

namespace tvm {
// forward delcare Tensor
namespace top {
class Tensor;
}
namespace arith {

/*!
Expand Down Expand Up @@ -75,7 +77,10 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
* \param consider_provides If provides (write) are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
Domain DomainTouched(Stmt body,
const top::Tensor &tensor,
bool consider_calls,
bool consider_provides);

} // namespace arith
} // namespace tvm
Expand Down
10 changes: 6 additions & 4 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@

#include <tvm/target/target.h>
#include <tvm/support/with.h>
#include <tvm/top/schedule_pass.h>

#include <string>
#include <vector>
#include <utility>
#include <unordered_map>
#include <unordered_set>

#include "runtime/packed_func.h"
#include "schedule_pass.h"

#include "lowered_func.h"

namespace tvm {
Expand Down Expand Up @@ -172,10 +174,10 @@ class BuildConfig : public ::tvm::ObjectRef {
* \param config The build configuration.
* \return The lowered function.
*/
TVM_DLL Array<LoweredFunc> lower(Schedule sch,
const Array<Tensor>& args,
TVM_DLL Array<LoweredFunc> lower(top::Schedule sch,
const Array<top::Tensor>& args,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
const std::unordered_map<top::Tensor, Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_

#include <tvm/top/schedule.h>

#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <string>
#include "expr.h"
#include "buffer.h"
#include "schedule.h"
#include "lowered_func.h"

namespace tvm {
Expand Down Expand Up @@ -203,7 +204,7 @@ Stmt Inline(Stmt stmt,
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
Map<top::Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);

Expand All @@ -217,8 +218,8 @@ Stmt StorageFlatten(Stmt stmt,
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer);
top::Schedule schedule,
Map<top::Tensor, Buffer> extern_buffer);

/*!
* \brief Verify if there is any argument bound to compact buffer.
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/lowered_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_

#include <tvm/top/tensor.h>

#include <string>

#include "expr.h"
#include "tensor.h"
#include "tvm/node/container.h"

namespace tvm {
Expand Down
7 changes: 4 additions & 3 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_

#include <tvm/top/tensor.h>

#include <string>
#include <memory>
#include <limits>
#include <type_traits>

#include "expr.h"
#include "tensor.h"
#include "runtime/packed_func.h"

namespace tvm {
Expand Down Expand Up @@ -116,8 +117,8 @@ inline TVMPODValue_::operator tvm::PrimExpr() const {
if (ptr->IsInstance<IterVarNode>()) {
return IterVar(ObjectPtr<Object>(ptr))->var;
}
if (ptr->IsInstance<TensorNode>()) {
return Tensor(ObjectPtr<Object>(ptr))();
if (ptr->IsInstance<top::TensorNode>()) {
return top::Tensor(ObjectPtr<Object>(ptr))();
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
Expand Down
28 changes: 14 additions & 14 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
#define TVM_RELAY_OP_ATTR_TYPES_H_

#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/top/tensor.h>
#include <tvm/top/schedule.h>
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
Expand Down Expand Up @@ -99,10 +99,10 @@ using TShapeDataDependant = bool;
* \return The output compute description of the operator.
*/
using FTVMCompute = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target)>;
Array<top::Tensor>(const Attrs& attrs,
const Array<top::Tensor>& inputs,
const Type& out_type,
const Target& target)>;

/*!
* \brief Build the computation schedule for
Expand All @@ -114,9 +114,9 @@ using FTVMCompute = runtime::TypedPackedFunc<
* \return schedule The computation schedule.
*/
using FTVMSchedule = runtime::TypedPackedFunc<
Schedule(const Attrs& attrs,
const Array<Tensor>& outs,
const Target& target)>;
top::Schedule(const Attrs& attrs,
const Array<top::Tensor>& outs,
const Target& target)>;

/*!
* \brief Alternate the layout of operators or replace the
Expand All @@ -131,7 +131,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos)>;
const Array<top::Tensor>& tinfos)>;

/*!
* \brief Convert the layout of operators or replace the
Expand All @@ -147,7 +147,7 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
using FTVMConvertOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos,
const Array<top::Tensor>& tinfos,
const std::string& desired_layout)>;
/*!
* \brief Legalizes an expression with another expression. This function will be
Expand Down Expand Up @@ -206,9 +206,9 @@ enum AnyCodegenStrategy {
using Shape = Array<IndexExpr>;

using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
const Array<IndexExpr>& out_ndims)>;
Array<top::Tensor>(const Attrs& attrs,
const Array<top::Tensor>& inputs,
const Array<IndexExpr>& out_ndims)>;

} // namespace relay
} // namespace tvm
Expand Down
22 changes: 13 additions & 9 deletions include/tvm/operation.h → include/tvm/top/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,28 @@
*/

/*!
* \file tvm/operation.h
* \file tvm/top/operation.h
* \brief Operation node can generate one or multiple Tensors
*/
#ifndef TVM_OPERATION_H_
#define TVM_OPERATION_H_
#ifndef TVM_TOP_OPERATION_H_
#define TVM_TOP_OPERATION_H_

#include <tvm/arith/analyzer.h>
#include <tvm/top/tensor.h>
#include <tvm/top/schedule.h>

#include <tvm/expr.h>
#include <tvm/expr_operator.h>
#include <tvm/buffer.h>

#include <string>
#include <vector>
#include <unordered_map>

#include "expr.h"
#include "expr_operator.h"
#include "tensor.h"
#include "schedule.h"
#include "buffer.h"


namespace tvm {
namespace top {

using arith::IntSet;

Expand Down Expand Up @@ -655,5 +658,6 @@ inline Tensor compute(Array<PrimExpr> shape,
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(get());
}
} // namespace top
} // namespace tvm
#endif // TVM_OPERATION_H_
#endif // TVM_TOP_OPERATION_H_
20 changes: 12 additions & 8 deletions include/tvm/schedule.h → include/tvm/top/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
*/

/*!
* \file tvm/schedule.h
* \file tvm/top/schedule.h
* \brief Define a schedule.
*/
// Acknowledgement: Many schedule primitives originate from Halide and Loopy.
#ifndef TVM_SCHEDULE_H_
#define TVM_SCHEDULE_H_
#ifndef TVM_TOP_SCHEDULE_H_
#define TVM_TOP_SCHEDULE_H_

#include <tvm/expr.h>
#include <tvm/top/tensor.h>
#include <tvm/top/tensor_intrin.h>


#include <string>
#include <unordered_map>
#include "expr.h"
#include "tensor.h"
#include "tensor_intrin.h"

namespace tvm {

namespace tvm {
namespace top {
// Node container for Stage
class StageNode;
// Node container for Schedule
Expand Down Expand Up @@ -764,5 +767,6 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(get());
}
} // namespace top
} // namespace tvm
#endif // TVM_SCHEDULE_H_
#endif // TVM_TOP_SCHEDULE_H_
14 changes: 7 additions & 7 deletions include/tvm/schedule_pass.h → include/tvm/top/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@
*/

/*!
* \file tvm/schedule_pass.h
* \file tvm/top/schedule_pass.h
* \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/
#ifndef TVM_SCHEDULE_PASS_H_
#define TVM_SCHEDULE_PASS_H_
#ifndef TVM_TOP_SCHEDULE_PASS_H_
#define TVM_TOP_SCHEDULE_PASS_H_

#include "schedule.h"
#include <tvm/top/schedule.h>

namespace tvm {
namespace schedule {
namespace top {

/*!
* \brief Infer the bound of all iteration variables relates to the schedule.
Expand Down Expand Up @@ -71,6 +71,6 @@ void AutoInlineElemWise(Schedule sch);
*/
TVM_DLL void AutoInlineInjective(Schedule sch);

} // namespace schedule
} // namespace top
} // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
#endif // TVM_TOP_SCHEDULE_PASS_H_
Loading