Skip to content

Commit

Permalink
increase code readability
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart committed Mar 24, 2020
1 parent 3e18cde commit 0d78212
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 155 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
*
* \return the pass
*/
TVM_DLL Pass GradientCell();
TVM_DLL Pass LazyGradientInit();

/*!
* \brief Fold constant expressions.
Expand Down
9 changes: 5 additions & 4 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,19 @@ def DeadCodeElimination(inline_once=False):
"""
return _ffi_api.DeadCodeElimination(inline_once)

def GradientCell():
"""Reduces memory usage of tensors with all 0s or 1s
def LazyGradientInit():
"""Reduces memory usage of gradient tensors
Parameters
----------
Returns
-------
ret: tvm.relay.Pass
The registered pass that delays or reduces memory allocation
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
"""
return _ffi_api.GradientCell()
return _ffi_api.LazyGradientInit()

def FoldConstant():
"""Fold the constant expressions in a Relay program.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

/*!
*
* \file gradient_cell.cc
* \file lazy_gradient_init.cc
*
* \brief Convert all tensors to a Gradient Cell
* \brief Lazily instantiate 0-filled or 1-filled tensors.
* This pass should be used after reverse-mode ad so that gradient tensors
* are not instantiated until after the forward pass.
*
* This pass delays or removes memory allocation by converting tensors into
* GradCell, an algebraic data type defined in gradient.rly
* GradCell, an algebraic data type defined in gradient.rly.
*
* This will delay or decrease memory usage. All calls to
* ones, ones_like, zeros, zeros_like will call the One or Zero constructor
Expand Down Expand Up @@ -67,13 +69,28 @@ namespace tvm {
namespace relay {

/*!
* \brief Visitor to wrap inputs
* \brief Visitor appropriately wraps tensors with Raw constructor
*
* Recursively looks at the type of the expression (TensorType or TupleType are only supported for now)
* and either call the GradCell constructor if TensorType
* or unfold and recursively visit if TupleType
*/
class InputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {
public:
explicit InputVisitor(IRModule module): module_(module) {}

Expr wrapExpr(const Expr expr, const Type& type) {
Expr VisitExpr_(const VarNode* op, const Type& t) final {
std::cout << op->type_annotation << std::endl;
return WrapExpr(GetRef<Var>(op), op->type_annotation);
}

Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
return WrapExpr(GetRef<TupleGetItem>(op), t);
}
private:
IRModule module_;

Expr WrapExpr(const Expr expr, const Type& type) {
if (type.as<TensorTypeNode>()) {
return Call(module_->GetConstructor("GradCell", "Raw"),
{expr}, Attrs(), {type});
Expand All @@ -89,27 +106,30 @@ class InputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {

return expr;
}

Expr VisitExpr_(const VarNode* op, const Type& t) final {
std::cout << op->type_annotation << std::endl;
return wrapExpr(GetRef<Var>(op), op->type_annotation);
}

Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
return wrapExpr(GetRef<TupleGetItem>(op), t);
}
private:
IRModule module_;
};

/*!
* \brief Visitor to unwrap output
* \brief Visitor appropriately unwraps expressions with GradCell type into Tensors
*
* Recursively looks at the type of the expression
* and either use the FromGradCell function if TypeCall to GradCell
* or unfold and recursively visit if TupleType
*/
class OutputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {
public:
explicit OutputVisitor(IRModule module): module_(module) {}

Expr unwrapExpr(const Expr expr, const Type& type) {
Expr VisitExpr_(const CallNode* op, const Type& t) final {
return UnwrapExpr(GetRef<Call>(op), t);
}

Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
return UnwrapExpr(GetRef<TupleGetItem>(op), t);
}
private:
IRModule module_;

Expr UnwrapExpr(const Expr expr, const Type& type) {
if (auto* type_call = type.as<TypeCallNode>()) {
if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) {
return Call(module_->GetGlobalVar("FromGradCell"), {expr});
Expand All @@ -127,32 +147,22 @@ class OutputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {

return expr;
}

Expr VisitExpr_(const CallNode* op, const Type& t) final {
return unwrapExpr(GetRef<Call>(op), t);
}

Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final {
return unwrapExpr(GetRef<TupleGetItem>(op), t);
}
private:
IRModule module_;
};

class GradientCellTransform: public ExprMutator, public TypeMutator {
class LazyGradientInitializer: public ExprMutator, public TypeMutator {
public:
explicit GradientCellTransform(IRModule module):
explicit LazyGradientInitializer(IRModule module):
module_(module) {
module_->ImportFromStd("gradient.rly");
}

/*!
* \brief apply GradientCell transformation and wrap function
* \brief apply LazyGradientInit transformation and wrap function
* so that function type stays the same
*
* input/output types should only be a combination of TupleTypes and TensorTypes
*/
Expr transform(const Expr& e) {
Expr Transform(const Expr& e) {
auto* f = (e).as<FunctionNode>();
auto* transformed = this->Mutate(e).as<FunctionNode>();

Expand All @@ -179,90 +189,46 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
}

Expr VisitExpr_(const CallNode* call_node) final {
// optimize operators
if (auto* op = (call_node->op).as<OpNode>()) {
Expr op_expr = GetRef<Op>(op);
if (op_expr == Op::Get("add") && call_node->args.size() == 2 &&
AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) {
// case: "add" between two tensors of the same size
const auto addFunc = module_->GetGlobalVar("AddGradCell");
tvm::Array<Expr> args;
// create add function
Type paramType = call_node->args[0]->checked_type();
tvm::Array<Var> params = {Var("lhs", paramType),
Var("rhs", paramType)};
Expr callAdd = Call(Op::Get("add"), {params[0], params[1]});
Expr addTensorsFunc = Function(params, callAdd, paramType,
Array<TypeVar>());

// pass add function and tensors into arguments
args.push_back(addTensorsFunc);
for (Expr expr : call_node->args) {
args.push_back(VisitExpr(expr));
}
return Call(addFunc, args, Attrs(), {paramType});
} else if (op_expr == Op::Get("multiply") && call_node->args.size() == 2 &&
AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) {
// case: "multiply" between two tensors of the same size
const auto multFunc = module_->GetGlobalVar("MultiplyGradCell");
// create multiply function
tvm::Array<Expr> args;
Type paramType = call_node->args[0]->checked_type();
tvm::Array<Var> params = {Var("lhs", paramType),
Var("rhs", paramType)};
Expr callMultiply = Call(Op::Get("multiply"),
{params[0], params[1]});
Expr multTensorsFunc = Function(params, callMultiply, paramType,
Array<TypeVar>());

// pass multiply function and tensors into arguments
args.push_back(multTensorsFunc);
for (Expr expr : call_node->args) {
args.push_back(VisitExpr(expr));
}
return Call(multFunc, args, Attrs(), {paramType});
} else if (op_expr == Op::Get("ones")) {
// ones operator, use One constructor of GradCell
Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)},
{call_node->checked_type()}, {});
return Call(module_->GetConstructor("GradCell", "One"),
{func}, Attrs(), {call_node->checked_type()});
} else if (op_expr == Op::Get("zeros")) {
// zeros operator, use Zero constructor of GradCell
Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)},
{call_node->checked_type()}, {});
return Call(module_->GetConstructor("GradCell", "Zero"),
{func}, Attrs(), {call_node->checked_type()});

if (op_expr == Op::Get("add")) {
return CallGradCellFunction(call_node, module_->GetGlobalVar("AddGradCell"));
}

// handle other ops + zeros_like + ones_like
// we put zeros_like and ones_like here to make use of
// code converting the arguments of CallNode into Tensor
const auto fromFunc = module_->GetGlobalVar("FromGradCell");
tvm::Array<Expr> args;
// use FromGradCell to convert args to Tensor
for (Expr expr : call_node->args) {
args.push_back(Call(fromFunc,
{VisitExpr(expr)}, Attrs(), {expr->checked_type()}));
if (op_expr == Op::Get("multiply")) {
return CallGradCellFunction(call_node, module_->GetGlobalVar("MultiplyGradCell"));
}

const Expr tensorRes = Call(call_node->op, args);
if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) {
// fn() -> T, function returns result of the operation
Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)},
{call_node->checked_type()}, {});
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero";
return Call(module_->GetConstructor("GradCell", constructor_name),
{func}, Attrs(), {call_node->checked_type()});
}

