Skip to content

Commit

Permalink
Redesign IRBuilder to BlockBuilder (apache#22)
Browse files Browse the repository at this point in the history
* init

* update

* update

* test case working

* update and add multi block test case

* check in

* fixes

* fix

* update

* add

* update

* add

* update

* address comments.

Co-authored-by: Altan Haan <ahaan@octoml.ai>
  • Loading branch information
YuchenJin and altanh committed Mar 2, 2022
1 parent 0e67bef commit 13ca07f
Show file tree
Hide file tree
Showing 15 changed files with 935 additions and 1,171 deletions.
204 changes: 204 additions & 0 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/block_builder.h
* \brief The utility for constructing Relax binding blocks.
*/
#ifndef TVM_RELAX_BLOCK_BUILDER_H_
#define TVM_RELAX_BLOCK_BUILDER_H_

#include <tvm/ir/expr.h>
#include <tvm/relax/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/with.h>

#include <memory>

namespace tvm {
namespace relax {

class BlockBuilder;

/*!
* \brief Utility data structure for generating unique names for IR construction.
*/
class NameTable {
public:
/*!
* \brief Generate a unique name with a specified prefix.
* \param prefix The name prefix.
* \return The generated name.
*/
inline std::string GetUniqueName(std::string prefix) {
std::replace(prefix.begin(), prefix.end(), '.', '_');
std::string unique_prefix = prefix;
auto it = alloc_map_.find(prefix);
if (it != alloc_map_.end()) {
while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) {
}
}
alloc_map_[unique_prefix] = 0;
return unique_prefix;
}

private:
std::unordered_map<std::string, uint32_t> alloc_map_;
};

/*!
* \brief A builder that provides APIs to build Relax binding blocks.
*/
class BlockBuilderNode : public Object {
public:
BlockBuilderNode(std::shared_ptr<NameTable> name_table) : name_table_(name_table) {}

~BlockBuilderNode();

BlockBuilderNode() { name_table_ = std::make_shared<NameTable>(); }

/*! \brief Begin to build a DataflowBlock. */
void BeginDataflowBlock();
/*! \brief Begin to build a BindingBlock. */
void BeginBindingBlock();
/*!
* \brief End building a BindingBlock.
* \return The BindingBlock being built.
*/
BindingBlock EndBlock();
/*!
* \brief Check if the block being built is DataflowBlock or not.
* \return A boolean that indicates if the block being built is DataflowBlock or not.
*/
inline bool CurrentBlockIsDataFlow() { return CurrentFrame()->is_dataflow; }
/*!
* \brief Emits an Expr, and returns the variable it is bound to.
* \param expr The Expr to be emitted.
* \param name_hint Name hint for the bound variable.
* \return The new variable that \p expr is bound to.
*/
virtual Var Emit(const Expr& expr, std::string name_hint = "");
/*!
* \brief Emits a variable binding, and returns the bound Var.
* \param binding The variable binding.
* \return The bound variable.
*/
virtual Var Emit(const VarBinding& binding);
/*!
* \brief Emit a MatchShape.
* \param value The value of the MatchShape to be emitted.
* \param pattern The pattern of the MatchShape to be emitted.
* \param name_hint Name hint for the bound variable.
* \return The variable bound to the MatchShape.
*/
Var EmitMatchShape(const Expr& value, const Array<PrimExpr>& pattern, std::string name_hint = "");
/*!
* \brief Emit a MatchShape binding.
* \param binding The MatchShape binding to be emitted.
* \return The variable bound to the MatchShape.
*/
Var EmitMatchShape(const MatchShape& binding);
/*!
* \brief Generate an output for the current dataflow block.
* \param output The output variable of the block.
* \param name_hint Name hint for the bound variable.
* \return The variable bound to \p output.
*/
Var EmitOutput(const Expr& output, std::string name_hint = "");
/*!
* \brief Generate an output for the current dataflow block.
* \param binding The output binding to output.
* \return The variable bound to \p output.
*/
Var EmitOutput(const VarBinding& binding);
/*!
* \brief Lookup a var in the binding table \p var_map_.
* \param var The input var.
* \return The Expr bound to the input \p var.
*/
Expr LookupVar(const Var& var);
/*!
* \brief Check if two shape expressions can be proven equal at compile time.
* \param lhs The input lhs shape.
* \param rhs The input rhs shape.
* \return Whether we can prove lhs shape is the same as the rhs shape.
*/
bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs);
/*!
* \brief Normalize an Expr to complete its shape and type.
* \param expr The input expr.
* \return The expr with normalized shape and type.
*/
Expr Normalize(const Expr& expr);
/*!
* \brief Create a BlockBuilder.
* \return The created BlockBuilder.
*/
TVM_DLL static BlockBuilder Create();

void VisitAttrs(AttrVisitor* v) {}

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.BlockBuilder";
TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object);

private:
Var Emit(const Expr& expr, bool is_dataflow, std::string name_hint);

protected:
/*!
* \brief A representation of a block frame.
*
* A block frame is a record containing the bindings needed
* to build a binding block, and a boolean to indicate if the
* block being built is a DataflowBlock or not.
*/
struct BlockFrame {
Array<Binding> bindings;
bool is_dataflow;
};
friend class BlockBuilder;
/*!
* \brief Get the current block frame.
* \return The current block frame.
*/
BlockFrame* CurrentFrame();
/*! \brief A stack to store block frames. */
std::stack<BlockFrame> block_stack_;
/*! \brief A diagnostic context for reporting errors. */
DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {}));
/*! \brief A binding table that maps var to value. */
// TODO(@yuchen, @altanh): make var_map_ scoped, and decide if it should be in the builder
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_map_;
/*! \brief A name table to get unique names for IR construction. */
std::shared_ptr<NameTable> name_table_;
};

