diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b762e7808ab0..c1567ee337e5 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + # GH actions. # We use it to cover windows and mac builds # Jenkins is still the primary CI @@ -31,7 +48,13 @@ jobs: run: >- conda build --output-folder=conda/pkg conda/recipe && conda install tvm -c ./conda/pkg +<<<<<<< HEAD - name: Build iOS RPC +======= +<<<<<<< HEAD + - name: Build iOS RPC@MacOS + if: startsWith(matrix.os, 'macOS') +>>>>>>> Relax Virtual Machine run: | IOS_VERSION="14.0" CMAKE_FLAGS="-DCMAKE_BUILD_TYPE=Release \ @@ -47,11 +70,29 @@ jobs: cd build-ios-simulator cmake .. ${CMAKE_FLAGS} cmake --build . --target ios_rpc +<<<<<<< HEAD - name: Test shell: bash -l {0} run: >- python -m pytest -v tests/python/all-platform-minimal-test - name: Test iOS RPC +======= +======= +>>>>>>> Relax Virtual Machine + - name: Test@Win + if: startsWith(matrix.os, 'windows') + shell: cmd /C call {0} + run: >- + python -m pytest -v tests/python/all-platform-minimal-test + - name: Test@MacOS + if: startsWith(matrix.os, 'macOS') + shell: bash -l {0} + run: >- + python -m pytest -v tests/python/all-platform-minimal-test +<<<<<<< HEAD + - name: Test iOS RPC@MacOS + if: startsWith(matrix.os, 'macOS') +>>>>>>> Relax Virtual Machine shell: bash -l {0} run: >- python -m pip install tornado psutil cloudpickle && @@ -59,6 +100,7 @@ jobs: export BUNDLE_ID=org.apache.tvmrpc && export BUNDLE_PATH=build-ios-simulator/apps/ios_rpc/ios_rpc/src/ios_rpc-build/Release-iphonesimulator/tvmrpc.app && python -m pytest -v tests/python/contrib/test_rpc_server_device.py +<<<<<<< HEAD Windows: runs-on: windows-2016 @@ -78,3 +120,7 @@ jobs: run: >- python -m pytest -v tests/python/all-platform-minimal-test +======= +======= +>>>>>>> Relax Virtual Machine +>>>>>>> Relax Virtual Machine diff --git a/CMakeLists.txt b/CMakeLists.txt index f87a3a9f617f..426ff5a4f0a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -262,6 +262,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/parser/*.cc src/printer/*.cc src/support/*.cc + src/relax/*.cc ) tvm_file_glob(GLOB CODEGEN_SRCS diff --git a/include/tvm/relax/builder.h b/include/tvm/relax/builder.h new file mode 100644 index 000000000000..40e124041e20 --- /dev/null +++ b/include/tvm/relax/builder.h @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/builder.h + * \brief + */ +#ifndef TVM_RELAX_BUILDER_H_ +#define TVM_RELAX_BUILDER_H_ + +#include +#include +#include +#include +#include + +#include "./vm/bytecode.h" +#include "./vm/executable.h" + +namespace tvm { +namespace relax { + +namespace vm = tvm::runtime::relax_vm; + +class ExecBuilder; + +/*! + * \brief A builder provides api to build VM executable with instructions. + */ +class ExecBuilderNode : public Object { + public: + /*! \brief The mutable internal executable node. */ + ObjectPtr exec; // mutable + /*! + * \brief To annotate the start of a vm function. + * \param func The function name. + * \param num_inputs The number of inputs. + */ + void Function(std::string func, int64_t num_inputs); + /*! + * \brief Emit a call instruction for a packed function. + * \param func The packed function name. + * \param args The arguments of the function. + * \param ret The return register. + */ + void EmitCall(std::string func, std::vector args, vm::RegName ret); + /*! + * \brief Emit a ret instruction. + * \param result The return result. + */ + void EmitRet(vm::RegName result); + /*! + * \brief Emit a constant value to the constant pool. + * \return The index that represents the constant. + */ + vm::Index EmitConstant(ObjectRef obj); + /*! + * \brief Get the built executable. + * \return The built executable. + */ + vm::Executable Get(); + /*! + * \brief Create a ExecBuilder. + * \return The ExecBuilder. + */ + TVM_DLL static ExecBuilder Create(); + + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.ExecBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExecBuilderNode, Object); + + private: + /*! + * \brief Formalize the executable. + */ + void Formalize(); +}; + +class ExecBuilder : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecBuilder, ObjectRef, ExecBuilderNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BUILDER_H_ diff --git a/include/tvm/relax/vm/bytecode.h b/include/tvm/relax/vm/bytecode.h new file mode 100644 index 000000000000..a980705194b4 --- /dev/null +++ b/include/tvm/relax/vm/bytecode.h @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/vm/bytecode.h + * \brief The bytecode for the virtual machine. + */ +#ifndef TVM_RELAX_VM_BYTECODE_H_ +#define TVM_RELAX_VM_BYTECODE_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + + +/*! + * \brief The storage type for the bytecode in the VM. + */ +using ExecWord = int64_t; + +/*! \brief A register name. */ +using RegName = ExecWord; + +/*! + * \brief An alias for the integer type used ubiquitously in the VM. + */ +using Index = ExecWord; + +/*! + * \brief An enumeration of Relax's opcodes. + * + * The opcode is used to implement instruction + * as a tagged union. + */ +enum class Opcode { + Call = 1U, + Ret = 2U, +}; + + +/*! \brief A single virtual machine instruction. + * + * The representation of the instruction is as + * a tagged union. + * + * The first field represents which instruction, + * and by extension which field of the union + * is active. + */ +struct Instruction { + /*! \brief Random magic number that represents void argument. */ + static constexpr RegName kVoidArg = 0x00EC66FE0321975A; + /*! \brief Random magic number that represents the VM state. */ + static constexpr RegName kVMStateRegister = 0x008D14FA4379015C; + /*! + * \brief The kind of instruction's argument. + */ + enum ArgKind { + kRegister = 0, + kImmediate = 1, + kConstIdx = 2, + }; + /*! + * \brief The auxiliary data structure for instruction argument. + */ + struct Arg { + /*! \brief The number of bit for storing value. */ + static constexpr ExecWord kValueBit = sizeof(ExecWord) * 8 - 8; + /*! \brief The bit mask of the value part. */ + static constexpr ExecWord kValueMask = (static_cast(1) << kValueBit) - 1; + /*! \brief Construct a void argument. */ + explicit Arg() : data(Instruction::kVoidArg) {} + /*! \brief Construct from the data. */ + explicit Arg(ExecWord data) : data(data) {} + /*! \brief Construct from the kind and value. */ + Arg(ArgKind kind, Index value) { + // TODO(ziheng): check value? + this->data = (static_cast(kind) << kValueBit) | + (value & kValueMask); + } + /*! + * \brief Get the kind of argument.. + * \return The kind of argument. + */ + ArgKind kind() const { + uint8_t kind = (data >> kValueBit) & 0xFF; + return Instruction::ArgKind(kind); + } + /*! + * \brief Get the value of argument.. + * \return The value of argument. + */ + ExecWord value() const { + return data & ((static_cast(1) << kValueBit) - 1); + } + /*! \brief The underlying stored data. */ + ExecWord data; + }; + /*! \brief The instruction opcode. */ + Opcode op; + /*! \brief The destination register. */ + RegName dst; + union { + struct /* Call */ { + /*! \brief The index into the packed function table. */ + Index func_idx; + /*! \brief The number of arguments to the packed function. */ + Index num_args; + /*! \brief The arguments of the packed function. */ + Arg* args; + }; + struct /* Ret */ { + /*! \brief The return result. */ + RegName result; + }; + }; + /*! + * \brief Construct a Call instruction. + * \param func_idx The index of the function to call. + * \param num_args The number of arguments. + * \param args The input arguments. + * \param dst The destination register. + * \return The call instruction. + */ + static Instruction Call(Index func_idx, Index num_args, + Arg* args, + RegName dst); + /*! + * \brief Construct a return instruction. + * \param result The register containing the return value. + * \return The return instruction. + */ + static Instruction Ret(RegName result); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RELAX_VM_BYTECODE_H_ diff --git a/include/tvm/relax/vm/executable.h b/include/tvm/relax/vm/executable.h new file mode 100644 index 000000000000..c9009c1c59d0 --- /dev/null +++ b/include/tvm/relax/vm/executable.h @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/vm/executable.h + * \brief + */ +#ifndef TVM_RELAX_VM_EXECUTABLE_H_ +#define TVM_RELAX_VM_EXECUTABLE_H_ + +#include +#include +#include +#include "./bytecode.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class Executable; + +/*! + * \brief A representation of a Relax function in the VM. + * + * Contains metadata about the compiled function, as + * well as the compiled VM instructions. + */ +struct VMFunction { + /*! \brief The function's name. */ + std::string name; + /*! \brief The start instruction index of the function. */ + Index start_instr; + /*! \brief The number of arguments of the function. */ + Index num_args; + /*! \brief The register file size of the function. */ + Index register_file_size; +}; + +/*! + * \brief The executable emitted by the VM compiler. + * + * The executable contains information (e.g. data in different memory regions) + * to run in a virtual machine. + */ +class ExecutableNode : public Object { + public: + /*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globls and constants, etc. + */ + std::string Stats() const; + /*! + * \brief Get the i-th instruction from the executable. + * \return The instruction. + */ + Instruction GetInstruction(Index i) const; + /*! + * \brief Print the instructions as text format. + */ + String AsText() const; + /*! + * \brief Print the instructions as python program. + */ + String AsPython() const; + /*! + * \brief Write the Executable to the binary stream in serialized form. + * \param stream The binary stream to save the executable to. + */ + void SaveToBinary(dmlc::Stream* stream); + /*! + * \brief Load Executable from the binary stream in serialized form. + * \param stream The binary stream that load the executable from. + */ + static Executable LoadFromBinary(void* stream); + /*! + * \brief Write the Executable to the provided path as a file contianing its serialized content. + * \param path The path to write the serialized data to. + */ + void SaveToFile(const std::string& path); + /*! + * \brief Load Executable from the file. + * \param file_name The file that load the executable from. + */ + static Executable LoadFromFile(const std::string& file_name); + /*! \brief The virtual machine's function table. */ + std::vector global_funcs; + /*! \brief A map from globals (as strings) to their index in the function map. */ + std::unordered_map global_map; + /*! \brief The global constant pool. */ + std::vector constants; + /*! \brief The name of packed functions. */ + std::vector func_names; + /*! \brief A mapping from the packed function (as string) to the index that + * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. + */ + std::unordered_map func2idx; + /*! \brief The offset of instruction. */ + std::vector instr_offset; + /*! \brief The byte data of instruction. */ + std::vector instr_data; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.Executable"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExecutableNode, Object); + + private: + /*! + * \brief Save the globals. + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); + /*! + * \brief Save the constant pool. + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); + /*! + * \brief Save the instructions. + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); + /*! + * \brief Save the packed functions. + * \param strm The input stream. + */ + void SavePackedFuncNames(dmlc::Stream* strm); + /*! + * \brief Load the globals. + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); + /*! + * \brief Load the constant pool. + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); + /*! + * \brief Load the instructions. + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); + /*! + * \brief Save the packed functions. + * \param strm The input stream. + */ + void LoadPackedFuncNames(dmlc::Stream* strm); +}; + +/*! \brief Reference to Executable. */ +class Executable : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Executable, ObjectRef, ExecutableNode); +}; + + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RELAX_VM_EXECUTABLE_H_ diff --git a/include/tvm/relax/vm/memory_manager.h b/include/tvm/relax/vm/memory_manager.h new file mode 100644 index 000000000000..8ae5bd185957 --- /dev/null +++ b/include/tvm/relax/vm/memory_manager.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/vm/memory_manager.h + * \brief Abstract device memory management API + */ +#ifndef TVM_RELAX_VM_MEMORY_MANAGER_H_ +#define TVM_RELAX_VM_MEMORY_MANAGER_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +struct Buffer { + /*! \brief The pointer to the allocated block of memory. */ + void* data{nullptr}; + /*! \brief The size of the block. */ + size_t size{0}; + /*! \brief The device of the allocated buffers. */ + Device device; +}; + +enum AllocatorType { + kNaive = 1, + kPooled, +}; + +class Allocator { + public: + explicit Allocator(AllocatorType type) : type_(type) {} + virtual ~Allocator() = default; + /*! \brief Allocate an empty NDArray using from the allocator. + * \param shape The shape of the NDArray. + * \param dtype The datatype of the NDArray. + * \param dev The device where the array is allocated. + * \return The empty NDArray. + */ + runtime::NDArray Empty(std::vector shape, DLDataType dtype, Device dev); + /*! \brief Return the allocator type. */ + inline AllocatorType type() const { return type_; } + /*! \brief Allocate a buffer given a size, alignment and type. + * \param nbytes The size of the buffer. + * \param alignment The alignment of the buffer. + * \param type_hint A type hint to the allocator. + * \return A sized allocation in the form of a buffer. + */ + virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0; + /*! \brief Free a buffer allocated by the allocator. + * \param buffer The buffer to free. + */ + virtual void Free(const Buffer& buffer) = 0; + + private: + AllocatorType type_; +}; + +class MemoryManager { + public: + static MemoryManager* Global(); + /*! + * \brief Get or create an allocator given the device and allocator type. + * \param dev The TVM device + * \param type The allocator type + * \return The memory allocator. + */ + static Allocator* GetOrCreateAllocator(Device dev, AllocatorType type); + /*! + * \brief Get an allocator given the device. + * \param dev The TVM device + * \return The memory allocator. + */ + static Allocator* GetAllocator(Device dev); + + private: + MemoryManager() {} + + private: + std::mutex mutex_; + std::unordered_map> allocators_; +}; + +/*! \brief An object representing a storage allocation. */ +class StorageObj : public Object { + public: + /*! \brief The index into the VM function table. */ + Buffer buffer; + + /*! \brief Allocate an NDArray from a given piece of storage. */ + runtime::NDArray AllocNDArray(size_t offset, ShapeTuple shape, DLDataType dtype); + + /*! \brief The deleter for an NDArray when allocated from underlying storage. */ + static void Deleter(Object* ptr); + + ~StorageObj() { + auto alloc = MemoryManager::Global()->GetAllocator(buffer.device); + alloc->Free(buffer); + } + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.Storage"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object); +}; + +/*! \brief reference to storage. */ +class Storage : public ObjectRef { + public: + explicit Storage(Buffer buffer); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Storage, ObjectRef, StorageObj); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RELAX_VM_MEMORY_MANAGER_H_ diff --git a/include/tvm/relax/vm/vm.h b/include/tvm/relax/vm/vm.h new file mode 100644 index 000000000000..f579573df72c --- /dev/null +++ b/include/tvm/relax/vm/vm.h @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/vm/vm.h + * \brief + */ +#ifndef TVM_RELAX_VM_VM_H_ +#define TVM_RELAX_VM_VM_H_ + +#include "./bytecode.h" +#include "./executable.h" +#include "./memory_manager.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief The register type. + */ +using RegType = TVMRetValue; + +/*! + * \brief A representation of a stack frame. + * + * A stack frame is a record containing the information needed + * to restore the caller's virtual machine state after returning + * from a function call. + */ +struct VMFrame { + /*! \brief The return program counter. */ + Index return_pc; + /*! \brief Statically allocated space for objects */ + std::vector register_file; + /*! \brief Register in caller's frame to put return value */ + RegName caller_return_register; + + VMFrame(Index pc, Index register_file_size) + : return_pc(pc), register_file(register_file_size), caller_return_register(0) {} +}; + +/*! + * \brief The state of virtual machine, which can be referred in + * instruction. + */ +struct VMState { + /*! \brief The memory allocators. */ + std::vector allocators; +}; + +/*! + * \brief The virtual machine. + * + * The virtual machine contains all the current execution state, + * as well as the executable. + * + * The goal is to have a single self-contained object, + * enabling one to easily pass around VMs, execute them on + * multiple threads, or serialize them to disk or over the + * wire. + */ +class VirtualMachine : public runtime::ModuleNode { + public: + /*! + * \brief Initialize the virtual machine for a set of devices. + * \param devices The set of TVM devices. + * \param alloc_types The allocator types for each device. + */ + void Init(const std::vector& devices, const std::vector& alloc_types); + /*! + * \brief load the executable and module for the virtual machine. + * \param exec The executable. + * \param mod The library module. + */ + void Load(Executable exec, runtime::Module mod); + /*! + * \brief Get a PackedFunc from module. + * + * The PackedFunc may not be fully initialized, + * there might still be first time running overhead when + * executing the function on certain devices. + * For benchmarking, use prepare to eliminate + * + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * + * \return PackedFunc(nullptr) when it is not available. + * + * \note The function will always remain valid. + * If the function needs resource from the module(e.g. late linking), + * it should capture sptr_to_self. + */ + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) final; + + virtual ~VirtualMachine() final {} + + const char* type_key() const final { return "relax.VirtualMachine"; } + /*! \brief The state of the virtual machine, which can be referred by + * instructions. + */ + VMState state; + + protected: + /*! \brief Push a call frame on to the call stack. */ + void PushFrame(Index ret_pc, const VMFunction& vm_func); + /*! + * \brief Pop a frame off the call stack. + * \return The number of frames left. + */ + void PopFrame(); + /*! + * \brief Write to a VM register. + * \param reg The register to write to. + * \param obj The object to write to. + */ + inline void WriteRegister(RegName reg, const RegType& obj); + /*! + * \brief Read a VM register. + * \param reg The register to read from. + * \return The read object. + */ + inline RegType ReadRegister(RegName reg) const; + /*! + * \brief Invoke a VM function. + * \param fidx The function index. + * \param args The arguments to the function. + * \return The object representing the result. + */ + RegType Invoke(Index fidx, const std::vector& args); + /*! \brief Run VM dispatch loop. */ + void RunLoop(); + + private: + /*! \brief The loaded executable. */ + Executable exec_; + /*! \brief The loaded module. */ + runtime::Module mod_; + /*! \brief The current stack of call frames. */ + std::vector frames_; + /*! \brief The virtual machine PC. */ + Index pc_{0}; + /*! \brief The special return register. */ + RegType return_value_; + /*! \brief The devices. */ + std::vector devices_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RELAX_VM_VM_H_ diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py new file mode 100644 index 000000000000..497937327991 --- /dev/null +++ b/python/tvm/relax/__init__.py @@ -0,0 +1,2 @@ +from .vm import VirtualMachine, load_exec_from_file +from .builder import ExecBuilder diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py new file mode 100644 index 000000000000..62b8f8a2e5a8 --- /dev/null +++ b/python/tvm/relax/_ffi_api.py @@ -0,0 +1,3 @@ +import tvm._ffi + +tvm._ffi._init_api("relax", __name__) diff --git a/python/tvm/relax/builder.py b/python/tvm/relax/builder.py new file mode 100644 index 000000000000..0ffcb82f501b --- /dev/null +++ b/python/tvm/relax/builder.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from enum import IntEnum +import tvm +from tvm.runtime import Object +from tvm._ffi.base import _LIB, check_call +from . import _ffi_api + +class SpecialReg(IntEnum): + """Magic numbers that represent special registers in vm.""" + VOID_ARG = 0x00EC66FE0321975A + VM_STATE = 0x008D14FA4379015C + +class VMFuncScope(object): + """An object corresponds to each VM function, working as a context manager.""" + stack = [] + + def __enter__(self): + VMFuncScope.stack.append(self) + return self + + def __exit__(self, ptype, value, trace): + VMFuncScope.stack.pop() + +@tvm._ffi.register_object("relax.ExecBuilder") +class ExecBuilder(Object): + """A builder to emit instructions and build executable for the virtual machine.""" + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.ExecBuilderCreate) + + def r(self, idx): + """set instruction's argument as a register.""" + return _ffi_api.ExecBuilderR(self, idx) + + def imm(self, value): + """set instruction's argument as an immediate.""" + return _ffi_api.ExecBuilderImm(self, value) + + def c(self, idx): + """set instruction's argument as a constant.""" + return _ffi_api.ExecBuilderC(self, idx) + + def void_arg(self): + return self.r(SpecialReg.VOID_ARG) + + def vm_state(self): + return self.r(SpecialReg.VM_STATE) + + def function(self, func_name, num_inputs=0): + """annotate a VM function.""" + _ffi_api.ExecBuilderFunction(self, func_name, num_inputs) + return VMFuncScope() + + def _check_scope(self): + if len(VMFuncScope.stack) == 0: + raise ValueError("emit should happen in a function scope") + + def emit_constant(self, const): + return _ffi_api.ExecBuilderEmitConstant(self, const) + + def emit_call(self, name, args=[], dst=None): + """emit a call instruction which calls a packed function.""" + self._check_scope() + if dst is None: + dst = SpecialReg.VOID_ARG + args_ = [] + for arg in args: + if isinstance(arg, tvm.nd.NDArray): + new_arg = self.emit_constant(arg) + args_.append(new_arg) + else: + args_.append(arg) + _ffi_api.ExecBuilderEmitCall(self, name, args_, dst) + + def emit_ret(self, result): + """emit a return instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitRet(self, result) + + def get(self): + """return the executable""" + return _ffi_api.ExecBuilderGet(self) + diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py new file mode 100644 index 000000000000..ed8a1884230d --- /dev/null +++ b/python/tvm/relax/vm.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.runtime import Object +from tvm._ffi.base import _LIB, check_call +from . import _ffi_api +from ..rpc.base import RPC_SESS_MASK + + +@tvm._ffi.register_object("relax.Executable") +class Executable(Object): + """The executable object emitted by the VM compiler or the ExecBuilder.""" + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.Executable) + + def stats(self): + """print the detailed statistics of the executable.""" + return _ffi_api.ExecutableStats(self) + + def save_to_file(self, file_name): + """serialize and write the executable to a file.""" + return _ffi_api.ExecutableSaveToFile(self, file_name) + + def astext(self): + """print the instructions as text format.""" + return _ffi_api.ExecutableAsText(self) + + def aspython(self): + """print the instructions as python program.""" + return _ffi_api.ExecutableAsPython(self) + +def load_exec_from_file(file_name): + return _ffi_api.ExecutableLoadFromFile(file_name) + +class VirtualMachine(object): + """Relax VM runtime.""" + + NAIVE_ALLOCATOR = 1 + POOLED_ALLOCATOR = 2 + + def __init__(self, exec, device, memory_cfg=None, mod=None): + """ + Construct a VirtualMachine wrapper object. + + Parameters + ---------- + exec: Executable + The VM executable. + + device : tvm.runtime.Device or List[tvm.runtime.Device] + The device to deploy the module. + + memory_cfg : str or Dict[tvm.runtime.Device, str], optional + Config the type of memory allocator. The allocator type can be ["naive", + "pooled"]. If memory_cfg is None, all devices will use pooled allocator + by default. If memory_cfg is string, all devices will use the specified + allocator type. If memory_cfg is a dict, each device uses the allocator + type specified in the dict, or pooled allocator if not specified in the + dict. + + Returns + ------- + vm: VirtualMachine + A VM wrapper object. + """ + self.module = _ffi_api.VirtualMachine(exec, mod) + self._setup_device(device, memory_cfg) + + def _setup_device(self, dev, memory_cfg): + """init devices and allocators.""" + devs = dev + if not isinstance(dev, (list, tuple)): + if not isinstance(dev, tvm.runtime.Device): + raise TypeError( + "dev is expected to be Device or \ + List[Device]" + ) + devs = [dev] + + # CPU is required for executing shape functions + if not any(c.device_type % RPC_SESS_MASK == tvm.cpu().device_type for c in devs): + devs.append(tvm.cpu()) + + default_alloc_type = VirtualMachine.POOLED_ALLOCATOR + if memory_cfg is None: + memory_cfg = {} + elif isinstance(memory_cfg, str): + assert memory_cfg in ["naive", "pooled"] + if memory_cfg == "naive": + default_alloc_type = VirtualMachine.NAIVE_ALLOCATOR + memory_cfg = {} + elif not isinstance(memory_cfg, dict): + raise TypeError( + "memory_cfg is expected be string or dictionary, " + + "but received {}".format(type(memory_cfg)) + ) + init_args = [] + for device in devs: + init_args.append(device.device_type % RPC_SESS_MASK) + init_args.append(device.device_id) + alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type + init_args.append(alloc_type) + _ffi_api.VirtualMachineInit(self.module, *init_args) + + def __getitem__(self, key): + return self.module[key] diff --git a/src/relax/builder.cc b/src/relax/builder.cc new file mode 100644 index 000000000000..63661eefe63a --- /dev/null +++ b/src/relax/builder.cc @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/builder.cc + */ + +#include + +#include + +namespace tvm { +namespace relax { + +using namespace vm; + +TVM_REGISTER_NODE_TYPE(ExecBuilderNode); + +ExecBuilder ExecBuilderNode::Create() { + ExecBuilder ret(make_object()); + ret->exec = make_object(); + return ret; +} + +vm::Index ExecBuilderNode::EmitConstant(ObjectRef obj) { + vm::Index idx = exec->constants.size(); + exec->constants.push_back(obj); + return vm::Instruction::Arg(vm::Instruction::kConstIdx, idx).data; +} + +void ExecBuilderNode::Function(std::string func_name, int64_t num_inputs) { + const auto& m = exec->global_map; + ICHECK(m.find(func_name) == m.end()); + VMFunction vmfunc; + vmfunc.name = func_name; + vmfunc.start_instr = exec->instr_offset.size(); + vmfunc.num_args = num_inputs; + exec->global_map[func_name] = exec->global_funcs.size(); + exec->global_funcs.push_back(vmfunc); +} + +void ExecBuilderNode::EmitCall(std::string func, std::vector args, RegName dst) { + // store function + if (exec->func2idx.find(func) == exec->func2idx.end()) { + exec->func2idx[func] = exec->func_names.size(); + exec->func_names.push_back(func); + } + Index func_idx = exec->func2idx[func]; + // store instruction + exec->instr_offset.push_back(exec->instr_data.size()); + exec->instr_data.push_back(static_cast(Opcode::Call)); + exec->instr_data.push_back(dst); + exec->instr_data.push_back(func_idx); + exec->instr_data.push_back(args.size()); + // store arguments + std::transform(args.cbegin(), args.cend(), std::back_inserter(exec->instr_data), + [](Instruction::Arg arg) { return arg.data; }); +} + +void ExecBuilderNode::EmitRet(RegName result) { + exec->instr_offset.push_back(exec->instr_data.size()); + exec->instr_data.push_back(static_cast(Opcode::Ret)); + exec->instr_data.push_back(result); +} + +// helper function to check if an executable is legal by checking if registers are used properly +bool CheckExecutable(Executable exec) { + for (auto it = exec->global_funcs.cbegin(); it != exec->global_funcs.cend(); ++it) { + Index num_inputs = it->num_args; + std::unordered_set dst_registers; + std::unordered_set arg_registers; + size_t start_instr = it->start_instr; + size_t end_instr = exec->instr_offset.size(); + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = exec->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + for (int i = 0; i < instr.num_args; ++i) { + if (instr.args[i].kind() == Instruction::kRegister && + instr.args[i].value() == Instruction::kVMStateRegister) { + continue; + } + if (instr.args[i].kind() == Instruction::kRegister && + instr.args[i].value() >= num_inputs && + dst_registers.find(instr.args[i].value()) == dst_registers.end()) { + LOG(ERROR) << "register r(" << instr.args[i].value() << ") in VM function \"" + << it->name << "\" is used as input while the number of inputs is only " + << num_inputs << ".\n"; + return false; + } + arg_registers.emplace(instr.args[i].value()); + } + if (instr.dst != Instruction::kVoidArg) { + dst_registers.emplace(instr.dst); + } + break; + } + case Opcode::Ret: { + arg_registers.emplace(instr.result); + for (int i = 0; i < num_inputs; i++) { + if (arg_registers.find(i) == arg_registers.end()) { + LOG(WARNING) << "register r(" << i << ") in VM function \"" << it->name + << "\" is unused as input.\n"; + } + } + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + } + return true; +} + +Executable ExecBuilderNode::Get() { + CheckExecutable(Executable(this->exec)); + this->Formalize(); + return Executable(this->exec); +} + +void ExecBuilderNode::Formalize() { + // a pass to formalize user-specified register indexes in the order of use + // and decide the number of registers to allocate for each VMFunction in the Executable + for (auto it = this->exec->global_funcs.begin(); it != this->exec->global_funcs.end(); ++it) { + Index num_inputs = it->num_args; + RegName register_idx = num_inputs; + std::unordered_map register_map; + size_t start_instr = it->start_instr; + size_t end_instr = this->exec->instr_offset.size(); + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = this->exec->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + for (int i = 0; i < instr.num_args; ++i) { + if (instr.args[i].kind() == Instruction::kRegister && + register_map.find(instr.args[i].value()) != register_map.end()) { + this->exec->instr_data[this->exec->instr_offset[idx] + 4 + i] = + register_map[instr.args[i].value()]; + } + } + if (instr.dst != Instruction::kVoidArg && instr.dst >= num_inputs && + register_map.find(instr.dst) == register_map.end()) { + this->exec->instr_data[this->exec->instr_offset[idx] + 1] = register_idx; + register_map[instr.dst] = register_idx++; + } + break; + } + case Opcode::Ret: { + if (register_map.find(instr.result) != register_map.end()) { + this->exec->instr_data[this->exec->instr_offset[idx] + 1] = register_map[instr.result]; + } + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + it->register_file_size = register_idx; + } +} + +TVM_REGISTER_GLOBAL("relax.ExecBuilderCreate").set_body_typed(ExecBuilderNode::Create); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitConstant") + .set_body_typed([](ExecBuilder builder, ObjectRef obj) { return builder->EmitConstant(obj); }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderFunction") + .set_body_typed([](ExecBuilder builder, String name, int64_t num_inputs) { + return builder->Function(name, num_inputs); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") + .set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { + std::vector args_; + for (size_t i = 0; i < args.size(); ++i) { + args_.push_back(static_cast(args[i]->value)); + } + Instruction::Arg dst_(dst); + CHECK_EQ(dst_.kind(), Instruction::ArgKind::kRegister); + builder->EmitCall(name, args_, dst_.value()); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") + .set_body_typed([](ExecBuilder builder, int64_t result) { builder->EmitRet(result); }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderR").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg(Instruction::kRegister, value).data; +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderImm").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg(Instruction::kImmediate, value).data; +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderC").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg(Instruction::kConstIdx, value).data; +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { + return builder->Get(); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/vm/builtin.cc b/src/relax/vm/builtin.cc new file mode 100644 index 000000000000..b7655ba32b9e --- /dev/null +++ b/src/relax/vm/builtin.cc @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/vm/builtin.cc + * \brief + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +using tvm::runtime::NDArray; + +TVM_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_typed([](NDArray arr) { return arr.Shape(); }); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_heap").set_body_typed([](int64_t size) { + return NDArray::Empty(ShapeTuple({size}), DLDataType{kDLInt, 64, 1}, DLDevice{kDLCPU, 0}); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.match_shape") +.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) { + ShapeTuple shape = args[0]; + NDArray heap = args[1]; + int64_t* heap_data = reinterpret_cast(heap.ToDLPack()->dl_tensor.data); + for (int i = 2; i < args.size(); ++i) { + int64_t heap_idx = args[i]; + ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]); + heap_data[heap_idx] = shape[i - 2]; + } +}); + +TVM_REGISTER_GLOBAL("vm.builtin.make_shape") +.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* rv) { + NDArray heap = args[0]; + int64_t* heap_data = reinterpret_cast(heap.ToDLPack()->dl_tensor.data); + std::vector shape; + for (int i = 1; i < args.size(); ++i) { + int64_t heap_idx = args[i]; + ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]); + shape.push_back(heap_data[heap_idx]); + } + *rv = ShapeTuple(shape); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage") +.set_body_typed([](void* vm_state_ptr, Index size, Index alignment, Index device_type, + DLDataType dtype_hint) { + VMState* vm_state = static_cast(vm_state_ptr); + DLOG(INFO) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment + << ", dtype_hint=" << runtime::DLDataType2String(dtype_hint) + << ", device_type=" << device_type; + + auto storage_obj = runtime::SimpleObjAllocator().make_object(); + ICHECK_LT(static_cast(device_type), vm_state->allocators.size()) + << "Memory allocator for device " << device_type << " has not been initialized"; + auto* alloc = vm_state->allocators[device_type]; + ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; + storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint); + Storage storage(storage_obj); + return storage; +}); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor") +.set_body_typed([](Storage storage, Index offset, DLDataType dtype, ShapeTuple shape) { + auto tensor = storage->AllocNDArray(offset, shape, dtype); + return tensor; +}); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/relax/vm/bytecode.cc b/src/relax/vm/bytecode.cc new file mode 100644 index 000000000000..6da75f3893dd --- /dev/null +++ b/src/relax/vm/bytecode.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/vm/bytecode.cc + * \brief The bytecode for Relax virtual machine. + */ + +#include +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +Instruction Instruction::Call(Index func_idx, Index num_args, + Instruction::Arg* args, RegName dst) { + Instruction instr; + instr.op = Opcode::Call; + instr.dst = dst; + instr.func_idx = func_idx; + instr.num_args = num_args; + instr.args = args; + return instr; +} + +Instruction Instruction::Ret(RegName result) { + Instruction instr; + instr.op = Opcode::Ret; + instr.result = result; + return instr; +} + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/relax/vm/executable.cc b/src/relax/vm/executable.cc new file mode 100644 index 000000000000..58563df3d788 --- /dev/null +++ b/src/relax/vm/executable.cc @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/vm/executable.cc + * \brief + */ + +#include +#include +#include + +#include +#include + +#include "../../runtime/file_utils.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! \brief The magic number for the serialized VM bytecode file */ +constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; + +#define STREAM_CHECK(val, section) \ + ICHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; + +TVM_REGISTER_OBJECT_TYPE(ExecutableNode); + +std::string ExecutableNode::Stats() const { + std::ostringstream oss; + oss << "Relax VM executable statistics:" << std::endl; + + // Get the number of constants and the shape of each of them. + oss << " Constant shapes (# " << constants.size() << "): ["; + for (const auto& it : constants) { + const auto constant = Downcast(it); + const auto& shape = constant.Shape(); + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], " << std::endl; + } + if (!constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << global_funcs.size() << "): ["; + for (const auto& it : global_funcs) { + oss << it.name << ", "; + } + if (!global_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of packed funcs and the name of each of them. + oss << " Packed functions (#" << func_names.size() << "): ["; + for (const auto& it : func_names) { + oss << it << ", "; + } + if (!func_names.empty()) { + oss.seekp(-2, oss.cur); + } + oss << "]" << std::endl; + + return oss.str(); +} + +Instruction ExecutableNode::GetInstruction(Index i) const { + size_t offset = instr_offset[i]; + Opcode op = static_cast(instr_data[offset]); + switch (op) { + case Opcode::Call: { + RegName dst = instr_data[offset + 1]; + Index func_idx = instr_data[offset + 2]; + Index num_args = instr_data[offset + 3]; + ExecWord* args = const_cast(&instr_data[offset + 4]); + return Instruction::Call(func_idx, num_args, reinterpret_cast(args), dst); + } + case Opcode::Ret: { + RegName result = instr_data[offset + 1]; + return Instruction::Ret(result); + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(op); + break; + } + return Instruction(); +} + +void SaveHeader(dmlc::Stream* strm) { + uint64_t header = kTVMVMBytecodeMagic; + strm->Write(header); + std::string version = TVM_VERSION; + strm->Write(version); +} + +void LoadHeader(dmlc::Stream* strm) { + // Check header. + uint64_t header; + STREAM_CHECK(strm->Read(&header), "header"); + STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); + + // Check version. + std::string version; + STREAM_CHECK(strm->Read(&version), "version"); + STREAM_CHECK(version == TVM_VERSION, "version"); +} + +void ExecutableNode::SaveToBinary(dmlc::Stream* stream) { + std::string code; + // Initialize the stream object. + dmlc::MemoryStringStream strm(&code); + + // Save header + SaveHeader(&strm); + + // Global section. + SaveGlobalSection(&strm); + + // Constant section. + SaveConstantSection(&strm); + + // Packedfunc names section. + SavePackedFuncNames(&strm); + + // Code section. + SaveCodeSection(&strm); + + stream->Write(code); +} + +void ExecutableNode::SaveToFile(const std::string& path) { + std::string data; + dmlc::MemoryStringStream writer(&data); + dmlc::SeekStream* strm = &writer; + ExecutableNode::SaveToBinary(strm); + runtime::SaveBinaryToFile(path, data); +} + +Executable ExecutableNode::LoadFromBinary(void* stream) { + std::string code; + static_cast(stream)->Read(&code); + dmlc::MemoryStringStream strm(&code); + + auto exec = make_object(); + + // Load header. + LoadHeader(&strm); + + // Global section. + exec->LoadGlobalSection(&strm); + + // Constant section. + exec->LoadConstantSection(&strm); + + // Packedfunc names section. + exec->LoadPackedFuncNames(&strm); + + // Code section. + exec->LoadCodeSection(&strm); + + return Executable(exec); +} + +Executable ExecutableNode::LoadFromFile(const std::string& file_name) { + std::string data; + runtime::LoadBinaryFromFile(file_name, &data); + dmlc::MemoryStringStream reader(&data); + dmlc::Stream* strm = &reader; + auto exec = ExecutableNode::LoadFromBinary(reinterpret_cast(strm)); + return exec; +} + +void SerializeVMFunc(const VMFunction& func, dmlc::Stream* strm) { + strm->Write(func.name); + strm->Write(func.start_instr); + strm->Write(func.num_args); + strm->Write(func.register_file_size); +} + +VMFunction DeserializeVMFunc(dmlc::Stream* strm) { + VMFunction func; + STREAM_CHECK(strm->Read(&func.name), "vmfunc name"); + STREAM_CHECK(strm->Read(&func.start_instr), "vmfunc start_instr"); + STREAM_CHECK(strm->Read(&func.num_args), "vmfunc num_args"); + STREAM_CHECK(strm->Read(&func.register_file_size), "vmfunc register_file_size"); + return func; +} + +void ExecutableNode::SaveGlobalSection(dmlc::Stream* strm) { + strm->Write(static_cast(this->global_funcs.size())); + for (const auto& func : this->global_funcs) { + SerializeVMFunc(func, strm); + } +} + +void ExecutableNode::SaveConstantSection(dmlc::Stream* strm) { + std::vector arrays; + for (const auto& obj : this->constants) { + const auto cell = Downcast(obj); + arrays.push_back(const_cast(cell.operator->())); + } + strm->Write(static_cast(this->constants.size())); + for (const auto& it : arrays) { + runtime::SaveDLTensor(strm, it); + } +} + +void ExecutableNode::SavePackedFuncNames(dmlc::Stream* strm) { strm->Write(func_names); } + +void ExecutableNode::SaveCodeSection(dmlc::Stream* strm) { + strm->Write(instr_offset); + strm->Write(instr_data); +} + +void ExecutableNode::LoadGlobalSection(dmlc::Stream* strm) { + uint64_t sz; + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); + size_t size = static_cast(sz); + for (size_t i = 0; i < size; i++) { + VMFunction func = DeserializeVMFunc(strm); + this->global_funcs.push_back(func); + } + for (size_t i = 0; i < global_funcs.size(); ++i) { + this->global_map[global_funcs[i].name] = i; + } +} + +void ExecutableNode::LoadConstantSection(dmlc::Stream* strm) { + uint64_t sz; + // Load the number of constants. + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); + + size_t size = static_cast(sz); + // Load each of the constants. + for (size_t i = 0; i < size; i++) { + runtime::NDArray constant; + STREAM_CHECK(constant.Load(strm), "constant"); + this->constants.push_back(constant); + } +} + +void ExecutableNode::LoadPackedFuncNames(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&(this->func_names)), "packed func names"); + for (size_t i = 0; i < func_names.size(); ++i) { + this->func2idx[func_names[i]] = i; + } +} + +void ExecutableNode::LoadCodeSection(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&(this->instr_offset)), "instr offset"); + STREAM_CHECK(strm->Read(&(this->instr_data)), "instr data"); +} + +template +std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ", + std::function repr = std::to_string) { + if (cnt == 0) { + return ""; + } + std::ostringstream oss; + oss << repr(items[offset]); + for (int i = 1; i < cnt; ++i) { + oss << delim << repr(items[offset + i]); + } + return oss.str(); +} + +std::string RegNameToStr(RegName reg) { + if (reg == Instruction::kVoidArg) { + return "void"; + } + if (reg == Instruction::kVMStateRegister) { + return "%state"; + } + return "%" + std::to_string(reg); +} + +std::string InstrArgToStr(Instruction::Arg arg) { + // only for argument + switch (arg.kind()) { + case Instruction::kRegister: + return RegNameToStr(arg.value()); + case Instruction::kImmediate: + return "i" + std::to_string(arg.value()); + case Instruction::kConstIdx: + return "c[" + std::to_string(arg.value()) + "]"; + default: + LOG(FATAL) << "Wrong instruction kind: " << arg.kind(); + return ""; + } +} + +std::string InstrArgToPyStr(Instruction::Arg arg) { + switch (arg.kind()) { + case Instruction::kRegister: + if (arg.value() == Instruction::kVMStateRegister) { + return "ib.r(state)"; + } + return "ib.r(" + std::to_string(arg.value()) + ")"; + case Instruction::kImmediate: + return "ib.imm(" + std::to_string(arg.value()) + ")"; + case Instruction::kConstIdx: + return "ib.c(" + std::to_string(arg.value()) + ")"; + default: + LOG(FATAL) << "Wrong instruction kind: " << arg.kind(); + return ""; + } +} + +String ExecutableNode::AsText() const { + // print the text format + std::ostringstream os; + for (size_t fidx = 0; fidx < this->global_funcs.size(); ++fidx) { + const VMFunction& gfunc = this->global_funcs[fidx]; + os << "@" << gfunc.name << ":\n"; + size_t start_instr = gfunc.start_instr; + size_t end_instr = this->instr_offset.size(); + if ((fidx + 1) < global_funcs.size()) { + end_instr = global_funcs[fidx + 1].start_instr; + } + for (size_t idx = start_instr; idx < end_instr; ++idx) { + os << " "; + Instruction instr = this->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + os << std::setw(6) << std::left << "call" << std::setw(16) << std::left + << this->func_names[instr.func_idx] << " in: " << std::setw(12) << std::left + << StrJoin(instr.args, 0, instr.num_args, ", ", InstrArgToStr) + << " dst: " << RegNameToStr(instr.dst) << "\n"; + break; + } + case Opcode::Ret: { + os << std::setw(6) << std::left << "ret" + << "ret " << RegNameToStr(instr.result) << "\n"; + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + os << "\n"; + } + return String(os.str()); +} + +String ExecutableNode::AsPython() const { + // print the python format + std::ostringstream os; + os << "ib = rx.Builder()\n"; + for (size_t fidx = 0; fidx < this->global_funcs.size(); ++fidx) { + const VMFunction& gfunc = this->global_funcs[fidx]; + os << "with ib.function(\"" << gfunc.name << "\", num_inputs=" << gfunc.num_args << "):\n"; + size_t start_instr = gfunc.start_instr; + size_t end_instr = this->instr_offset.size(); + if ((fidx + 1) < global_funcs.size()) { + end_instr = global_funcs[fidx + 1].start_instr; + } + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = this->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + os << " ib.emit_call(\"" << this->func_names[instr.func_idx] << "\", args=[" + << StrJoin(instr.args, 0, instr.num_args, ", ", InstrArgToPyStr) + << "]"; + if (instr.dst != Instruction::kVoidArg) os << ", ret=ib.r(" << instr.dst << ")"; + os << ")\n"; + break; + } + case Opcode::Ret: { + os << " ib.emit_ret(ib.r(" << instr.result << "))\n"; + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + } + return String(os.str()); +} + +TVM_REGISTER_GLOBAL("relax.Executable").set_body_typed([]() { return Executable(); }); + +TVM_REGISTER_GLOBAL("relax.ExecutableStats").set_body_typed([](Executable exec) { + return exec->Stats(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecutableAsText").set_body_typed([](Executable exec) { + return exec->AsText(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecutableAsPython").set_body_typed([](Executable exec) { + return exec->AsPython(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecutableSaveToFile") + .set_body_typed([](Executable exec, std::string file_name) { + return exec->SaveToFile(file_name); + }); + +TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed([](std::string file_name) { + return ExecutableNode::LoadFromFile(file_name); +}); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/relax/vm/memory_manager.cc b/src/relax/vm/memory_manager.cc new file mode 100644 index 000000000000..7c4a8e1e19fc --- /dev/null +++ b/src/relax/vm/memory_manager.cc @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/vm/memory_manager.cc + * \brief Allocate and manage memory for the Relay VM. + */ +#include + +#include +#include + +#include "naive_allocator.h" +#include "pooled_allocator.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +static void BufferDeleter(Object* obj) { + auto* ptr = static_cast(obj); + ICHECK(ptr->manager_ctx != nullptr); + Buffer* buffer = reinterpret_cast(ptr->manager_ctx); + MemoryManager::GetAllocator(buffer->device)->Free(*(buffer)); + delete buffer; + delete ptr; +} + +void StorageObj::Deleter(Object* obj) { + auto* ptr = static_cast(obj); + // When invoking AllocNDArray we don't own the underlying allocation + // and should not delete the buffer, but instead let it be reclaimed + // by the storage object's destructor. + // + // We did bump the reference count by 1 to keep alive the StorageObj + // allocation in case this NDArray is the sole owner. + // + // We decrement the object allowing for the buffer to release our + // reference count from allocation. + StorageObj* storage = reinterpret_cast(ptr->manager_ctx); + storage->DecRef(); + delete ptr; +} + +inline void VerifyDataType(DLDataType dtype) { + ICHECK_GE(dtype.lanes, 1); + if (dtype.code == kDLFloat) { + ICHECK_EQ(dtype.bits % 8, 0); + } else { + // allow uint1 as a special flag for bool. + if (dtype.bits == 1 && dtype.code == kDLUInt) return; + ICHECK_EQ(dtype.bits % 8, 0); + } + ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0); +} + +inline size_t GetDataAlignment(const DLTensor& arr) { + size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; + if (align < runtime::kAllocAlignment) return runtime::kAllocAlignment; + return align; +} + +runtime::NDArray StorageObj::AllocNDArray(size_t offset, ShapeTuple shape, DLDataType dtype) { + VerifyDataType(dtype); + + // crtical zone: allocate header, cannot throw + runtime::NDArray::Container* container = + new runtime::NDArray::Container(nullptr, shape, dtype, this->buffer.device); + + container->SetDeleter(StorageObj::Deleter); + size_t needed_size = runtime::GetDataSize(container->dl_tensor); + this->IncRef(); + // The manager context pointer must continue to point to the storage object + // which owns the backing memory, and keeps track of the reference count. + // + // When we free a container we extract the storage object, decrement its + // reference count, then destroy the container, but leave the underlying + // buffer intact. + container->manager_ctx = reinterpret_cast(this); + + // is this UB? + // The only change we make w.r.t offset is modifying the data pointer + // of the backing tensor to point into the buffer instead of its start. + auto offset_ptr = reinterpret_cast(this->buffer.data) + offset; + container->dl_tensor.data = reinterpret_cast(offset_ptr); + + runtime::NDArray ret(runtime::GetObjectPtr(container)); + // RAII in effect, now run the check. + + ICHECK(offset + needed_size <= this->buffer.size) + << "storage allocation failure, attempted to allocate " << needed_size << " at offset " + << offset << " in region that is " << this->buffer.size << "bytes"; + + return ret; +} + +MemoryManager* MemoryManager::Global() { + // NOTE: explicitly use new to avoid exit-time destruction of global state + // Global state will be recycled by OS as the process exits. + static auto* inst = new MemoryManager(); + return inst; +} + +Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + if (m->allocators_.find(dev) == m->allocators_.end()) { + std::unique_ptr alloc; + switch (type) { + case kNaive: { + DLOG(INFO) << "New naive allocator for " << runtime::DeviceName(dev.device_type) << "(" + << dev.device_id << ")"; + alloc.reset(new NaiveAllocator(dev)); + break; + } + case kPooled: { + DLOG(INFO) << "New pooled allocator for " << runtime::DeviceName(dev.device_type) << "(" + << dev.device_id << ")"; + alloc.reset(new PooledAllocator(dev)); + break; + } + default: + LOG(FATAL) << "Unknown allocator type: " << type; + } + auto ret = alloc.get(); + m->allocators_.emplace(dev, std::move(alloc)); + return ret; + } + auto alloc = m->allocators_.at(dev).get(); + if (alloc->type() != type) { + LOG(WARNING) << "The type of existing allocator for " << runtime::DeviceName(dev.device_type) + << "(" << dev.device_id << ") is different from the request type (" + << alloc->type() << " vs " << type << ")"; + } + return alloc; +} + +Allocator* MemoryManager::GetAllocator(Device dev) { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + auto it = m->allocators_.find(dev); + if (it == m->allocators_.end()) { + LOG(FATAL) << "Allocator for " << runtime::DeviceName(dev.device_type) << "(" << dev.device_id + << ") has not been created yet."; + } + return it->second.get(); +} + +runtime::NDArray Allocator::Empty(std::vector shape, DLDataType dtype, DLDevice dev) { + VerifyDataType(dtype); + runtime::NDArray::Container* container = + new runtime::NDArray::Container(nullptr, shape, dtype, dev); + container->SetDeleter(BufferDeleter); + size_t size = runtime::GetDataSize(container->dl_tensor); + size_t alignment = GetDataAlignment(container->dl_tensor); + Buffer* buffer = new Buffer; + *buffer = this->Alloc(size, alignment, dtype); + container->manager_ctx = reinterpret_cast(buffer); + container->dl_tensor.data = buffer->data; + return runtime::NDArray(runtime::GetObjectPtr(container)); +} + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/relax/vm/naive_allocator.h b/src/relax/vm/naive_allocator.h new file mode 100644 index 000000000000..08a4159ec1c3 --- /dev/null +++ b/src/relax/vm/naive_allocator.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/vm/naive_allocator.h + */ +#ifndef TVM_RELAX_VM_NAIVE_ALLOCATOR_H_ +#define TVM_RELAX_VM_NAIVE_ALLOCATOR_H_ + +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class NaiveAllocator final : public Allocator { + public: + explicit NaiveAllocator(Device dev) : Allocator(kNaive), used_memory_(0), device_(dev) {} + + Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { + Buffer buf; + buf.device = device_; + buf.size = nbytes; + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, nbytes, alignment, type_hint); + used_memory_.fetch_add(nbytes, std::memory_order_relaxed); + DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + runtime::DeviceAPI::Get(device_)->FreeDataSpace(buffer.device, buffer.data); + used_memory_.fetch_sub(buffer.size, std::memory_order_relaxed); + DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B"; + } + + private: + std::atomic used_memory_; + Device device_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RELAX_VM_NAIVE_ALLOCATOR_H_ diff --git a/src/relax/vm/pooled_allocator.h b/src/relax/vm/pooled_allocator.h new file mode 100644 index 000000000000..919b84667124 --- /dev/null +++ b/src/relax/vm/pooled_allocator.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/vm/pooled_allocator.h + */ +#ifndef TVM_RELAX_VM_POOLED_ALLOCATOR_H_ +#define TVM_RELAX_VM_POOLED_ALLOCATOR_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class PooledAllocator final : public Allocator { + public: + static constexpr size_t kDefaultPageSize = 4096; + + explicit PooledAllocator(Device dev, size_t page_size = kDefaultPageSize) + : Allocator(kPooled), page_size_(page_size), used_memory_(0), device_(dev) {} + + ~PooledAllocator() { ReleaseAll(); } + + Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { + std::lock_guard lock(mu_); + size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; + auto&& it = memory_pool_.find(size); + if (it != memory_pool_.end() && !it->second.empty()) { + auto&& pool = it->second; + auto ret = pool.back(); + pool.pop_back(); + return ret; + } + Buffer buf; + buf.device = device_; + buf.size = size; + try { + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } catch (InternalError& err) { + LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message(); + LOG(WARNING) << "Trying to release all unused memory and reallocate..."; + ReleaseAll(); + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } + + used_memory_.fetch_add(size, std::memory_order_relaxed); + DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + std::lock_guard lock(mu_); + if (memory_pool_.find(buffer.size) == memory_pool_.end()) { + memory_pool_.emplace(buffer.size, std::vector{}); + } + memory_pool_.at(buffer.size).push_back(buffer); + DLOG(INFO) << "reclaim buffer " << buffer.size; + } + + private: + void ReleaseAll() { + std::lock_guard lock(mu_); + for (auto const& it : memory_pool_) { + auto const& pool = it.second; + for (auto const& buf : pool) { + runtime::DeviceAPI::Get(buf.device)->FreeDataSpace(buf.device, buf.data); + } + } + memory_pool_.clear(); + used_memory_ = 0; + DLOG(INFO) << "release all buffers"; + } + + private: + size_t page_size_; + std::atomic used_memory_; + std::unordered_map > memory_pool_; + std::recursive_mutex mu_; + Device device_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RELAX_VM_POOLED_ALLOCATOR_H_ diff --git a/src/relax/vm/vm.cc b/src/relax/vm/vm.cc new file mode 100644 index 000000000000..bbd7c13d8d71 --- /dev/null +++ b/src/relax/vm/vm.cc @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/vm/vm.cc + * \brief + */ + +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class DummyModule : public runtime::ModuleNode { + public: + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) final { + return nullptr; + } + + const char* type_key() const final { return "relax.DummyModule"; } +}; + +PackedFunc VirtualMachine::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + const auto& m = exec_->global_map; + if (m.find(name) != m.end()) { + Index gf_idx = m.at(name); + return PackedFunc([sptr_to_self, this, gf_idx](TVMArgs args, TVMRetValue* rv) { + std::vector inputs(args.size()); + for (int i = 0; i < args.size(); ++i) { + inputs[i] = args[i]; + } + *rv = this->Invoke(gf_idx, inputs); + }); + } else { + LOG(FATAL) << "Unknown function: " << name; + return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); + } +} + +void VirtualMachine::Load(Executable exec, runtime::Module mod) { + this->exec_ = exec; + this->mod_ = mod; +} + +RegType VirtualMachine::Invoke(Index gf_idx, const std::vector& args) { + const VMFunction& gfunc = exec_->global_funcs[gf_idx]; + PushFrame(this->pc_ + 1, gfunc); + // load arguments to the register file + ICHECK(static_cast(gfunc.num_args) == args.size()); + for (size_t i = 0; i < args.size(); ++i) { + WriteRegister(i, args[i]); + } + // set program counter + pc_ = gfunc.start_instr; + RunLoop(); + return return_value_; +} + +void VirtualMachine::Init(const std::vector& devices, + const std::vector& alloc_types) { + ICHECK_EQ(devices.size(), alloc_types.size()); + for (size_t i = 0; i < devices.size(); i++) { + auto dev_type = static_cast(devices[i].device_type); + auto alloc = MemoryManager::GetOrCreateAllocator(devices[i], alloc_types[i]); + if (devices_.size() <= dev_type) { + devices_.resize(dev_type + 1); + state.allocators.resize(dev_type + 1); + } + devices_[dev_type] = devices[i]; + state.allocators[dev_type] = alloc; + } +} + +void VirtualMachine::RunLoop() { + size_t start_frame = frames_.size(); + while (true) { + if (static_cast(pc_) >= exec_->instr_offset.size()) { + LOG(FATAL) << "run into invalide section"; + } + Instruction instr = exec_->GetInstruction(pc_); + switch (instr.op) { + case Opcode::Call: { + std::string func_name = exec_->func_names[instr.func_idx]; + DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << func_name; + PackedFunc func = mod_->GetFunction(func_name, true); + if (func == nullptr) { + func = *(mod_->GetFuncFromEnv(func_name)); + } + + std::vector values(instr.num_args); + std::vector tcodes(instr.num_args); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + for (Index i = 0; i < instr.num_args; ++i) { + Instruction::Arg arg = instr.args[i]; + switch (arg.kind()) { + case Instruction::kRegister: { + if (arg.value() == Instruction::kVMStateRegister) { + setter(i, &(this->state)); + } else { + setter(i, ReadRegister(arg.value())); + } + break; + } + case Instruction::kImmediate: { + setter(i, arg.value()); + break; + } + case Instruction::kConstIdx: { + setter(i, this->exec_->constants[arg.value()]); + break; + } + default: { + LOG(FATAL) << ""; + } + } + } + TVMArgs args(values.data(), tcodes.data(), values.size()); + TVMRetValue ret; + func.CallPacked(args, &ret); + if (instr.dst != Instruction::kVoidArg) { + WriteRegister(instr.dst, ret); + } + pc_++; + break; + } + case Opcode::Ret: { + // If we have hit the point from which we started + // running, we should return to the caller breaking + // the dispatch loop. + return_value_ = ReadRegister(instr.result); + auto caller_return_register = frames_.back().caller_return_register; + PopFrame(); + if (frames_.size() < start_frame) { + ICHECK(frames_.size() == start_frame - 1); + return; + } + // Otherwise we are just returning from a local call. + WriteRegister(caller_return_register, return_value_); + break; + } + } + } +} + +void VirtualMachine::PushFrame(Index ret_pc, const VMFunction& vm_func) { + auto frame = VMFrame(ret_pc, vm_func.register_file_size); + frames_.push_back(frame); +} + +void VirtualMachine::PopFrame() { + ICHECK_GT(frames_.size(), 0); + const VMFrame& fr = frames_.back(); + pc_ = fr.return_pc; + frames_.pop_back(); +} + +inline void VirtualMachine::WriteRegister(Index r, const RegType& val) { + frames_.back().register_file[r] = val; +} + +inline RegType VirtualMachine::ReadRegister(Index r) const { + return frames_.back().register_file[r]; +} + +runtime::Module CreateVirtualMachine(Executable exec, Optional mod) { + runtime::Module mod_; + if (!mod) { + mod_ = runtime::Module(make_object()); + } else { + mod_ = mod.value(); + } + auto vm = make_object(); + vm->Load(exec, mod_); + return runtime::Module(vm); +} + +TVM_REGISTER_GLOBAL("relax.VirtualMachine") + .set_body_typed([](Executable exec, Optional mod) { + return CreateVirtualMachine(exec, mod); + }); + +// initilize the VirtualMachine, takes variable-length arguments +// first argument is a runtime::Module, followed by one or more device_type, device_id, +// and the AllocatorType associated with the device. +TVM_REGISTER_GLOBAL("relax.VirtualMachineInit").set_body([](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.size() % 3, 1); + runtime::Module mod = args[0]; + auto vm = static_cast(mod.operator->()); + std::vector devices; + std::vector alloc_types; + for (int i = 0; i < args.size() / 3; ++i) { + Device dev; + int device_type = args[i * 3 + 1]; + dev.device_type = DLDeviceType(device_type); + dev.device_id = args[i * 3 + 2]; + int type = args[i * 3 + 3]; + devices.push_back(dev); + alloc_types.push_back(AllocatorType(type)); + } + vm->Init(devices, alloc_types); +}); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/tests/python/relax/test_relax_vm.py b/tests/python/relax/test_relax_vm.py new file mode 100644 index 000000000000..cf22c03741b3 --- /dev/null +++ b/tests/python/relax/test_relax_vm.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relax as rx +from tvm.runtime import container + + +@tvm.register_func("test.vm.move") +def move(src): + return src + +@tvm.register_func("test.vm.add") +def add(a, b): + ret = a.asnumpy() + b.asnumpy() + return tvm.nd.array(ret) + +@tvm.register_func("test.vm.mul") +def mul(a, b): + ret = a.asnumpy() * b.asnumpy() + return tvm.nd.array(ret) + +def test_vm_execute(): + ib = rx.ExecBuilder() + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = rx.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array(np.random.rand(4,)) + b = tvm.nd.array(np.random.rand(4,)) + add_res = vm["func0"](a, b) + np.testing.assert_allclose(add_res.asnumpy(), a.asnumpy() + b.asnumpy()) + +def test_vm_multiple_func(): + ib = rx.ExecBuilder() + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + with ib.function("func1", num_inputs=2): + ib.emit_call("test.vm.mul", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = rx.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array(np.random.rand(4,)) + b = tvm.nd.array(np.random.rand(4,)) + mul_res = vm["func1"](a, b) + add_res = vm["func0"](a, b) + np.testing.assert_allclose(add_res.asnumpy(), a.asnumpy() + b.asnumpy()) + np.testing.assert_allclose(mul_res.asnumpy(), a.asnumpy() * b.asnumpy()) + +def test_vm_serialize(): + ib = rx.ExecBuilder() + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + arr = tvm.nd.array(np.random.rand(4,)) + with ib.function("func1", num_inputs=2): + ib.emit_call("test.vm.mul", args=[ib.r(0), arr], dst=ib.r(1)) + ib.emit_ret(ib.r(1)) + exec0 = ib.get() + exec0.save_to_file("exec.bin") + exec1 = rx.load_exec_from_file("exec.bin") + assert exec0.astext() == exec1.astext() + +def test_vm_checker(): + ib = rx.ExecBuilder() + try: + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(2)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ib.get() + except ValueError as ex: + assert True + +def test_vm_formalize(): + ib0 = rx.ExecBuilder() + ib1 = rx.ExecBuilder() + with ib0.function("func0", num_inputs=2): + ib0.emit_call("test.vm.add", args=[ib0.r(0), ib0.r(1)], dst=ib0.r(100)) + ib0.emit_call("test.vm.mul", args=[ib0.r(1), ib0.r(100)], dst=ib0.r(50)) + ib0.emit_ret(ib0.r(50)) + with ib1.function("func0", num_inputs=2): + ib1.emit_call("test.vm.add", args=[ib1.r(0), ib1.r(1)], dst=ib1.r(2)) + ib1.emit_call("test.vm.mul", args=[ib1.r(1), ib1.r(2)], dst=ib1.r(3)) + ib1.emit_ret(ib1.r(3)) + exec0 = ib0.get() + exec1 = ib1.get() + assert exec0.astext() == exec1.astext() + +@tvm.register_func("test.vm.add_scalar") +def add_scalar(a, b): + return a + b + +@tvm.register_func("test.vm.get_device_id") +def get_device_id(device): + return device.device_id + +def test_vm_operand(): + ib0 = rx.ExecBuilder() + with ib0.function("func0", num_inputs=2): + ib0.emit_call("test.vm.add_scalar", args=[ib0.r(0), ib0.r(1)], dst=ib0.r(2)) + ib0.emit_ret(ib0.r(2)) + exec0 = ib0.get() + vm = rx.VirtualMachine(exec0, tvm.cpu()) + res = vm["func0"](2, 3) + assert res == 5 + + ib1 = rx.ExecBuilder() + with ib1.function("func1", num_inputs=1): + ib1.emit_call("test.vm.get_device_id", args=[ib1.r(0)], dst=ib1.r(1)) + ib1.emit_ret(ib1.r(1)) + exec1 = ib1.get() + vm = rx.VirtualMachine(exec1, tvm.cpu()) + res = vm["func1"](tvm.cpu(3)) + assert res == 3 + +def test_vm_shapeof(): + ib = rx.ExecBuilder() + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape)) + with ib.function("main", num_inputs=0): + ib.emit_call("vm.builtin.shape_of", args=[arr], dst=ib.r(0)) + ib.emit_ret(ib.r(0)) + ex = ib.get() + vm = rx.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + for i, s in enumerate(res): + assert s == shape[i] + +def test_vm_heap(): + ib = rx.ExecBuilder() + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape)) + with ib.function("main", num_inputs=0): + ib.emit_call("vm.builtin.alloc_heap", args=[ib.imm(2)], dst=ib.r(0)) + ib.emit_call("vm.builtin.shape_of", args=[arr], dst=ib.r(1)) + ib.emit_call("vm.builtin.match_shape", args=[ib.r(1), ib.r(0), ib.imm(0), ib.imm(1)]) + ib.emit_call("vm.builtin.make_shape", args=[ib.r(0), ib.imm(0), ib.imm(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = rx.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + for i, s in enumerate(res): + assert s == shape[i] + +def test_vm_storage(): + ib = rx.ExecBuilder() + with ib.function("main", num_inputs=7): + ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), ib.r(0), ib.r(1), ib.r(2), ib.r(3)], dst=ib.r(7)) + ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(7), ib.r(4), ib.r(5), ib.r(6)], dst=ib.r(8)) + ib.emit_ret(ib.r(8)) + ex = ib.get() + vm = rx.VirtualMachine(ex, tvm.cpu()) + dtype = tvm.DataType('float32') + cpu_dev = tvm.cpu().device_type + buffer_size = 24 + alignment = 8 + offset = 0 + shape = (32, 16) + shape_tuple = container.ShapeTuple(shape) + res = vm["main"](buffer_size, alignment, cpu_dev, dtype, offset, dtype, shape_tuple) + assert res.device == tvm.cpu() + assert res.shape == shape + +if __name__ == "__main__": + test_vm_execute() + test_vm_multiple_func() + test_vm_checker() + test_vm_formalize() + test_vm_operand() + test_vm_serialize() + test_vm_shapeof() + test_vm_heap() + test_vm_storage()