Skip to content

Commit

Permalink
create stream
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Oct 17, 2019
1 parent 764b34c commit b4180c2
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 66 deletions.
69 changes: 49 additions & 20 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,34 +541,63 @@ class Executable : public ModuleNode {
std::vector<VMFunction> functions;

private:
/*! \brief Save the globals. */
void SaveGlobalSection();

/*! \brief Save the constant pool. */
void SaveConstantSection();
/*!
* \brief Save the globals.
*
* \param strm The input stream.
*/
void SaveGlobalSection(dmlc::Stream* strm);

/*! \brief Save primitive op names. */
void SavePrimitiveOpNames();
/*!
* \brief Save the constant pool.
*
* \param strm The input stream.
*/
void SaveConstantSection(dmlc::Stream* strm);

/*! \brief Save the vm functions. */
void SaveCodeSection();
/*!
* \brief Save primitive op names.
*
* \param strm The input stream.
*/
void SavePrimitiveOpNames(dmlc::Stream* strm);

/*! \brief Load the globals. */
void LoadGlobalSection();
/*!
* \brief Save the vm functions.
*
* \param strm The input stream.
*/
void SaveCodeSection(dmlc::Stream* strm);

/*! \brief Load the constant pool. */
void LoadConstantSection();
/*!
* \brief Load the globals.
*
* \param strm The input stream.
*/
void LoadGlobalSection(dmlc::Stream* strm);

/*! \brief Load primitive op names. */
void LoadPrimitiveOpNames();
/*!
* \brief Load the constant pool.
*
* \param strm The input stream.
*/
void LoadConstantSection(dmlc::Stream* strm);

/*! \brief Load the vm functions.*/
void LoadCodeSection();
/*!
* \brief Load primitive op names.
*
* \param strm The input stream.
*/
void LoadPrimitiveOpNames(dmlc::Stream* strm);

/*! \brief The stream used for serialization. */
dmlc::Stream* strm_;
/*!
* \brief Load the vm functions.
*
* \param strm The input stream.
*/
void LoadCodeSection(dmlc::Stream* strm);

/*! \brief The serialized code. */
/*! \brief The serialized bytecode. */
std::string code_;
};

Expand Down
100 changes: 54 additions & 46 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,34 +166,40 @@ std::string Executable::Stats() const {
return oss.str();
}

void SaveHeader(dmlc::Stream* strm) {
uint64_t header = kTVMVMBytecodeMagic;
strm->Write(header);
std::string version = TVM_VERSION;
strm->Write(version);
}

TVMByteArray Executable::Save() {
// Initialize the stream object.
strm_ = new dmlc::MemoryStringStream(&code_);
code_.clear();
dmlc::MemoryStringStream strm(&code_);

uint64_t header = kTVMVMBytecodeMagic;
strm_->Write(header);
std::string version = TVM_VERSION;
strm_->Write(version);
// Save header
SaveHeader(&strm);

// Global section.
SaveGlobalSection();
SaveGlobalSection(&strm);

// Constant section.
SaveConstantSection();
SaveConstantSection(&strm);

// Primitive names.
SavePrimitiveOpNames();
SavePrimitiveOpNames(&strm);

// Code section.
SaveCodeSection();
SaveCodeSection(&strm);

TVMByteArray arr;
arr.data = code_.c_str();
arr.size = code_.length();
return arr;
}

