Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
VM compiler.
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin committed Sep 29, 2021
1 parent 5e293cd commit 43628e5
Show file tree
Hide file tree
Showing 15 changed files with 531 additions and 50 deletions.
63 changes: 63 additions & 0 deletions include/tvm/relax/attrs/memory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/memory.h
* \brief Attributes for memory operators.
*/
#ifndef TVM_RELAX_ATTRS_MEMORY_H_
#define TVM_RELAX_ATTRS_MEMORY_H_

#include <tvm/ir/attrs.h>

namespace tvm {
namespace relax {
/*!
* \brief Options for allocating storage.
*/
struct AllocStorageAttrs : public tvm::AttrsNode<AllocStorageAttrs> {
DataType dtype;
int device_id;
int device_type;

TVM_DECLARE_ATTRS(AllocStorageAttrs, "relax.attrs.AllocStorageAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory.");
TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory.");
}
};

/*!
* \brief Options for allocating tensors.
*/
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(AllocTensorAttrs, "relax.attrs.AllocTensorAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
}
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_ATTRS_MEMORY_H_
2 changes: 1 addition & 1 deletion include/tvm/relax/vm/exec_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ExecBuilderNode : public Object {
* \brief Emit a constant value to the constant pool.
* \return The index that represents the constant.
*/
vm::Index EmitConstant(ObjectRef obj);
vm::Index EmitConstant(TVMRetValue obj);
/*!
* \brief Get the built executable.
* \return The built executable.
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relax/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ExecutableNode : public Object {
/*! \brief A map from globals (as strings) to their index in the function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
std::vector<TVMRetValue> constants;
/*! \brief The name of packed functions. */
std::vector<std::string> func_names;
/*! \brief A mapping from the packed function (as string) to the index that
Expand All @@ -129,7 +129,7 @@ class ExecutableNode : public Object {
* \brief Save the constant pool.
* \param strm The input stream.
*/
void SaveConstantSection(dmlc::Stream* strm);
// void SaveConstantSection(dmlc::Stream* strm);
/*!
* \brief Save the instructions.
* \param strm The input stream.
Expand All @@ -149,7 +149,7 @@ class ExecutableNode : public Object {
* \brief Load the constant pool.
* \param strm The input stream.
*/
void LoadConstantSection(dmlc::Stream* strm);
// void LoadConstantSection(dmlc::Stream* strm);
/*!
* \brief Load the instructions.
* \param strm The input stream.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@

# Operator
from .op.base import call_dps
from .op.op_attrs import AllocStorageAttrs, AllocTensorAttrs

# IRBuilder
IRBuilder = ir_builder.IRBuilder
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
# Operators
from .base import *
from .tensor import *
from .op_attrs import *
28 changes: 28 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The attributes node used for Relax operators"""
from tvm.ir import Attrs
import tvm._ffi

@tvm._ffi.register_object("relax.attrs.AllocStorageAttrs")
class AllocStorageAttrs(Attrs):
"""Attributes used in alloc_storage operators"""


@tvm._ffi.register_object("relax.attrs.AllocTensorAttrs")
class AllocTensorAttrs(Attrs):
"""Attributes used in alloc_tensor operators"""
1 change: 1 addition & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
"""Relax IR analysis. """

from .transform import *
from .compile import *
23 changes: 23 additions & 0 deletions python/tvm/relax/transform/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
from . import _ffi_api


def compile(expr):
return _ffi_api.compile(expr)
4 changes: 4 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
*/
#include <tvm/relax/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relax/attrs/memory.h>

#include "op_common.h"

namespace tvm {
namespace relax {

TVM_REGISTER_NODE_TYPE(AllocStorageAttrs);
TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);

bool EqualConstInt(const PrimExpr& lhs, int64_t value) {
if (const int64_t* pvalue = tir::as_const_int(lhs)) {
return pvalue[0] == value;
Expand Down
167 changes: 167 additions & 0 deletions src/relax/vm/compiler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relax/vm/compiler.cc
* \brief A compiler from relay::Module to the VM byte code.
*/

#include "compiler.h"

#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/expr_functor.h>

namespace tvm {
namespace runtime {
namespace relax_vm {

using namespace relax;

class VMFunctionCompiler : public ExprVisitor {
public:
VMFunctionCompiler() { builder_ = ExecBuilderNode::Create(); }

Executable Get() { return builder_->Get(); }

protected:
void VisitExpr_(const FunctionNode* func_node) {
builder_->Function("main", func_node->params.size());
size_t i = 0;
for (auto param : func_node->params) {
auto arg_register = NewRegister();
ICHECK_EQ(i, arg_register);
var_register_map_.insert({param, arg_register});
++i;
}
ExprVisitor::VisitExpr_(func_node);
}

void VisitExpr_(const SeqExprNode* op) {
for (auto block : op->blocks) {
this->VisitBindingBlock(block);
}
// find function return Var and emit
auto ret_reg = this->var_register_map_.find(Downcast<Var>(op->body));
ICHECK(ret_reg != this->var_register_map_.end());
builder_->EmitRet(ret_reg->second);
}

void VisitVarBinding(const VarBinding& binding) {
Var var = binding->var;
Call call_node = Downcast<Call>(binding->value);
if (auto* extern_func = call_node->op.as<relax::ExternFuncNode>()) {
String name = extern_func->global_symbol;
if (name == "vm.builtin.alloc_storage") {
Attrs attrs = call_node->attrs;
// Get the dtype hint from the attributes.
auto alloc_attrs = attrs.as<AllocStorageAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs";
DataType dtype = alloc_attrs->dtype;
int device_type = alloc_attrs->device_type;
PrimExpr size = Downcast<ShapeExpr>(call_node->args[0])->values[0];
PrimExpr alignment = Downcast<ShapeExpr>(call_node->args[1])->values[0];

std::vector<Instruction::Arg> args;
args.push_back(Instruction::Arg(Instruction::kVMStateRegister));
args.push_back(Instruction::Arg(Instruction::kImmediate, Downcast<IntImm>(size)->value));
args.push_back(
Instruction::Arg(Instruction::kImmediate, Downcast<IntImm>(alignment)->value));
args.push_back(Instruction::Arg(Instruction::kImmediate, device_type));

// store dtype in constant pool
TVMRetValue data_type;
data_type = dtype;
Index index = builder_->EmitConstant(data_type);
args.push_back(Instruction::Arg(Instruction::kConstIdx, index));

this->var_register_map_.insert({var, this->registers_num_});
builder_->EmitCall(name, args, NewRegister());
} else if (name == "vm.builtin.alloc_tensor") {
Attrs attrs = call_node->attrs;
auto alloc_attrs = attrs.as<AllocTensorAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be the AllocTensor attrs";
DataType dtype = alloc_attrs->dtype;

std::vector<Instruction::Arg> args;
auto storage_reg = this->var_register_map_.find(Downcast<Var>(call_node->args[0]));
ICHECK(storage_reg != this->var_register_map_.end());
args.push_back(Instruction::Arg(Instruction::kRegister, storage_reg->second));

PrimExpr offset = Downcast<ShapeExpr>(call_node->args[1])->values[0];
args.push_back(Instruction::Arg(Instruction::kImmediate, Downcast<IntImm>(offset)->value));

// store dtype in constant pool
TVMRetValue data_type;
data_type = dtype;
Index index = builder_->EmitConstant(data_type);
args.push_back(Instruction::Arg(Instruction::kConstIdx, index));

// store shape in constant pool
std::vector<int64_t> shape;
auto shape_expr = Downcast<ShapeExpr>(call_node->args[2])->values;
for (PrimExpr i : shape_expr) {
shape.push_back(Downcast<IntImm>(i)->value);
}
auto shape_tuple = ShapeTuple(shape);
TVMRetValue shape_tuple_value;
shape_tuple_value = shape_tuple;
index = builder_->EmitConstant(shape_tuple_value);
args.push_back(Instruction::Arg(Instruction::kConstIdx, index));

this->var_register_map_.insert({var, this->registers_num_});
builder_->EmitCall(name, args, NewRegister());
}
// Normal packed function
else {
std::vector<Instruction::Arg> args_;
for (size_t i = 0; i < call_node->args.size(); ++i) {
if (call_node->args[i].as<VarNode>()) {
auto reg = this->var_register_map_.find(Downcast<Var>(call_node->args[i]));
ICHECK(reg != this->var_register_map_.end());
args_.push_back(Instruction::Arg(Instruction::kRegister, reg->second));
}
}
builder_->EmitCall(name, args_, Instruction::kVoidArg);
}
}
}

size_t NewRegister() { return registers_num_++; }

/*! \brief Internal ExecBuilder. */
relax::ExecBuilder builder_;
/*! \brief Total number of virtual registers allocated. */
size_t registers_num_ = 0;
/*! \brief Map from var to register number. */
std::unordered_map<Var, RegName, ObjectPtrHash, ObjectPtrEqual> var_register_map_;
};

Executable Compile(const relay::Expr& e) {
auto compiler = VMFunctionCompiler();
compiler.VisitExpr(e);
return compiler.Get();
}

TVM_REGISTER_GLOBAL("relax.transform.compile").set_body_typed([](relay::Expr expr) {
return Compile(expr);
});

} // namespace relax_vm
} // namespace runtime
} // namespace tvm
Loading

0 comments on commit 43628e5

Please sign in to comment.