From b4180c20cfcdc12e02189d853615e78fb1a9614b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 17 Oct 2019 04:09:25 +0000 Subject: [PATCH] create stream --- include/tvm/runtime/vm.h | 69 +++++++++++++++++------- src/runtime/vm/executable.cc | 100 +++++++++++++++++++---------------- 2 files changed, 103 insertions(+), 66 deletions(-) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 011826e6f1a1..a276c658c496 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -541,34 +541,63 @@ class Executable : public ModuleNode { std::vector functions; private: - /*! \brief Save the globals. */ - void SaveGlobalSection(); - - /*! \brief Save the constant pool. */ - void SaveConstantSection(); + /*! + * \brief Save the globals. + * + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); - /*! \brief Save primitive op names. */ - void SavePrimitiveOpNames(); + /*! + * \brief Save the constant pool. + * + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); - /*! \brief Save the vm functions. */ - void SaveCodeSection(); + /*! + * \brief Save primitive op names. + * + * \param strm The input stream. + */ + void SavePrimitiveOpNames(dmlc::Stream* strm); - /*! \brief Load the globals. */ - void LoadGlobalSection(); + /*! + * \brief Save the vm functions. + * + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); - /*! \brief Load the constant pool. */ - void LoadConstantSection(); + /*! + * \brief Load the globals. + * + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); - /*! \brief Load primitive op names. */ - void LoadPrimitiveOpNames(); + /*! + * \brief Load the constant pool. + * + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); - /*! \brief Load the vm functions.*/ - void LoadCodeSection(); + /*! + * \brief Load primitive op names. + * + * \param strm The input stream. + */ + void LoadPrimitiveOpNames(dmlc::Stream* strm); - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; + /*! + * \brief Load the vm functions. + * + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); - /*! \brief The serialized code. */ + /*! \brief The serialized bytecode. */ std::string code_; }; diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 8768ddf53232..21f71af4eb8c 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -166,26 +166,32 @@ std::string Executable::Stats() const { return oss.str(); } +void SaveHeader(dmlc::Stream* strm) { + uint64_t header = kTVMVMBytecodeMagic; + strm->Write(header); + std::string version = TVM_VERSION; + strm->Write(version); +} + TVMByteArray Executable::Save() { // Initialize the stream object. - strm_ = new dmlc::MemoryStringStream(&code_); + code_.clear(); + dmlc::MemoryStringStream strm(&code_); - uint64_t header = kTVMVMBytecodeMagic; - strm_->Write(header); - std::string version = TVM_VERSION; - strm_->Write(version); + // Save header + SaveHeader(&strm); // Global section. - SaveGlobalSection(); + SaveGlobalSection(&strm); // Constant section. - SaveConstantSection(); + SaveConstantSection(&strm); // Primitive names. - SavePrimitiveOpNames(); + SavePrimitiveOpNames(&strm); // Code section. - SaveCodeSection(); + SaveCodeSection(&strm); TVMByteArray arr; arr.data = code_.c_str(); @@ -193,7 +199,7 @@ TVMByteArray Executable::Save() { return arr; } -void Executable::SaveGlobalSection() { +void Executable::SaveGlobalSection(dmlc::Stream* strm) { std::vector > globals(this->global_map.begin(), this->global_map.end()); auto comp = [](const std::pair& a, @@ -206,10 +212,10 @@ void Executable::SaveGlobalSection() { for (const auto& it : globals) { glbs.push_back(it.first); } - strm_->Write(glbs); + strm->Write(glbs); } -void Executable::SaveConstantSection() { +void Executable::SaveConstantSection(dmlc::Stream* strm) { std::vector arrays; for (const auto& obj : this->constants) { const auto* cell = obj.as(); @@ -217,13 +223,13 @@ void Executable::SaveConstantSection() { runtime::NDArray data = cell->data; arrays.push_back(const_cast(data.operator->())); } - strm_->Write(static_cast(this->constants.size())); + strm->Write(static_cast(this->constants.size())); for (const auto& it : arrays) { - runtime::SaveDLTensor(strm_, it); + runtime::SaveDLTensor(strm, it); } } -void Executable::SavePrimitiveOpNames() { +void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { std::vector primitive_names; for (const auto& it : this->primitive_map) { auto packed_index = static_cast(it.second); @@ -232,7 +238,7 @@ void Executable::SavePrimitiveOpNames() { } primitive_names[packed_index] = it.first; } - strm_->Write(primitive_names); + strm->Write(primitive_names); } // Serialize a virtual machine instruction. It creates a list that contains the @@ -384,85 +390,87 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { return VMInstructionSerializer(static_cast(instr.op), fields); } -void Executable::SaveCodeSection() { +void Executable::SaveCodeSection(dmlc::Stream* strm) { // Save the number of functions. - strm_->Write(static_cast(this->functions.size())); + strm->Write(static_cast(this->functions.size())); for (const auto& func : this->functions) { // Save the function info. VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), func.params); - func_format.Save(strm_); + func_format.Save(strm); // Serialize each instruction. for (const auto& instr : func.instructions) { const auto& serialized_instr = SerializeInstruction(instr); - serialized_instr.Save(strm_); + serialized_instr.Save(strm); } } } -runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { - std::shared_ptr exec = std::make_shared(); - exec->code_ = code; - exec->lib = lib; - // Initialize the stream object. - if (exec->strm_ == nullptr) { - exec->strm_ = new dmlc::MemoryStringStream(&exec->code_); - } - +void LoadHeader(dmlc::Stream* strm) { // Check header. uint64_t header; - STREAM_CHECK(exec->strm_->Read(&header), "header"); + STREAM_CHECK(strm->Read(&header), "header"); STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); // Check version. std::string version; - STREAM_CHECK(exec->strm_->Read(&version), "version"); + STREAM_CHECK(strm->Read(&version), "version"); STREAM_CHECK(version == TVM_VERSION, "version"); +} + +runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { + std::shared_ptr exec = std::make_shared(); + exec->lib = lib; + exec->code_ = code; + dmlc::MemoryStringStream strm(&exec->code_); + + // Load header. + LoadHeader(&strm); // Global section. - exec->LoadGlobalSection(); + exec->LoadGlobalSection(&strm); // Constant section. - exec->LoadConstantSection(); + exec->LoadConstantSection(&strm); // Primitive names that will be invoked by `InvokePacked` instructions. - exec->LoadPrimitiveOpNames(); + exec->LoadPrimitiveOpNames(&strm); // Code section. - exec->LoadCodeSection(); + exec->LoadCodeSection(&strm); return runtime::Module(exec); } -void Executable::LoadGlobalSection() { +void Executable::LoadGlobalSection(dmlc::Stream* strm) { std::vector globals; - STREAM_CHECK(strm_->Read(&globals), "global"); + STREAM_CHECK(strm->Read(&globals), "global"); for (size_t i = 0; i < globals.size(); i++) { this->global_map.insert({globals[i], i}); } } -void Executable::LoadConstantSection() { +void Executable::LoadConstantSection(dmlc::Stream* strm) { uint64_t sz; // Load the number of constants. - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant"); + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); size_t size = static_cast(sz); // Load each of the constants. for (size_t i = 0; i < size; i++) { runtime::NDArray constant; - STREAM_CHECK(constant.Load(strm_), "constant"); + STREAM_CHECK(constant.Load(strm), "constant"); runtime::ObjectRef obj = runtime::vm::Tensor(constant); this->constants.push_back(obj); } } -void Executable::LoadPrimitiveOpNames() { +void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { std::vector primitive_names; - STREAM_CHECK(strm_->Read(&primitive_names), "primitive name"); + STREAM_CHECK(strm->Read(&primitive_names), "primitive name"); for (size_t i = 0; i < primitive_names.size(); i++) { this->primitive_map.insert({primitive_names[i], i}); } @@ -630,24 +638,24 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { } } -void Executable::LoadCodeSection() { +void Executable::LoadCodeSection(dmlc::Stream* strm) { // Load the number of functions. uint64_t sz; - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code"); + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "code"); size_t num_funcs = static_cast(sz); this->functions.resize(num_funcs); for (size_t i = 0; i < num_funcs; i++) { // Load the function info. VMFunctionSerializer loaded_func; - STREAM_CHECK(loaded_func.Load(strm_), "code/function"); + STREAM_CHECK(loaded_func.Load(strm), "code/function"); // Load the instructions. std::vector instructions; for (size_t j = 0; j < loaded_func.num_instructions; j++) { VMInstructionSerializer instr; std::vector instr_fields; - STREAM_CHECK(instr.Load(strm_), "code/instruction"); + STREAM_CHECK(instr.Load(strm), "code/instruction"); instructions.push_back(DeserializeInstruction(instr)); }