Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Update.
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin committed Oct 6, 2021
1 parent 97a5fc8 commit c7082dc
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 23 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relax/vm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tvm
from . import vm, _ffi_api


def compile(mod: tvm.IRModule) -> vm.Executable:
"""Compile the module to VM executable. A helper function for VMCompiler.
Expand All @@ -39,6 +40,7 @@ def compile(mod: tvm.IRModule) -> vm.Executable:
compiler.compile(mod)
return compiler.get_exec()


class VMCompiler(object):
"""Compiler that compiles module to VM executable."""

Expand All @@ -47,7 +49,7 @@ def __init__(self):
self._compile = self.mod["compile"]
self._get_exec = self.mod["get_executable"]

def compile(self, mod: tvm.IRModule):
def compile(self, mod: tvm.IRModule) -> None:
"""Compile the module to VM executable.
Parameters
Expand Down
5 changes: 3 additions & 2 deletions src/relax/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ class VMFunctionCompiler : public ExprVisitor {
void EmitAllocStorage(const Call& call_node, const Var& var) {
Attrs attrs = call_node->attrs;

// TODO(@yuchen): a generic way to lower attributes for extern calls
// Get dtype and device_type from the attributes.
auto alloc_attrs = attrs.as<AllocStorageAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs";
Expand Down Expand Up @@ -189,7 +188,7 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Obje
this->Compile(mod);
});
} else if (name == "get_executable") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = builder_->Get(); });
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetExec(); });
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([name](TVMArgs args, TVMRetValue* rv) {});
Expand All @@ -213,6 +212,8 @@ void VMCompiler::Compile(IRModule mod) {
}
}

Executable VMCompiler::GetExec() { return builder_->Get(); }

runtime::Module CreateVMCompiler() {
auto compiler = make_object<VMCompiler>();
return runtime::Module(compiler);
Expand Down
5 changes: 5 additions & 0 deletions src/relax/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class VMCompiler : public runtime::ModuleNode {
* \param mod Input IRModule to be compiled.
*/
void Compile(IRModule mod);
/*!
* \brief Get the compiled executable.
* \return The compiled executable.
*/
Executable GetExec();

virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);

Expand Down
Binary file removed tests/python/relax/exec.bin
Binary file not shown.
30 changes: 10 additions & 20 deletions tests/python/relax/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations # must import to defer parsing of annotations
import numpy as np
import tvm
from tvm.relay import Call
Expand Down Expand Up @@ -202,27 +203,16 @@ def test_vm_storage():
assert res.shape == shape

def test_vm_compile():
shape_anno = [3, 4]
type_anno = rx.DynTensorType(2, "float32")
x = rx.Var("x", shape_anno, type_anno)
ib = rx.IRBuilder()
storage_attr = tvm.ir.attrs.make_node(
"relax.attrs.AllocStorageAttrs", device_id=0, device_type=1
)
tensor_attr = tvm.ir.attrs.make_node("relax.attrs.AllocTensorAttrs",)

# Construct the lowest-level Relax program
with ib.function(x, "foo"):
gv0 = ib.emit(Call(rx.ExternFunc("vm.builtin.alloc_storage"),[rx.ShapeExpr([12]), rx.ShapeExpr([64])], storage_attr))
gv1 = ib.emit(Call(rx.ExternFunc("vm.builtin.alloc_tensor"),[gv0, rx.ShapeExpr([0]), rx.ShapeExpr([3, 4])], tensor_attr))
ib.emit(Call(rx.ExternFunc("test.vm.identity"), [x, gv1]))
ib.emit_output(gv1)
expr = ib.get()
mod = tvm.IRModule.from_expr(expr)

# compile the module to VM executable
@rx.script
class Mod:
def foo(x: Tensor[(3, 4), "float32"]):
y = relax.call_packed("vm.builtin.alloc_storage", (12,), (64,), device_id=0, device_type=1, attrs_type_key="relax.attrs.AllocStorageAttrs")
z = relax.call_packed("vm.builtin.alloc_tensor", y, (0,), (3, 4), attrs_type_key="relax.attrs.AllocTensorAttrs")
w = relax.call_packed("test.vm.identity", x, z)
return z

mod = Mod()
exec = rx.vm_compiler.compile(mod)

input = tvm.nd.array(np.random.rand(3,4).astype(np.float32))
vm = rx.VirtualMachine(exec, tvm.cpu())
res = vm["foo"](input)
Expand Down

0 comments on commit c7082dc

Please sign in to comment.