if (op_expr == Op::Get("ones_like")) {
Expr onesFunction = Function({}, tensorRes,
if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) {
// ones_like and zeros_like need TensorType input
Expr result = CallPrimitiveOp(call_node);
// fn() -> T, function returns result of operation
Expr func = Function({}, result,
{call_node->checked_type()}, Array<TypeVar>());
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero";
return Call(module_->GetConstructor("GradCell", "One"),
{onesFunction}, Attrs(), {call_node->checked_type()});
} else if (op_expr == Op::Get("zeros_like")) {
Expr zerosFunction = Function({}, tensorRes,
{call_node->checked_type()}, Array<TypeVar>());
return Call(module_->GetConstructor("GradCell", "Zero"),
{zerosFunction}, Attrs(), {call_node->checked_type()});
{func}, Attrs(), {call_node->checked_type()});
}
return Call(module_->GetConstructor("GradCell", "Raw"), {tensorRes},

// handle all other ops
Expr result = CallPrimitiveOp(call_node);
// wrap result with Raw constructor
return Call(module_->GetConstructor("GradCell", "Raw"), {result},
Attrs(), {call_node->checked_type()});
}
// call-> op is not a relay op
// not an op
return ExprMutator::VisitExpr_(call_node);
}

Expand All @@ -280,23 +246,70 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
private:
// Module
IRModule module_;

