Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Update.
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin committed Sep 30, 2021
1 parent 830eb42 commit 9165dd6
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 56 deletions.
17 changes: 13 additions & 4 deletions src/relax/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ class VMFunctionCompiler : public ExprVisitor {

protected:
void VisitExpr_(const FunctionNode* func_node) {
builder_->Function("main", func_node->params.size());
if (func_node->name.defined()) {
builder_->Function(func_node->name.value()->name_hint, func_node->params.size());
} else {
builder_->Function("local_func", func_node->params.size());
}

size_t i = 0;
for (auto param : func_node->params) {
auto arg_register = NewRegister();
Expand All @@ -56,7 +61,7 @@ class VMFunctionCompiler : public ExprVisitor {
for (auto block : op->blocks) {
this->VisitBindingBlock(block);
}
// find function return Var and emit
// find the function return value and emit
auto ret_reg = this->var_register_map_.find(Downcast<Var>(op->body));
ICHECK(ret_reg != this->var_register_map_.end());
builder_->EmitRet(ret_reg->second);
Expand All @@ -69,7 +74,8 @@ class VMFunctionCompiler : public ExprVisitor {
String name = extern_func->global_symbol;
if (name == "vm.builtin.alloc_storage") {
Attrs attrs = call_node->attrs;
// Get the dtype hint from the attributes.

// Get dtype and device_type from the attributes.
auto alloc_attrs = attrs.as<AllocStorageAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs";
DataType dtype = alloc_attrs->dtype;
Expand Down Expand Up @@ -127,7 +133,7 @@ class VMFunctionCompiler : public ExprVisitor {
this->var_register_map_.insert({var, this->registers_num_});
builder_->EmitCall(name, args, NewRegister());
}
// Normal packed function
// Normal packed function without attributes
else {
std::vector<Instruction::Arg> args_;
for (size_t i = 0; i < call_node->args.size(); ++i) {
Expand All @@ -138,6 +144,9 @@ class VMFunctionCompiler : public ExprVisitor {
}
}
builder_->EmitCall(name, args_, Instruction::kVoidArg);
// this->var_register_map_.insert({var, this->registers_num_});
// builder_->EmitCall(name, args_, NewRegister());
// TODO(yuchen): what if the packed func has void return (no need to write to the dst register)?
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relax/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ String ExecutableNode::AsPython() const {
os << " ib.emit_call(\"" << this->func_names[instr.func_idx] << "\", args=["
<< StrJoin<Instruction::Arg>(instr.args, 0, instr.num_args, ", ", InstrArgToPyStr)
<< "]";
if (instr.dst != Instruction::kVoidArg) os << ", ret=ib.r(" << instr.dst << ")";
if (instr.dst != Instruction::kVoidArg) os << ", dst=ib.r(" << instr.dst << ")";
os << ")\n";
break;
}
Expand Down
114 changes: 63 additions & 51 deletions tests/python/relax/lowering_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,92 +23,104 @@
from tvm.relay import Call
from tvm.ir import structural_equal
import numpy as np
from termcolor import colored


# def rx_func(func):
# return func.module[func.fn_name]
# Before call_dps lowering

# Before memory lowering
# @tvm.script.relax
# class Module:
# def foo(x: Tensor[(3, 4), "float32"]):
# with relax.dataflow():
# gv0 = relax.call_dps((3, 4), "test.op.identity", (x,))
# relax.output(gv0)
# return gv0

# @rx.script
# def foo(x: Tensor[(3, 4), "float32"]):
# with relax.dataflow():
# z: Tensor[(3, 4), "float32"] = relax.call_dps((3, 4), rx.extern("test.op.identity"), (x))
# relax.output(x)
# return z
# f = rx_func(foo)


# def original_program():
# shape_anno = [3, 4]
# type_anno = rx.DynTensorType(2, "float32")
# x = rx.Var("x", shape_anno, type_anno)
# ib = rx.IRBuilder()
# with ib.function(x):
# with ib.dataflow() as df:
# lv0 = rx.call_dps([3, 4], rx.extern("test.op.identity"), [x])
# gv0 = ib.emit_output(lv0)
# ib.emit_output(gv0)
# expr = ib.get()


# after rewrite
# func = rx.transform.explicit_memory_rewrite(expr)
def original_program():
shape_anno = [3, 4]
type_anno = rx.DynTensorType(2, "float32")
x = rx.Var("x", shape_anno, type_anno)
ib = rx.IRBuilder()
with ib.function(x, "foo"):
with ib.dataflow() as df:
lv0 = rx.call_dps([3, 4], rx.ExternFunc("test.op.identity"), [x])
gv0 = ib.emit_output(lv0)
ib.emit_output(gv0)
expr = ib.get()
return expr

# After memory lowering
# after call_dps lowering

# @rx.script
# def foo(x: Tensor[(3, 4), "float32"]):
#
# lv0 = relax.call(rx.extern("relax.builtin.alloc_tensor"), (3, 4))
# relax.call(rx.extern("test.op.identity"), (x, lv0))
#
# return lv0

# gv0 = relax.call_packed("relax.builtin.alloc_tensor", (3, 4))
# relax.call_packed("test.op.identity", (x, gv0))
# return gv0

def explicit_memory_rewrite():
print(colored("Original Relax program:", "green"))
func = original_program()
mod = tvm.IRModule.from_expr(func)
print(rx.parser.astext(mod))
mem_lowered_func = rx.transform.explicit_memory_rewrite(func)
new_mod = tvm.IRModule.from_expr(mem_lowered_func)
# print(new_mod.astext())
# print(rx.parser.astext(new_mod))

# After furthur lowering

# @rx.script
# def foo(x: Tensor[(3, 4), "float32"]):
# gv0 = relax.call(extern("vm.builtin.alloc_storage"), (12, "cpu", "float32"))
# gv1 = relax.call(extern("vm.builtin.alloc_tensor"), (gv0, 0, "float32", (3, 4)))
# gv2 = relax.call(extern("test.op.identity"), (x, gv1))
# relax.call(extern("vm.builtin.free_tensor"), (gv1))
# relax.call(extern("vm.builtin.free_storage"), (gv0))
# @tvm.script.relax
# class Module:
# def foo(x: Tensor[(3, 4), "float32"]):
# gv0 = relax.call_packed("vm.builtin.alloc_storage", (12,), (8,), device_id=0, device_type=1)
# gv1 = relax.call_packed("vm.builtin.alloc_tensor", gv0, (0,), (3, 4))
# gv2 = relax.call_packed("test.op.identity", x, gv1)
# return gv1


@tvm.register_func("test.op.identity")
def identity_packed(a, b):
b = tvm.nd.array(a.asnumpy())
b[:] = tvm.nd.array(a.asnumpy())

def Relax_to_VM():
def relax_compile_vm():
shape_anno = [3, 4]
type_anno = rx.DynTensorType(2, "float32")
x = rx.Var("x", shape_anno, type_anno)

ib = rx.IRBuilder()

storage_attr = tvm.ir.attrs.make_node(
"relax.attrs.AllocStorageAttrs", device_id=0, device_type=1
)
tensor_attr = tvm.ir.attrs.make_node("relax.attrs.AllocTensorAttrs")

with ib.function(x):
# Construct the lowest form program
with ib.function(x, "foo"):
gv0 = ib.emit(Call(rx.ExternFunc("vm.builtin.alloc_storage"),[rx.ShapeExpr([12]), rx.ShapeExpr([8])], storage_attr))
gv1 = ib.emit(Call(rx.ExternFunc("vm.builtin.alloc_tensor"),[gv0, rx.ShapeExpr([0]), rx.ShapeExpr([3, 4])], tensor_attr))
ib.emit(Call(rx.ExternFunc("test.op.identity"), [x, gv1]))
ib.emit_output(gv1)
expr = ib.get()

mod = tvm.IRModule.from_expr(expr)
print(colored("After call_dps lowering:", "green"))
print(rx.parser.astext(mod))

print(colored("Compile into a VM executable:", "green"))
exec = rx.transform.compile(expr)
print(exec.astext())
print(exec.aspython())

input = tvm.nd.array(np.random.rand(3,4))
print(colored("Run the executable on the VM:", "green"))
input = tvm.nd.array(np.random.rand(3,4).astype(np.float32))
print("input array:", input)
vm = rx.VirtualMachine(exec, tvm.cpu())
res = vm["main"](input)
print(res)
res = vm["foo"](input)
print("output array:", res)


if __name__ == "__main__":
Relax_to_VM()
explicit_memory_rewrite()
print("""
@tvm.register_func("test.op.identity")
def identity_packed(a, b):
b[:] = tvm.nd.array(a.asnumpy())
""")
relax_compile_vm()

0 comments on commit 9165dd6

Please sign in to comment.