class BlockBuilder : public ObjectRef {
public:
TVM_DLL explicit BlockBuilder(std::shared_ptr<NameTable> name_table);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode);
};

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_BLOCK_BUILDER_H_
44 changes: 31 additions & 13 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

#include <tvm/ir/error.h>
#include <tvm/node/functor.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/ir_builder.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
Expand Down Expand Up @@ -167,6 +167,9 @@ class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
virtual void VisitMatchShape(const MatchShape& binding);
virtual void VisitBindingBlock(const BindingBlock& block);
virtual void VisitDataflowBlock(const DataflowBlock& block);

protected:
std::unordered_map<const Object*, size_t> visit_counter_;
};

void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
Expand All @@ -180,11 +183,22 @@ void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
*/
class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
public:
ExprMutator() {
name_table_ = std::make_shared<NameTable>();
builder_ = BlockBuilder(name_table_);
}

/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); }
Expr Mutate(const Expr& expr) {
if (memo_.count(expr) == 0) {
memo_[expr] = this->VisitExpr(expr);
}
return Downcast<Expr>(memo_[expr]);
}

Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const TupleNode* op) override;
Expand All @@ -208,28 +222,32 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
virtual void VisitBinding(const Binding& binding, IRBuilder& builder);
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder);
virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& builder);

virtual void VisitBinding(const Binding& binding);
virtual Var VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);

protected:
IRBuilder builder_;
Expr MutateWithPrologue(const Expr& expr, bool is_dataflow);
/*! \brief Look up the value binded to a var. */
Expr LookupVar(Var var);
// A remapping table: pre var -> post var
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> memo_;
std::shared_ptr<NameTable> name_table_;
BlockBuilder builder_;
};

// TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks
/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes
*/
class DataflowMutator : public ExprMutator {
public:
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder);
void VisitBinding(const Binding& binding) final;

protected:
/*! \brief Look up the value binded to a var. */
Expr LookupVar(Var var);
// A remapping table: pre var -> post var
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> pre_post_var_map_;
virtual Var VisitDataflowVarBinding(const VarBinding& binding);
};

} // namespace relax
Expand Down
Loading

0 comments on commit 13ca07f

Please sign in to comment.