Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

End2End Lowering #23

Merged
merged 13 commits into from
Oct 22, 2021
6 changes: 2 additions & 4 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,7 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
* \return expr.
*/
Expr Mutate(const Expr& expr) {
if (memo_.count(expr) == 0) {
memo_[expr] = this->VisitExpr(expr);
ZihengJiang marked this conversation as resolved.
Show resolved Hide resolved
}
return Downcast<Expr>(memo_[expr]);
return this->VisitExpr(expr);
}

Expr VisitExpr(const Expr& expr) override;
Expand Down Expand Up @@ -226,6 +223,7 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
virtual void VisitBinding(const Binding& binding);
virtual Var VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);

virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);

Expand Down
7 changes: 6 additions & 1 deletion python/tvm/relax/exec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm._ffi._ctypes.packed_func import TVMRetValueHandle
from tvm.runtime import Object
from tvm.runtime.container import ShapeTuple
from tvm._ffi.base import _LIB, check_call
from . vm import Executable
from . import _ffi_api
Expand Down Expand Up @@ -89,7 +90,11 @@ def emit_call(
dst = SpecialReg.VOID_ARG
args_ = []
for arg in args:
if isinstance(arg, tvm.nd.NDArray) or isinstance(arg, tvm.DataType):
if isinstance(arg, tuple):
shape_tuple = ShapeTuple(arg)
new_arg = self.emit_constant(shape_tuple)
args_.append(new_arg)
elif isinstance(arg, (tvm.nd.NDArray, tvm.DataType, ShapeTuple)):
new_arg = self.emit_constant(arg)
args_.append(new_arg)
else:
Expand Down
38 changes: 31 additions & 7 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
from tvm import IRModule
from tvm import IRModule
from . import _ffi_api


def fma_rewrite(expr):
"""Perform fused multiply add rewriting in dataflow blocks.

Expand All @@ -29,22 +30,45 @@ def fma_rewrite(expr):
"""
return _ffi_api.fma_rewrite(expr)

def explicit_memory_rewrite(expr):
"""Perform explicit memory allocation for call_dps in dataflow blocks.
def to_non_dataflow(mod: IRModule) -> IRModule:
"""Transform all dataflow structure to non-dataflow version.

Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : tvm.IRModule
The input module.
"""
return _ffi_api.explicit_memory_rewrite(expr)
return _ffi_api.to_non_dataflow(mod)


def call_dps_rewrite(mod: IRModule) -> IRModule:
"""Perform explicit memory allocation for call_dps.

Parameters
----------
mod : tvm.IRModule
The input module.
"""
return _ffi_api.call_dps_rewrite(mod)


def memory_lower(mod: IRModule) -> IRModule:
"""Perform memory lowering. Lower the relax.builtin.alloc_tensor op to VM builtin functions.

Parameters
----------
mod : tvm.IRModule
The input module.
"""
return _ffi_api.memory_lower(mod)


def shape_lower(mod: IRModule) -> IRModule:
"""Lower the shape expression in relax to shape heap and TIR functions.

Parameters
----------
expr : tvm.IRModule
mod : tvm.IRModule
The input module.
"""
return _ffi_api.shape_lower(mod)
7 changes: 6 additions & 1 deletion python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm.runtime import Object, Device, Module, PackedFunc
from tvm._ffi.base import _LIB, check_call
from . import _ffi_api
from . import transform
from ..rpc.base import RPC_SESS_MASK


Expand Down Expand Up @@ -164,5 +165,9 @@ def build(mod: tvm.IRModule,
lib: tvm.runtime.Module
A runtime module that contains generated code.
"""
ex, lib = _ffi_api.VMBuild(mod, target, target_host)
new_mod = transform.to_non_dataflow(mod)
new_mod = transform.call_dps_rewrite(new_mod)
new_mod = transform.memory_lower(new_mod)
new_mod = transform.shape_lower(new_mod)
ex, lib = _ffi_api.VMBuild(new_mod, target, target_host)
return ex, lib
11 changes: 9 additions & 2 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
}

Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) {
auto it = var_remap_.find(GetRef<Var>(op));
if (it != var_remap_.end()) {
return it->second;
}
if (op->type_annotation.defined()) {
Type type = this->VisitType(op->type_annotation.value());
if (!op->type_annotation.same_as(type)) {
Expand Down Expand Up @@ -339,7 +343,7 @@ void ExprMutator::VisitBinding(const Binding& binding) {

Var ExprMutator::VisitVarBinding(const VarBinding& binding) {
Expr new_value = builder_->Normalize(this->Mutate(binding->value));
Var new_var = Downcast<Var>(this->Mutate(binding->var));

// TODO(@altanh): this probably shouldn't live here, all passes would have to make sure to do it
// in this method...
// if (new_value->shape_.defined()) {
Expand All @@ -356,6 +360,7 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) {
// new_var->checked_type_ = new_value->checked_type_;
// }

Var new_var = binding->var;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not this->Mutate(binding->var)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the lhs of the binding is defining a Var (not reuse it), so we do not need to mutate it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if someone wants to write a simple mutator pass that renames each var from x -> my_x?

if (!builder_->CanProveShapeEqual(new_var->shape(), new_value->shape()) ||
!StructuralEqual()(new_var->checked_type(), new_value->checked_type())) {
new_var = Var(new_var->vid, NullOpt, NullOpt, new_var->span);
Expand All @@ -380,7 +385,9 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) {
void ExprMutator::VisitMatchShape(const MatchShape& binding) {
Expr new_value = this->Mutate(binding->value);
Expr new_pattern = this->Mutate(ShapeExpr(binding->pattern));
Var new_var = Downcast<Var>(this->Mutate(binding->var));
Var new_var = binding->var;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question


// TODO: when value's shape/type changed, create new var
builder_->EmitMatchShape(
MatchShape(new_value, Downcast<ShapeExpr>(new_pattern)->values, new_var));
}
Expand Down
87 changes: 87 additions & 0 deletions src/relax/transform/call_dps_rewrite.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/transform/call_dps_rewrite.cc
* \brief
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>

#include "../../relay/transforms/pattern_utils.h"

namespace tvm {
namespace relax {

// ==================
// CallDPSMutator
// Example:
// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
// -->
// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
// rx.call_packed(op.identity, x, lv0)

class CallDPSMutator : public ExprMutator {
public:
explicit CallDPSMutator(IRModule mod) { mod_ = mod; }

IRModule Lower() {
ret_mod_ = IRModule();
for (auto& p : mod_->functions) {
Expr func = p.second;
if (p.second->IsInstance<FunctionNode>()) {
func = this->Mutate(p.second);
}
ret_mod_->Add(p.first, Downcast<BaseFunc>(func));
}
return ret_mod_;
}

Expr VisitExpr_(const CallNode* call) override {
// post-order mutation
Expr expr = ExprMutator::VisitExpr_(call);
call = expr.as<CallNode>();
// TODO(@yuchen, @altanh): using mutate cause infinite recursion
// Expr expr = ExprMutator::Mutate(GetRef<Call>(call));

static const Op& call_dps_op = Op::Get("relax.call_dps");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");

if (call->op == call_dps_op) {
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc");
builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_");
return tensor;
}

return GetRef<Expr>(call);
}

private:
IRModule mod_;
IRModule ret_mod_;
};

TVM_REGISTER_GLOBAL("relax.transform.call_dps_rewrite").set_body_typed([](IRModule mod) {
return CallDPSMutator(mod).Lower();
});

} // namespace relax
} // namespace tvm
82 changes: 52 additions & 30 deletions src/relax/transform/memory_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
* \file src/relax/transform/memory_rewrite.cc
* \brief
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>
Expand All @@ -30,14 +31,31 @@ namespace tvm {
namespace relax {

// ==================
// ExplicitMemMutator
// MemLowerMutator
// Lower the relax.builtin.alloc_tensor op to VM builtin functions.
// Example:
// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
// x = relax.builtin.alloc_tensor((m, n))
// -->
// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
// rx.call_packed(op.identity, x, lv0)
// gv0 = relax.call_packed("vm.builtin.alloc_storage", (m * n), alignment, device_type,
// relax.attrs.AllocStorageAttrs) gv1 = relax.call_packed("vm.builtin.alloc_tensor", gv0, offset,
// (m, n), relax.attrs.AllocTensorAttrs)

class MemLowerMutator : public ExprMutator {
public:
explicit MemLowerMutator(IRModule mod) { mod_ = mod; }

IRModule Lower() {
ret_mod_ = IRModule();
for (auto& p : mod_->functions) {
Expr func = p.second;
if (p.second->IsInstance<FunctionNode>()) {
func = this->Mutate(p.second);
}
ret_mod_->Add(p.first, Downcast<BaseFunc>(func));
}
return ret_mod_;
}

class ExplicitMemMutator : public ExprMutator {
Expr ComputeStorageSize(const Expr& shape, const Type& type) const {
DynTensorType tensor_type = Downcast<DynTensorType>(type);
DataType dtype = DataType(tensor_type->dtype);
Expand All @@ -63,44 +81,48 @@ class ExplicitMemMutator : public ExprMutator {
return ret;
}

BindingBlock VisitBindingBlock(const BindingBlock& block) {
builder_->BeginBindingBlock();
for (Binding binding : block->bindings) {
this->VisitBinding(binding);
}
return builder_->EndBlock();
}

Expr VisitExpr_(const CallNode* call) override {
// post-order mutation
Expr expr = ExprMutator::VisitExpr_(call);
call = expr.as<CallNode>();
// TODO(@yuchen, @altanh): using mutate cause infinite recursion
// Expr expr = ExprMutator::Mutate(GetRef<Call>(call));

static const Op& call_dps_op = Op::Get("relax.call_dps");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");

if (call->op == call_dps_op) {
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
Type arg_type = Downcast<Tuple>(call->args[2])->fields[0]->checked_type();
Expr output_size = ComputeStorageSize(output_shape, arg_type);
Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc");
builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_");
return tensor;
if (call->op == alloc_tensor_op) {
ShapeExpr tensor_shape = Downcast<ShapeExpr>(call->args[0]);
// TODO(@yuchen): Get the type of input x, options: add an attr to relax.builtin.alloc_tensor
altanh marked this conversation as resolved.
Show resolved Hide resolved
Type tensor_type = DynTensorType(2, DataType::Float(32));
altanh marked this conversation as resolved.
Show resolved Hide resolved
Expr storage_size = ComputeStorageSize(tensor_shape, tensor_type);
ShapeExpr alignment = ShapeExpr({IntImm(DataType::Int(64), 64)});
ShapeExpr device_type = ShapeExpr({IntImm(DataType::Int(64), 1)});
auto storage_attr = make_object<AllocStorageAttrs>();
storage_attr->dtype = DataType::Float(32);
storage_attr->device_type = 1;

Var storage =
builder_->Emit(Call(ExternFunc("vm.builtin.alloc_storage"),
altanh marked this conversation as resolved.
Show resolved Hide resolved
{storage_size, alignment}, Attrs(storage_attr)),
"storage");

ShapeExpr offset = ShapeExpr({IntImm(DataType::Int(64), 0)});
auto tensor_attr = make_object<AllocTensorAttrs>();
tensor_attr->dtype = DataType::Float(32);
Expr shape = call->args[0];
return builder_->Emit(
Call(ExternFunc("vm.builtin.alloc_tensor"), {storage, offset, shape}, Attrs(tensor_attr)),
"tensor");
}

return GetRef<Expr>(call);
}
};

Expr ExplicitMemRewrite(const Expr& e) {
return ExplicitMemMutator().Mutate(e);
}
private:
IRModule mod_;
IRModule ret_mod_;
YuchenJin marked this conversation as resolved.
Show resolved Hide resolved
};

TVM_REGISTER_GLOBAL("relax.transform.explicit_memory_rewrite")
.set_body_typed([](Expr expr) {
return ExplicitMemRewrite(expr);
TVM_REGISTER_GLOBAL("relax.transform.memory_lower").set_body_typed([](IRModule mod) {
return MemLowerMutator(mod).Lower();
});

} // namespace relax
Expand Down
Loading