Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][pass manager] Open transform namespace #3226

Merged
merged 6 commits into from
May 22, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 117 additions & 16 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <string>
#include <vector>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -182,6 +183,122 @@ class Pass : public NodeRef {
using ContainerType = PassNode;
};

class Sequential;

/*!
* \brief The SequentialNode contains a set of passes that transform Relay
* programs from one AST to another semantically equivalent one.
*
* One example of this level of pass is that the pass manager needs to correctly
* perform a host of optimizations with a given optimization level and disabled
* passes.
*/
class SequentialNode : public PassNode {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

declare node, but hide the implementation inside. (User don't need to see SequentialNode, they only need to see Sequantial)

public:
/* \brief The pass meta data.*/
PassInfo pass_info;

/*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes;
/*!
* \brief A list of disabled passes that should be excluded when executing the
* sequential pass.
*/
tvm::Array<tvm::Expr> disabled;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes);
v->Visit("disabled", &disabled);
}

/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const { return pass_info; }

/*!
* \brief Add a pass to the pass list.
*
* \param pass The candidate pass to be added.
*/
void AddPass(const Pass& pass) {
passes.push_back(pass);
}

TVM_DLL static Sequential make(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);

/*!
* \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module after resolving pass dependencies.
*
* TODO(zhiics) Build a dependency graph among the passes using provided
* metadata, i.e. required_passes. Likely, we can have a data structure, i.e.
* PassInfo, to store the relevant information including the parent passes.
*/
void ResolveDependency(const Module& mod);

TVM_DLL std::vector<std::string> DisabledPasses() const;

/*!
* \brief Perform optimizations on a series of passes. The aforementioned
* typical pass manager jobs could be done by it. This function could
* be overloaded to focus on different metrics, i.e. performance,
* memory footprint, etc.
*
* \param mod The module that an optimization pass runs on.
*
* \return Return the updated module.
*/
Module operator()(const Module& mod) const final;

/*!
* \brief Set the context information for a sequential pass.
*
* \param pass_ctx The context information for a sequential pass.
*/
void SetContext(const PassContext& pass_ctx) final;

static constexpr const char* _type_key = "relay.Sequential";
TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);

private:
/*!
* \brief The context information that is used to help perform a module pass.
*/
PassContext pass_ctx_;
};

class Sequential : public Pass {
public:
/*!
* \brief The constructor of `Sequential`.
* \param passes The passes to apply.
* \param pass_info The pass metadata.
* \param disabled The passes that will not be applied.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);
Sequential() = default;

explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}

const SequentialNode* operator->() const {
return static_cast<const SequentialNode*>(this->node_.get());
}

using ContainerType = Sequential;
};

// RELAY_DEFINE_NODE_REF(Sequential, SequentialNode, Pass);

/*
* \brief Create a module pass.
*
Expand Down Expand Up @@ -213,22 +330,6 @@ Pass CreateFunctionPass(
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
/*
* \brief Create a sequential pass.
*
* \param passes The optimization passes will be performed.
* \param opt_level The optimization level of the sequential pass.
* \param name The name of the sequential pass.
* \param required The list of the passes that the sequential pass is dependent on.
* \param disabled The disabled passes.
*
* \return The created sequential pass.
*/
Pass CreateSequential(const tvm::Array<Pass>& passes,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required,
const tvm::Array<tvm::Expr>& disabled);

} // namespace transform
} // namespace relay
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(self,
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of list/tuple.")

self.__init_handle_by_constructor__(_transform.CreateSequential,
self.__init_handle_by_constructor__(_transform.Sequential,
passes, opt_level, name, required,
disabled)

Expand Down
119 changes: 13 additions & 106 deletions src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,100 +169,6 @@ class FunctionPassNode : public PassNode {

RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass);

class Sequential;

/*!
* \brief The SequentialNode contains a set of passes that transform Relay
* programs from one AST to another semantically equivalent one.
*
* One example of this level of pass is that the pass manager needs to correctly
* perform a host of optimizations with a given optimization level and disabled
* passes.
*/
class SequentialNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;

/*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes;
/*!
* \brief A list of disabled passes that should be excluded when executing the
* sequential pass.
*/
tvm::Array<tvm::Expr> disabled;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes);
v->Visit("disabled", &disabled);
}

