forked from neo-ai/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TensorIR][M1b] Schedule class (apache#7847)
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Jared Roesch <roeschinc@gmail.com>
- Loading branch information
1 parent
6f4797e
commit 4be2382
Showing
11 changed files
with
1,210 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ | ||
#define TVM_TIR_SCHEDULE_SCHEDULE_H_ | ||
|
||
#include <tvm/tir/schedule/state.h> | ||
|
||
namespace tvm { | ||
namespace tir { | ||
|
||
/**************** Random variable: BlockRV ****************/ | ||
|
||
/*! \brief A random variable that evaluates to a TensorIR block */ | ||
class BlockRVNode : public runtime::Object { | ||
public: | ||
void VisitAttrs(tvm::AttrVisitor* v) {} | ||
static constexpr const char* _type_key = "tir.BlockRV"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to BlockRVNode | ||
* \sa BlockRVNode | ||
*/ | ||
class BlockRV : public runtime::ObjectRef { | ||
public: | ||
/*! \brief Constructor */ | ||
TVM_DLL BlockRV(); | ||
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode); | ||
}; | ||
|
||
/**************** Random variable: LoopRV ****************/ | ||
|
||
/*! \brief A random variable that evaluates to a TensorIR for loop */ | ||
class LoopRVNode : public runtime::Object { | ||
public: | ||
void VisitAttrs(tvm::AttrVisitor* v) {} | ||
static constexpr const char* _type_key = "tir.LoopRV"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to LoopRVNode | ||
* \sa LoopRVNode | ||
*/ | ||
class LoopRV : public runtime::ObjectRef { | ||
public: | ||
/*! \brief Constructor */ | ||
TVM_DLL LoopRV(); | ||
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode); | ||
}; | ||
|
||
/**************** Random variable: IntRV ****************/ | ||
|
||
/*! \brief An integer random variable */ | ||
using IntRV = PrimExpr; | ||
|
||
using IntRVNode = PrimExprNode; | ||
|
||
/**************** The Schedule class ****************/ | ||
|
||
class Schedule; | ||
|
||
/*! \brief The user-facing schedule class */ | ||
class ScheduleNode : public runtime::Object { | ||
friend class Schedule; | ||
|
||
public: | ||
virtual ~ScheduleNode() = default; | ||
|
||
static constexpr const char* _type_key = "tir.Schedule"; | ||
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object); | ||
|
||
public: | ||
/*! \brief Get the IRModule associated with this schedule. */ | ||
virtual IRModule mod() const { return state()->mod; } | ||
/*! \return The internal state of scheduling */ | ||
virtual ScheduleState state() const = 0; | ||
/*! | ||
* \brief Returns a copy of the schedule, including both its state and its symbol table, | ||
* guaranteeing that | ||
* 1) SRef tree is completely reconstructed; | ||
* 2) The IRModule being scheduled is not modified; | ||
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref | ||
* reconstructed | ||
*/ | ||
virtual Schedule Copy() const = 0; | ||
/*! | ||
* \brief Seed the randomness | ||
* \param seed The new random seed, -1 if use device random, otherwise non-negative | ||
*/ | ||
virtual void Seed(int64_t seed = -1) { | ||
LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed"; | ||
} | ||
|
||
public: | ||
/******** Lookup/Remove random variables ********/ | ||
/*! | ||
* \brief Get the block corresponding to the specific BlockRV | ||
* \param block_rv The BlockRV to be looked up | ||
* \return The corresponding block | ||
*/ | ||
virtual Block Get(const BlockRV& block_rv) const = 0; | ||
/*! | ||
* \brief Get the for loop corresponding to the specific LoopRV | ||
* \param loop_rv The LoopRV to be looked up | ||
* \return The corresponding for loop | ||
*/ | ||
virtual For Get(const LoopRV& loop_rv) const = 0; | ||
/*! | ||
* \brief Get the value corresponding to the specific random variable | ||
* \param int_rv The random variable to be looked up | ||
* \return The corresponding value | ||
*/ | ||
virtual int64_t Get(const IntRV& int_rv) const = 0; | ||
/*! | ||
* \brief Get the block sref corresponding to the specific BlockRV | ||
* \param block_rv The BlockRV to be looked up | ||
* \return The corresponding block sref | ||
*/ | ||
virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0; | ||
/*! | ||
* \brief Get the loop sref corresponding to the specific LoopRV | ||
* \param loop_rv The LoopRV to be looked up | ||
* \return The corresponding loop sref | ||
*/ | ||
virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; | ||
/*! | ||
* \brief Get the block/loop sref corresponding to the specific statement | ||
* \param stmt The statement to be looked up | ||
* \return The corresponding block/loop sref | ||
*/ | ||
virtual StmtSRef GetSRef(const StmtNode* stmt) const; | ||
/*! | ||
* \brief Get the block/loop sref corresponding to the specific statement | ||
* \param stmt The statement to be looked up | ||
* \return The corresponding block/loop sref | ||
*/ | ||
StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); } | ||
/*! | ||
* \brief Remove a block random variable from the symbol table | ||
* \param block_rv The random variable to be removed | ||
*/ | ||
virtual void RemoveRV(const BlockRV& block_rv) = 0; | ||
/*! | ||
* \brief Remove a loop random variable from the symbol table | ||
* \param loop_rv The random variable to be removed | ||
*/ | ||
virtual void RemoveRV(const LoopRV& loop_rv) = 0; | ||
/*! | ||
* \brief Remove an integer random variable from the symbol table | ||
* \param int_rv The random variable to be removed | ||
*/ | ||
virtual void RemoveRV(const IntRV& int_rv) = 0; | ||
|
||
public: | ||
/******** Block/Loop relation ********/ | ||
/*! | ||
* \brief Retrieve a block in a specific function with its name | ||
* \param name The name of the block to be retrieved | ||
* \param func_name The name of the function | ||
* \return The block retrieved | ||
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name | ||
*/ | ||
virtual BlockRV GetBlock(const String& name, const String& func_name = "main") = 0; | ||
/*! | ||
* \brief Get the parent loops of the block in its scope, from outer to inner | ||
* \param block_rv The query block | ||
* \return A list of loops above the given block in its scope, from outer to inner | ||
*/ | ||
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0; | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to ScheduleNode | ||
* | ||
* A schedule is a set of transformations that change the order of computation but | ||
* preserve the semantics of computation. Some example of schedules: | ||
* 1) Split a loop into two; | ||
* 2) Reorder two loops; | ||
* 3) Inline the computation of a specific buffer into its consumer | ||
* | ||
* The schedule class stores auxiliary information to schedule correctly and efficiently. | ||
* | ||
* Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html | ||
* | ||
* \sa ScheduleNode | ||
*/ | ||
class Schedule : public runtime::ObjectRef { | ||
public: | ||
/*! | ||
* \brief Construct a concrete TensorIR schedule from an IRModule | ||
* \param mod The IRModule to be scheduled | ||
* \param debug_mode Do extra correctness checking after the class creation | ||
* and each time after calling the Replace method. | ||
* \return The concrete schedule created | ||
* \sa ScheduleDebugMask | ||
* \note The checks performed includes: | ||
* 1) VerifySRefTree | ||
* 2) VerifyAffineBinding | ||
* 3) VerifyRegionCover | ||
* 4) VerifyStagePipeline | ||
*/ | ||
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode); | ||
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); | ||
}; | ||
|
||
} // namespace tir | ||
} // namespace tvm | ||
|
||
#endif // TVM_TIR_SCHEDULE_SCHEDULE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.