void Executable::SaveGlobalSection() {
void Executable::SaveGlobalSection(dmlc::Stream* strm) {
std::vector<std::pair<std::string, Index> > globals(this->global_map.begin(),
this->global_map.end());
auto comp = [](const std::pair<std::string, Index>& a,
Expand All @@ -206,24 +212,24 @@ void Executable::SaveGlobalSection() {
for (const auto& it : globals) {
glbs.push_back(it.first);
}
strm_->Write(glbs);
strm->Write(glbs);
}

void Executable::SaveConstantSection() {
void Executable::SaveConstantSection(dmlc::Stream* strm) {
std::vector<DLTensor*> arrays;
for (const auto& obj : this->constants) {
const auto* cell = obj.as<runtime::vm::TensorObj>();
CHECK(cell != nullptr);
runtime::NDArray data = cell->data;
arrays.push_back(const_cast<DLTensor*>(data.operator->()));
}
strm_->Write(static_cast<uint64_t>(this->constants.size()));
strm->Write(static_cast<uint64_t>(this->constants.size()));
for (const auto& it : arrays) {
runtime::SaveDLTensor(strm_, it);
runtime::SaveDLTensor(strm, it);
}
}

void Executable::SavePrimitiveOpNames() {
void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) {
std::vector<std::string> primitive_names;
for (const auto& it : this->primitive_map) {
auto packed_index = static_cast<size_t>(it.second);
Expand All @@ -232,7 +238,7 @@ void Executable::SavePrimitiveOpNames() {
}
primitive_names[packed_index] = it.first;
}
strm_->Write(primitive_names);
strm->Write(primitive_names);
}

// Serialize a virtual machine instruction. It creates a list that contains the
Expand Down Expand Up @@ -384,85 +390,87 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
return VMInstructionSerializer(static_cast<Index>(instr.op), fields);
}

void Executable::SaveCodeSection() {
void Executable::SaveCodeSection(dmlc::Stream* strm) {
// Save the number of functions.
strm_->Write(static_cast<uint64_t>(this->functions.size()));
strm->Write(static_cast<uint64_t>(this->functions.size()));
for (const auto& func : this->functions) {
// Save the function info.
VMFunctionSerializer func_format(func.name,
func.register_file_size,
func.instructions.size(),
func.params);
func_format.Save(strm_);
func_format.Save(strm);

// Serialize each instruction.
for (const auto& instr : func.instructions) {
const auto& serialized_instr = SerializeInstruction(instr);
serialized_instr.Save(strm_);
serialized_instr.Save(strm);
}
}
}

runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
std::shared_ptr<Executable> exec = std::make_shared<Executable>();
exec->code_ = code;
exec->lib = lib;
// Initialize the stream object.
if (exec->strm_ == nullptr) {
exec->strm_ = new dmlc::MemoryStringStream(&exec->code_);
}

void LoadHeader(dmlc::Stream* strm) {
// Check header.
uint64_t header;
STREAM_CHECK(exec->strm_->Read(&header), "header");
STREAM_CHECK(strm->Read(&header), "header");
STREAM_CHECK(header == kTVMVMBytecodeMagic, "header");

// Check version.
std::string version;
STREAM_CHECK(exec->strm_->Read(&version), "version");
STREAM_CHECK(strm->Read(&version), "version");
STREAM_CHECK(version == TVM_VERSION, "version");
}

runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
std::shared_ptr<Executable> exec = std::make_shared<Executable>();
exec->lib = lib;
exec->code_ = code;
dmlc::MemoryStringStream strm(&exec->code_);

// Load header.
LoadHeader(&strm);

// Global section.
exec->LoadGlobalSection();
exec->LoadGlobalSection(&strm);

// Constant section.
exec->LoadConstantSection();
exec->LoadConstantSection(&strm);

// Primitive names that will be invoked by `InvokePacked` instructions.
exec->LoadPrimitiveOpNames();
exec->LoadPrimitiveOpNames(&strm);

// Code section.
exec->LoadCodeSection();
exec->LoadCodeSection(&strm);

return runtime::Module(exec);
}

void Executable::LoadGlobalSection() {
void Executable::LoadGlobalSection(dmlc::Stream* strm) {
std::vector<std::string> globals;
STREAM_CHECK(strm_->Read(&globals), "global");
STREAM_CHECK(strm->Read(&globals), "global");
for (size_t i = 0; i < globals.size(); i++) {
this->global_map.insert({globals[i], i});
}
}

void Executable::LoadConstantSection() {
void Executable::LoadConstantSection(dmlc::Stream* strm) {
uint64_t sz;
// Load the number of constants.
STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant");
STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant");

size_t size = static_cast<size_t>(sz);
// Load each of the constants.
for (size_t i = 0; i < size; i++) {
runtime::NDArray constant;
STREAM_CHECK(constant.Load(strm_), "constant");
STREAM_CHECK(constant.Load(strm), "constant");
runtime::ObjectRef obj = runtime::vm::Tensor(constant);
this->constants.push_back(obj);
}
}

void Executable::LoadPrimitiveOpNames() {
void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) {
std::vector<std::string> primitive_names;
STREAM_CHECK(strm_->Read(&primitive_names), "primitive name");
STREAM_CHECK(strm->Read(&primitive_names), "primitive name");
for (size_t i = 0; i < primitive_names.size(); i++) {
this->primitive_map.insert({primitive_names[i], i});
}
Expand Down Expand Up @@ -630,24 +638,24 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
}
}

void Executable::LoadCodeSection() {
void Executable::LoadCodeSection(dmlc::Stream* strm) {
// Load the number of functions.
uint64_t sz;
STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code");
STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "code");

size_t num_funcs = static_cast<size_t>(sz);
this->functions.resize(num_funcs);
for (size_t i = 0; i < num_funcs; i++) {
// Load the function info.
VMFunctionSerializer loaded_func;
STREAM_CHECK(loaded_func.Load(strm_), "code/function");
STREAM_CHECK(loaded_func.Load(strm), "code/function");

// Load the instructions.
std::vector<Instruction> instructions;
for (size_t j = 0; j < loaded_func.num_instructions; j++) {
VMInstructionSerializer instr;
std::vector<Index> instr_fields;
STREAM_CHECK(instr.Load(strm_), "code/instruction");
STREAM_CHECK(instr.Load(strm), "code/instruction");
instructions.push_back(DeserializeInstruction(instr));
}

Expand Down

0 comments on commit b4180c2

Please sign in to comment.