Skip to content

Commit

Permalink
[VM][DMLC] Lower memory usage when loading and dumping weights (#13877)
Browse files Browse the repository at this point in the history
* initial commit

* update additional use cases

* typo

* asf header, summary

* clean up

* lint

* move code to src/runtime/file_utils.h

* file utils is cool
  • Loading branch information
AndrewZhaoLuo authored Feb 2, 2023
1 parent f0ea9e4 commit 9008ec2
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
37 changes: 37 additions & 0 deletions src/runtime/file_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,43 @@ std::string SaveParams(const Map<String, NDArray>& params);
* \param params Parameters to save.
*/
void SaveParams(dmlc::Stream* strm, const Map<String, NDArray>& params);

/*!
* \brief A dmlc stream which wraps standard file operations.
*/
struct SimpleBinaryFileStream : public dmlc::Stream {
public:
SimpleBinaryFileStream(const std::string& path, std::string mode) {
const char* fname = path.c_str();

CHECK(mode == "wb" || mode == "rb") << "Only allowed modes are 'wb' and 'rb'";
read_ = mode == "rb";
fp_ = std::fopen(fname, mode.c_str());
CHECK(fp_ != nullptr) << "Unable to open file " << path;
}
virtual ~SimpleBinaryFileStream(void) { this->Close(); }
virtual size_t Read(void* ptr, size_t size) {
CHECK(read_) << "File opened in write-mode, cannot read.";
CHECK(fp_ != nullptr) << "File is closed";
return std::fread(ptr, 1, size, fp_);
}
virtual void Write(const void* ptr, size_t size) {
CHECK(!read_) << "File opened in read-mode, cannot write.";
CHECK(fp_ != nullptr) << "File is closed";
CHECK(std::fwrite(ptr, 1, size, fp_) == size) << "SimpleBinaryFileStream.Write incomplete";
}
inline void Close(void) {
if (fp_ != nullptr) {
std::fclose(fp_);
fp_ = nullptr;
}
}

private:
std::FILE* fp_ = nullptr;
bool read_;
}; // class SimpleBinaryFileStream

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_FILE_UTILS_H_
22 changes: 6 additions & 16 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,8 @@ void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byt
}

void Executable::MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit) {
std::string bytes;
dmlc::MemoryStringStream stream(&bytes);
tvm::runtime::SimpleBinaryFileStream stream(path, "wb");
MoveLateBoundConstantsToStream(&stream, byte_limit);
SaveBinaryToFile(path, bytes);
}

void Executable::LoadLateBoundConstantsFromStream(dmlc::Stream* stream) {
Expand Down Expand Up @@ -381,9 +379,7 @@ void Executable::LoadLateBoundConstantsFromMap(Map<String, NDArray> map) {
}

void Executable::LoadLateBoundConstantsFromFile(const std::string& path) {
std::string bytes;
LoadBinaryFromFile(path, &bytes);
dmlc::MemoryStringStream stream(&bytes);
tvm::runtime::SimpleBinaryFileStream stream(path, "rb");
LoadLateBoundConstantsFromStream(&stream);
}

Expand Down Expand Up @@ -1063,22 +1059,16 @@ Module ExecutableLoadBinary(void* strm) {
}

void Executable::SaveToFile(const std::string& path, const std::string& format) {
std::string data;
dmlc::MemoryStringStream writer(&data);
dmlc::SeekStream* strm = &writer;
SaveToBinary(strm);
SaveBinaryToFile(path, data);
tvm::runtime::SimpleBinaryFileStream stream(path, "wb");
SaveToBinary(&stream);
}

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable").set_body_typed(ExecutableLoadBinary);

// Load module from module.
Module ExecutableLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
LoadBinaryFromFile(file_name, &data);
dmlc::MemoryStringStream reader(&data);
dmlc::Stream* strm = &reader;
auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(strm));
tvm::runtime::SimpleBinaryFileStream stream(file_name, "rb");
auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(&stream));
return exec;
}

Expand Down

0 comments on commit 9008ec2

Please sign in to comment.