/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const { return pass_info; }

/*!
* \brief Add a pass to the pass list.
*
* \param pass The candidate pass to be added.
*/
void AddPass(const Pass& pass) {
passes.push_back(pass);
}

TVM_DLL static Sequential make(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);

/*!
* \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module after resolving pass dependencies.
*
* TODO(zhiics) Build a dependency graph among the passes using provided
* metadata, i.e. required_passes. Likely, we can have a data structure, i.e.
* PassInfo, to store the relevant information including the parent passes.
*/
void ResolveDependency(const Module& mod);

TVM_DLL std::vector<std::string> DisabledPasses() const;

/*!
* \brief Perform optimizations on a series of passes. The aforementioned
* typical pass manager jobs could be done by it. This function could
* be overloaded to focus on different metrics, i.e. performance,
* memory footprint, etc.
*
* \param mod The module that an optimization pass runs on.
*
* \return Return the updated module.
*/
Module operator()(const Module& mod) const final;

/*!
* \brief Set the context information for a sequential pass.
*
* \param pass_ctx The context information for a sequential pass.
*/
void SetContext(const PassContext& pass_ctx) final;

static constexpr const char* _type_key = "relay.Sequential";
TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);

private:
/*!
* \brief The context information that is used to help perform a module pass.
*/
PassContext pass_ctx_;
};

RELAY_DEFINE_NODE_REF(Sequential, SequentialNode, Pass);

PassInfo PassInfoNode::make(int opt_level, std::string name,
tvm::Array<tvm::Expr> required) {
auto pass_info = make_node<PassInfoNode>();
Expand Down Expand Up @@ -350,6 +256,16 @@ bool FunctionPassNode::SkipFunction(const Function& func) const {
return pval && pval->value != 0;
}

Sequential::Sequential(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled) {
auto n = make_node<SequentialNode>();
n->passes = std::move(passes);
n->pass_info = std::move(pass_info);
n->disabled = std::move(disabled);
node_ = std::move(n);
}

Sequential SequentialNode::make(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled) {
Expand Down Expand Up @@ -414,15 +330,6 @@ Pass CreateFunctionPass(
return FunctionPassNode::make(pass_func, pass_info);
}

Pass CreateSequential(const tvm::Array<Pass>& passes,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required,
const tvm::Array<tvm::Expr>& disabled) {
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
return SequentialNode::make(passes, pass_info, disabled);
}

TVM_REGISTER_NODE_TYPE(PassInfoNode);

TVM_REGISTER_API("relay._transform.PassInfo")
Expand Down Expand Up @@ -488,20 +395,20 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(SequentialNode);

TVM_REGISTER_API("relay._transform.CreateSequential")
TVM_REGISTER_API("relay._transform.Sequential")
.set_body([](TVMArgs args, TVMRetValue* ret) {
tvm::Array<Pass> passes = args[0];
int opt_level = args[1];
std::string name = args[2];
tvm::Array<tvm::Expr> required = args[3];
tvm::Array<tvm::Expr> disabled = args[4];
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
*ret = SequentialNode::make(passes, pass_info, disabled);
*ret = Sequential(passes, pass_info, disabled);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SequentialNode>([](const SequentialNode* node,
tvm::IRPrinter* p) {
tvm::IRPrinter* p) {
const PassInfoNode* seq_pn = node->Info().operator->();
p->stream << "Run Sequential pass: " << seq_pn->name
<< " at the optimization level. " << seq_pn->opt_level;
Expand Down