/*!
* \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type
*/
Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) {
// can only use overloaded functions if 2 arguments of same type
if (call_node->args.size() != 2 ||
!AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) {
Expr result = CallPrimitiveOp(call_node);
return Call(module_->GetConstructor("GradCell", "Raw"), {result},
Attrs(), {call_node->checked_type()});
}

tvm::Array<Expr> args;
// create "fallback" function for overloaded function
Type paramType = call_node->args[0]->checked_type();
tvm::Array<Var> params = {Var("lhs", paramType),
Var("rhs", paramType)};
// use primitive op in this case
Expr callOp = Call(call_node->op, {params[0], params[1]});
Expr func = Function(params, callOp, paramType,
Array<TypeVar>());

// pass "fallback" function and tensors as arguments
args.push_back(func);
for (Expr expr : call_node->args) {
args.push_back(VisitExpr(expr));
}
// return new call to overloaded function
return Call(overloaded_op, args, Attrs(), {paramType});
}

/*!
* \brief Convert calls to other ops by converting args into TensorType
* \return call expr returning result of op
*/
Expr CallPrimitiveOp(const CallNode* call_node) {
const auto fromFunc = module_->GetGlobalVar("FromGradCell");
tvm::Array<Expr> args;
// use FromGradCell to convert args to Tensor
for (Expr expr : call_node->args) {
args.push_back(Call(fromFunc,
{VisitExpr(expr)}, Attrs(), {expr->checked_type()}));
}
// result of operation
return Call(call_node->op, args);
}
};

Expr GradientCell(const Expr& e, IRModule mod) {
return GradientCellTransform(mod).transform(e);
Expr LazyGradientInit(const Expr& e, IRModule mod) {
return LazyGradientInitializer(mod).Transform(e);
}

namespace transform {
Pass GradientCell() {
Pass LazyGradientInit() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(GradientCell(f, m));
return Downcast<Function>(LazyGradientInit(f, m));
};
return CreateFunctionPass(pass_func, 2, "GradientCell", {});
return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {});
}

TVM_REGISTER_GLOBAL("relay._transform.GradientCell")
.set_body_typed(GradientCell);
TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit")
.set_body_typed(LazyGradientInit);

} // namespace transform

Expand Down
Loading

0 comments on commit 0d78212

Please sign in to comment.