diff --git a/include/tvm/relax/attrs/memory.h b/include/tvm/relax/attrs/memory.h new file mode 100644 index 0000000000..91988906a2 --- /dev/null +++ b/include/tvm/relax/attrs/memory.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/memory.h + * \brief Attributes for memory operators. + */ +#ifndef TVM_RELAX_ATTRS_MEMORY_H_ +#define TVM_RELAX_ATTRS_MEMORY_H_ + +#include + +namespace tvm { +namespace relax { +/*! + * \brief Options for allocating storage. + */ +struct AllocStorageAttrs : public tvm::AttrsNode { + DataType dtype; + int device_id; + int device_type; + + TVM_DECLARE_ATTRS(AllocStorageAttrs, "relax.attrs.AllocStorageAttrs") { + TVM_ATTR_FIELD(dtype) + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory."); + TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory."); + } +}; + +/*! + * \brief Options for allocating tensors. + */ +struct AllocTensorAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(AllocTensorAttrs, "relax.attrs.AllocTensorAttrs") { + TVM_ATTR_FIELD(dtype) + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_ATTRS_MEMORY_H_ diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index d38b272971..749220a5b3 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -43,7 +43,8 @@ using relay::Call; * \param diag_ctx The diagnostic context for reporting errors. * \return The inferred output shape expression. */ -using FInferShape = runtime::TypedPackedFunc(const Call& call, DiagnosticContext diag_ctx)>; +using FInferShape = + runtime::TypedPackedFunc(const Call& call, DiagnosticContext diag_ctx)>; /*! * \brief Infer the output type for operators. This function will diff --git a/include/tvm/relax/vm/exec_builder.h b/include/tvm/relax/vm/exec_builder.h index 415c9533fe..c59cba99d8 100644 --- a/include/tvm/relax/vm/exec_builder.h +++ b/include/tvm/relax/vm/exec_builder.h @@ -21,8 +21,8 @@ * \file tvm/relax/vm/exec_builder.h * \brief */ -#ifndef TVM_RELAX_EXEC_BUILDER_H_ -#define TVM_RELAX_EXEC_BUILDER_H_ +#ifndef TVM_RELAX_VM_EXEC_BUILDER_H_ +#define TVM_RELAX_VM_EXEC_BUILDER_H_ #include #include @@ -52,7 +52,7 @@ class ExecBuilderNode : public Object { * \param func The function name. * \param num_inputs The number of inputs. */ - void Function(std::string func, int64_t num_inputs); + void EmitFunction(std::string func, int64_t num_inputs); /*! * \brief Emit a call instruction for a packed function. * \param func The packed function name. @@ -69,7 +69,7 @@ class ExecBuilderNode : public Object { * \brief Emit a constant value to the constant pool. * \return The index that represents the constant. */ - vm::Index EmitConstant(ObjectRef obj); + vm::Index EmitConstant(TVMRetValue obj); /*! * \brief Get the built executable. * \return The built executable. @@ -102,4 +102,4 @@ class ExecBuilder : public ObjectRef { } // namespace relax } // namespace tvm -#endif // TVM_RELAX_EXEC_BUILDER_H_ +#endif // TVM_RELAX_VM_EXEC_BUILDER_H_ diff --git a/include/tvm/relax/vm/executable.h b/include/tvm/relax/vm/executable.h index c9009c1c59..1d98432a2c 100644 --- a/include/tvm/relax/vm/executable.h +++ b/include/tvm/relax/vm/executable.h @@ -103,7 +103,7 @@ class ExecutableNode : public Object { /*! \brief A map from globals (as strings) to their index in the function map. */ std::unordered_map global_map; /*! \brief The global constant pool. */ - std::vector constants; + std::vector constants; /*! \brief The name of packed functions. */ std::vector func_names; /*! \brief A mapping from the packed function (as string) to the index that diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 19f1ecf6d7..27e32fe716 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -24,6 +24,7 @@ from . import parser from . import analysis from . import transform +from . import vm_compiler # Expr @@ -61,9 +62,11 @@ ExecBuilder = exec_builder.ExecBuilder VirtualMachine = vm.VirtualMachine load_exec_from_file = vm.load_exec_from_file +compile = vm_compiler.compile # Operator from .op.base import call_dps +from .op.op_attrs import AllocStorageAttrs, AllocTensorAttrs # IRBuilder IRBuilder = ir_builder.IRBuilder diff --git a/python/tvm/relax/base.py b/python/tvm/relax/base.py deleted file mode 100644 index b85ac77c6e..0000000000 --- a/python/tvm/relax/base.py +++ /dev/null @@ -1,4 +0,0 @@ -# Skeleton AST so we can get prototype working before this PR is merged -class rxNode: - def __init__(self, span): - self.span = span diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index 0ffcb82f50..a8a834a160 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -80,7 +80,7 @@ def emit_call(self, name, args=[], dst=None): dst = SpecialReg.VOID_ARG args_ = [] for arg in args: - if isinstance(arg, tvm.nd.NDArray): + if isinstance(arg, tvm.nd.NDArray) or isinstance(arg, tvm.DataType): new_arg = self.emit_constant(arg) args_.append(new_arg) else: diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 03f8b8858a..ca70193968 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -20,3 +20,4 @@ # Operators from .base import * from .tensor import * +from .op_attrs import * diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py new file mode 100644 index 0000000000..db5d1a6f92 --- /dev/null +++ b/python/tvm/relax/op/op_attrs.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The attributes node used for Relax operators""" +from tvm.ir import Attrs +import tvm._ffi + +@tvm._ffi.register_object("relax.attrs.AllocStorageAttrs") +class AllocStorageAttrs(Attrs): + """Attributes used in alloc_storage operators""" + + +@tvm._ffi.register_object("relax.attrs.AllocTensorAttrs") +class AllocTensorAttrs(Attrs): + """Attributes used in alloc_tensor operators""" \ No newline at end of file diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index a757971098..cb3eb79ed8 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from __future__ import annotations import inspect diff --git a/python/tvm/relax/vm_compiler.py b/python/tvm/relax/vm_compiler.py new file mode 100644 index 0000000000..99afa96610 --- /dev/null +++ b/python/tvm/relax/vm_compiler.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin +""" +The Relax Virtual Machine compiler. +""" +from typing import List, Optional, Union, Dict +import tvm +from . import vm, _ffi_api + + +def compile(mod: tvm.IRModule) -> vm.Executable: + """Compile the module to VM executable. A helper function for VMCompiler. + + Parameters + ---------- + mod : tvm.IRModule + The Relay module to build. + + Returns + ------- + exec : tvm.relax.Executable + The VM executable that contains the bytecode. + """ + compiler = VMCompiler() + compiler.compile(mod) + return compiler.get_exec() + + +class VMCompiler(object): + """Compiler that compiles module to VM executable.""" + + def __init__(self): + self.mod = _ffi_api.VMCompiler() + self._compile = self.mod["compile"] + self._get_exec = self.mod["get_executable"] + + def compile(self, mod: tvm.IRModule) -> None: + """Compile the module to VM executable. + + Parameters + ---------- + mod : tvm.IRModule + The IRModule to build. + """ + self._compile(mod) + + def get_exec(self) -> vm.Executable: + """Get the VM executable. + + Returns + ------- + exec : tvm.relax.Executable + The VM executable that contains bytecode. + """ + return self._get_exec() diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index b40d41b4c7..c7e1e58419 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -24,6 +25,9 @@ namespace tvm { namespace relax { +TVM_REGISTER_NODE_TYPE(AllocStorageAttrs); +TVM_REGISTER_NODE_TYPE(AllocTensorAttrs); + bool EqualConstInt(const PrimExpr& lhs, int64_t value) { if (const int64_t* pvalue = tir::as_const_int(lhs)) { return pvalue[0] == value; diff --git a/src/relax/vm/compiler.cc b/src/relax/vm/compiler.cc new file mode 100644 index 0000000000..2ebe6a3fd9 --- /dev/null +++ b/src/relax/vm/compiler.cc @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/vm/compiler.cc + * \brief A compiler from relay::Module to the VM byte code. + */ + +#include "compiler.h" + +#include +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +using namespace relax; + +class VMFunctionCompiler : public ExprVisitor { + public: + explicit VMFunctionCompiler(ExecBuilderNode* builder) { builder_ = GetRef(builder); } + + protected: + /*! \brief A counter for naming local functions. */ + int local_func_counter_ = 0; + + // TODO(@yuchen): support visiting other IR nodes + void VisitExpr_(const FunctionNode* func_node) { + if (func_node->name.defined()) { + builder_->EmitFunction(func_node->name.value()->name_hint, func_node->params.size()); + } else { + // TODO(@yuchen): handle local functions that capture local vars outside the func + // TODO(@yuchen): a renaming pass to resolve name conflicts, e.g. the input module has a + // function named "local_funcN" + // lift the local func to a global func and compile it normally + builder_->EmitFunction("local_func" + std::to_string(local_func_counter_++), + func_node->params.size()); + } + for (auto param : func_node->params) { + NewRegister(param); + } + ExprVisitor::VisitExpr_(func_node); + } + + void VisitExpr_(const SeqExprNode* op) { + for (auto block : op->blocks) { + this->VisitBindingBlock(block); + } + // find the function return value and emit the output + auto ret_reg = this->var_register_map_.find(Downcast(op->body)); + ICHECK(ret_reg != this->var_register_map_.end()); + builder_->EmitRet(ret_reg->second); + } + + void VisitVarBinding(const VarBinding& binding) { + Var var = binding->var; + // TODO(@yuchen): support other nodes than Call + Call call_node = Downcast(binding->value); + if (auto* extern_func = call_node->op.as()) { + String name = extern_func->global_symbol; + if (name == "vm.builtin.alloc_storage") { + EmitAllocStorage(call_node, var); + } else if (name == "vm.builtin.alloc_tensor") { + EmitAllocTensor(call_node, var); + } else { + // Normal packed function without attributes + std::vector args; + for (size_t i = 0; i < call_node->args.size(); ++i) { + if (call_node->args[i].as()) { + auto reg = this->var_register_map_.find(Downcast(call_node->args[i])); + ICHECK(reg != this->var_register_map_.end()); + args.push_back(Instruction::Arg(Instruction::kRegister, reg->second)); + } + } + // TODO(@yuchen): what if the packed func has void return (no need to write to the dst + // register)? + builder_->EmitCall(name, args, NewRegister(var)); + } + } else { + LOG(FATAL) << "TODO: support compiling everything other than extern functions."; + } + } + + void EmitAllocStorage(const Call& call_node, const Var& var) { + Attrs attrs = call_node->attrs; + + // Get dtype and device_type from the attributes. + auto alloc_attrs = attrs.as(); + ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs"; + DataType dtype = alloc_attrs->dtype; + int device_type = alloc_attrs->device_type; + PrimExpr size = Downcast(call_node->args[0])->values[0]; + PrimExpr alignment = Downcast(call_node->args[1])->values[0]; + + std::vector args; + args.push_back(Instruction::Arg(Instruction::kVMStateRegister)); + args.push_back(Instruction::Arg(Instruction::kImmediate, Downcast(size)->value)); + args.push_back(Instruction::Arg(Instruction::kImmediate, Downcast(alignment)->value)); + args.push_back(Instruction::Arg(Instruction::kImmediate, device_type)); + + // store dtype in constant pool + TVMRetValue data_type; + data_type = dtype; + Index index = this->builder_->EmitConstant(data_type); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + + builder_->EmitCall("vm.builtin.alloc_storage", args, NewRegister(var)); + } + + void EmitAllocTensor(const Call& call_node, const Var& var) { + Attrs attrs = call_node->attrs; + + // Get dtype from the attributes. + auto alloc_attrs = attrs.as(); + ICHECK(alloc_attrs != nullptr) << "must be the AllocTensor attrs"; + DataType dtype = alloc_attrs->dtype; + + std::vector args; + auto storage_reg = this->var_register_map_.find(Downcast(call_node->args[0])); + ICHECK(storage_reg != this->var_register_map_.end()); + args.push_back(Instruction::Arg(Instruction::kRegister, storage_reg->second)); + + PrimExpr offset = Downcast(call_node->args[1])->values[0]; + args.push_back(Instruction::Arg(Instruction::kImmediate, Downcast(offset)->value)); + + // store dtype in constant pool + TVMRetValue data_type; + data_type = dtype; + Index index = builder_->EmitConstant(data_type); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + + // TODO(@yuchen, @ziheng): support symbolic shape when connecting with shape lowering + // store shape in constant pool + std::vector shape; + auto shape_expr = Downcast(call_node->args[2])->values; + for (PrimExpr i : shape_expr) { + shape.push_back(Downcast(i)->value); + } + auto shape_tuple = ShapeTuple(shape); + TVMRetValue shape_tuple_value; + shape_tuple_value = shape_tuple; + index = builder_->EmitConstant(shape_tuple_value); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + + builder_->EmitCall("vm.builtin.alloc_tensor", args, NewRegister(var)); + } + + size_t NewRegister(Var var) { + size_t reg = this->registers_num_++; + this->var_register_map_.insert({var, reg}); + return reg; + } + + /*! \brief Internal ExecBuilder. */ + relax::ExecBuilder builder_; + /*! \brief Total number of virtual registers allocated. */ + size_t registers_num_ = 0; + /*! \brief Map from var to register number. */ + std::unordered_map var_register_map_; +}; + +PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "compile") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.num_args, 1); + IRModule mod = args[0]; + this->Compile(mod); + }); + } else if (name == "get_executable") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetExec(); }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc([name](TVMArgs args, TVMRetValue* rv) {}); + } +} + +void VMCompiler::Compile(IRModule mod) { + // TODO(@yuchen, @ziheng): support lowering PrimFuncs + for (auto& func : mod->functions) { + auto gvar = func.first; + if (!func.second->IsInstance()) { + continue; + } + + VMFunctionCompiler func_compiler(); + if (auto* n = func.second.as()) { + auto func = GetRef(n); + auto func_compiler = VMFunctionCompiler(builder_.operator->()); + func_compiler.VisitExpr(func); + } + } +} + +Executable VMCompiler::GetExec() { return builder_->Get(); } + +runtime::Module CreateVMCompiler() { + auto compiler = make_object(); + return runtime::Module(compiler); +} + +TVM_REGISTER_GLOBAL("relax.VMCompiler").set_body_typed([]() { return CreateVMCompiler(); }); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/relax/vm/compiler.h b/src/relax/vm/compiler.h new file mode 100644 index 0000000000..c4e10493f7 --- /dev/null +++ b/src/relax/vm/compiler.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/vm/compiler.h + * \brief A compiler to compile a relay::Module to the VM executable. + */ + +#ifndef TVM_RELAX_VM_COMPILER_H_ +#define TVM_RELAX_VM_COMPILER_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class VMCompiler : public runtime::ModuleNode { + public: + /*! + * \brief Compile the functions in a Module. + * \param mod Input IRModule to be compiled. + */ + void Compile(IRModule mod); + /*! + * \brief Get the compiled executable. + * \return The compiled executable. + */ + Executable GetExec(); + + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + const char* type_key() const { return "relax.VMCompiler"; } + + protected: + /*! \brief Internal executable builder. */ + relax::ExecBuilder builder_ = relax::ExecBuilderNode::Create(); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RELAX_VM_COMPILER_H_ diff --git a/src/relax/vm/exec_builder.cc b/src/relax/vm/exec_builder.cc index 4f681571fd..2b2545f966 100644 --- a/src/relax/vm/exec_builder.cc +++ b/src/relax/vm/exec_builder.cc @@ -37,13 +37,13 @@ ExecBuilder ExecBuilderNode::Create() { return ret; } -vm::Index ExecBuilderNode::EmitConstant(ObjectRef obj) { +vm::Index ExecBuilderNode::EmitConstant(TVMRetValue obj) { vm::Index idx = exec->constants.size(); exec->constants.push_back(obj); return vm::Instruction::Arg(vm::Instruction::kConstIdx, idx).data; } -void ExecBuilderNode::Function(std::string func_name, int64_t num_inputs) { +void ExecBuilderNode::EmitFunction(std::string func_name, int64_t num_inputs) { const auto& m = exec->global_map; ICHECK(m.find(func_name) == m.end()); VMFunction vmfunc; @@ -181,13 +181,16 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderCreate") .set_body_typed(ExecBuilderNode::Create); TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitConstant") -.set_body_typed([](ExecBuilder builder, ObjectRef obj) { - return builder->EmitConstant(obj); +.set_body([](TVMArgs args, TVMRetValue* ret) { + ExecBuilder builder = args[0]; + TVMRetValue rt; + rt = args[1]; + *ret = builder->EmitConstant(rt); }); TVM_REGISTER_GLOBAL("relax.ExecBuilderFunction") .set_body_typed([](ExecBuilder builder, String name, int64_t num_inputs) { - return builder->Function(name, num_inputs); + return builder->EmitFunction(name, num_inputs); }); TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") diff --git a/src/relax/vm/executable.cc b/src/relax/vm/executable.cc index 58563df3d7..8620055987 100644 --- a/src/relax/vm/executable.cc +++ b/src/relax/vm/executable.cc @@ -38,6 +38,12 @@ namespace relax_vm { /*! \brief The magic number for the serialized VM bytecode file */ constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; +/*! \brief Possible types in the constant pool */ +enum ConstantType : int { + kNDArray = 0, + kDLDataType = 1, +}; + #define STREAM_CHECK(val, section) \ ICHECK(val) << "Invalid VM file format in the " << section << " section." \ << "\n"; @@ -48,22 +54,35 @@ std::string ExecutableNode::Stats() const { std::ostringstream oss; oss << "Relax VM executable statistics:" << std::endl; - // Get the number of constants and the shape of each of them. - oss << " Constant shapes (# " << constants.size() << "): ["; + // Get the number of constants. + // If the constant is an NDArray, get the shape of each of them. + // If the constant is an DLDataType, get the data type of each of them. + oss << " Constant pool (# " << constants.size() << "): ["; for (const auto& it : constants) { - const auto constant = Downcast(it); - const auto& shape = constant.Shape(); - // Scalar - if (shape.empty()) { - oss << "scalar, "; - continue; - } - oss << "["; - for (auto s : shape) { - oss << s << ", "; + if (it.IsObjectRef()) { + const auto ndarray = it.operator tvm::runtime::NDArray(); + const auto& shape = ndarray.Shape(); + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], "; + } else { + try { + DLDataType dtype = it.operator DLDataType(); + oss << dtype; + oss << ", "; + } catch (std::exception& exc) { + LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + << ArgTypeCode2Str(it.type_code()); + } } - oss.seekp(-2, oss.cur); - oss << "], " << std::endl; } if (!constants.empty()) oss.seekp(-2, oss.cur); oss << "]" << std::endl; @@ -219,14 +238,21 @@ void ExecutableNode::SaveGlobalSection(dmlc::Stream* strm) { } void ExecutableNode::SaveConstantSection(dmlc::Stream* strm) { - std::vector arrays; - for (const auto& obj : this->constants) { - const auto cell = Downcast(obj); - arrays.push_back(const_cast(cell.operator->())); - } strm->Write(static_cast(this->constants.size())); - for (const auto& it : arrays) { - runtime::SaveDLTensor(strm, it); + std::vector arrays; + for (const auto& it : this->constants) { + if (it.IsObjectRef()) { + strm->Write(ConstantType::kNDArray); + runtime::SaveDLTensor(strm, it.operator DLTensor*()); + } else { + try { + strm->Write(ConstantType::kDLDataType); + strm->Write(it.operator DLDataType()); + } catch (std::exception& exc) { + LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + << ArgTypeCode2Str(it.type_code()); + } + } } } @@ -256,11 +282,26 @@ void ExecutableNode::LoadConstantSection(dmlc::Stream* strm) { STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); size_t size = static_cast(sz); + runtime::NDArray ndarray; + DLDataType dtype; // Load each of the constants. for (size_t i = 0; i < size; i++) { - runtime::NDArray constant; - STREAM_CHECK(constant.Load(strm), "constant"); - this->constants.push_back(constant); + int constant_type; + STREAM_CHECK(strm->Read(&constant_type, sizeof(constant_type)), "constant"); + if (constant_type == ConstantType::kNDArray) { + ndarray.Load(strm); + TVMRetValue cell; + cell = ndarray; + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kDLDataType) { + strm->Read(&dtype); + TVMRetValue cell; + cell = dtype; + this->constants.push_back(cell); + } else { + LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + << ArgTypeCode2Str(constant_type) << " when loading the VM constant pool."; + } } } @@ -388,7 +429,7 @@ String ExecutableNode::AsPython() const { os << " ib.emit_call(\"" << this->func_names[instr.func_idx] << "\", args=[" << StrJoin(instr.args, 0, instr.num_args, ", ", InstrArgToPyStr) << "]"; - if (instr.dst != Instruction::kVoidArg) os << ", ret=ib.r(" << instr.dst << ")"; + if (instr.dst != Instruction::kVoidArg) os << ", dst=ib.r(" << instr.dst << ")"; os << ")\n"; break; } diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index cf22c03741..68d317e2ba 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations # must import to defer parsing of annotations import numpy as np import tvm +from tvm.relay import Call from tvm import relax as rx from tvm.runtime import container @@ -34,6 +36,10 @@ def mul(a, b): ret = a.asnumpy() * b.asnumpy() return tvm.nd.array(ret) +@tvm.register_func("test.vm.identity") +def identity_packed(a, b): + b[:] = tvm.nd.array(a.asnumpy()) + def test_vm_execute(): ib = rx.ExecBuilder() with ib.function("func0", num_inputs=2): @@ -77,6 +83,25 @@ def test_vm_serialize(): exec1 = rx.load_exec_from_file("exec.bin") assert exec0.astext() == exec1.astext() +def test_vm_constant_serialize(): + dtype = tvm.DataType('float32') + shape = (3, 4) + shape_tuple = container.ShapeTuple(shape) + input = tvm.nd.array(np.random.rand(3,4).astype(np.float32)) + ib = rx.ExecBuilder() + with ib.function("main", num_inputs=1): + ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), ib.imm(24), ib.imm(64), ib.imm(1), dtype], dst=ib.r(1)) + ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), ib.imm(0), dtype, ib.r(0)], dst=ib.r(2)) + ib.emit_call("test.vm.identity", args=[input, ib.r(2)]) + ib.emit_ret(ib.r(2)) + exec0 = ib.get() + exec0.save_to_file("exec.bin") + exec1 = rx.load_exec_from_file("exec.bin") + assert exec0.astext() == exec1.astext() + vm = rx.VirtualMachine(exec1, tvm.cpu()) + res = vm["main"](shape_tuple) + np.testing.assert_allclose(input.asnumpy(), res.asnumpy()) + def test_vm_checker(): ib = rx.ExecBuilder() try: @@ -177,6 +202,22 @@ def test_vm_storage(): assert res.device == tvm.cpu() assert res.shape == shape +def test_vm_compile(): + @rx.script + class Mod: + def foo(x: Tensor[(3, 4), "float32"]): + y = relax.call_packed("vm.builtin.alloc_storage", (12,), (64,), device_id=0, device_type=1, attrs_type_key="relax.attrs.AllocStorageAttrs") + z = relax.call_packed("vm.builtin.alloc_tensor", y, (0,), (3, 4), attrs_type_key="relax.attrs.AllocTensorAttrs") + w = relax.call_packed("test.vm.identity", x, z) + return z + + mod = Mod() + exec = rx.vm_compiler.compile(mod) + input = tvm.nd.array(np.random.rand(3,4).astype(np.float32)) + vm = rx.VirtualMachine(exec, tvm.cpu()) + res = vm["foo"](input) + np.testing.assert_allclose(input.asnumpy(), res.asnumpy()) + if __name__ == "__main__": test_vm_execute() test_vm_multiple_func() @@ -184,6 +225,8 @@ def test_vm_storage(): test_vm_formalize() test_vm_operand() test_vm_serialize() + test_vm_constant_serialize() test_vm_shapeof() test_vm_heap() test_vm_storage() + test_vm_compile()