Skip to content

Commit

Permalink
[Runtime] Compatibility with dmlc::Stream API changes (#16998)
Browse files Browse the repository at this point in the history
* [Runtime] Compatibility with dmlc::Stream API changes

This commit updates TVM implementations of `dmlc::Stream`.  With
dmlc/dmlc-core#686, this API now requires
the `Write` method to return the number of bytes written.  This change
allows partial writes to be correctly handled.

* Update dmlc-core version

* lint fix
  • Loading branch information
Lunderberg authored May 30, 2024
1 parent 7c2c0d9 commit 820f1b6
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 27 deletions.
3 changes: 2 additions & 1 deletion src/runtime/disco/process_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ class DiscoPipeMessageQueue : private dmlc::Stream, private DiscoProtocol<DiscoP
return size;
}

void Write(const void* data, size_t size) final {
size_t Write(const void* data, size_t size) final {
size_t cur_size = write_buffer_.size();
write_buffer_.resize(cur_size + size);
std::memcpy(write_buffer_.data() + cur_size, data, size);
return size;
}

using dmlc::Stream::Read;
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/disco/threaded_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ class DiscoThreadedMessageQueue : private dmlc::Stream,
return size;
}

void Write(const void* data, size_t size) final {
size_t Write(const void* data, size_t size) final {
size_t cur_size = write_buffer_.size();
write_buffer_.resize(cur_size + size);
std::memcpy(write_buffer_.data() + cur_size, data, size);
return size;
}

using dmlc::Stream::Read;
Expand Down
8 changes: 6 additions & 2 deletions src/runtime/file_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,14 @@ struct SimpleBinaryFileStream : public dmlc::Stream {
CHECK(fp_ != nullptr) << "File is closed";
return std::fread(ptr, 1, size, fp_);
}
virtual void Write(const void* ptr, size_t size) {
virtual size_t Write(const void* ptr, size_t size) {
CHECK(!read_) << "File opened in read-mode, cannot write.";
CHECK(fp_ != nullptr) << "File is closed";
CHECK(std::fwrite(ptr, 1, size, fp_) == size) << "SimpleBinaryFileStream.Write incomplete";
size_t nwrite = std::fwrite(ptr, 1, size, fp_);
int err = std::ferror(fp_);

CHECK_EQ(err, 0) << "SimpleBinaryFileStream.Write incomplete: " << std::strerror(err);
return nwrite;
}
inline void Close(void) {
if (fp_ != nullptr) {
Expand Down
8 changes: 6 additions & 2 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
pending_request_bytes_ -= size;
return size;
}
// wriite the data to the channel.
void Write(const void* data, size_t size) final { writer_->Write(data, size); }
// write the data to the channel.
size_t Write(const void* data, size_t size) final {
writer_->Write(data, size);
return size;
}

// Number of pending bytes requests
size_t pending_request_bytes_{0};
// The ring buffer to read data from.
Expand Down
7 changes: 2 additions & 5 deletions src/runtime/rpc/rpc_socket_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,8 @@ class SimpleSockHandler : public dmlc::Stream {
// Internal supporting.
// Override methods that inherited from dmlc::Stream.
private:
size_t Read(void* data, size_t size) final {
ICHECK_EQ(sock_.RecvAll(data, size), size);
return size;
}
void Write(const void* data, size_t size) final { ICHECK_EQ(sock_.SendAll(data, size), size); }
size_t Read(void* data, size_t size) final { return sock_.Recv(data, size); }
size_t Write(const void* data, size_t size) final { return sock_.Send(data, size); }

// Things of current class.
private:
Expand Down
5 changes: 3 additions & 2 deletions src/support/base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class Base64InStream : public dmlc::Stream {
}
return size - tlen;
}
virtual void Write(const void* ptr, size_t size) {
size_t Write(const void* ptr, size_t size) final {
LOG(FATAL) << "Base64InStream do not support write";
}

Expand All @@ -229,7 +229,7 @@ class Base64OutStream : public dmlc::Stream {

using dmlc::Stream::Write;

void Write(const void* ptr, size_t size) final {
size_t Write(const void* ptr, size_t size) final {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char* cptr = static_cast<const unsigned char*>(ptr);
Expand All @@ -247,6 +247,7 @@ class Base64OutStream : public dmlc::Stream {
buf__top_ = 0;
}
}
return size;
}
virtual size_t Read(void* ptr, size_t size) {
LOG(FATAL) << "Base64OutStream do not support read";
Expand Down
24 changes: 11 additions & 13 deletions src/support/pipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class Pipe : public dmlc::Stream {
* \param size block size
* \return the size of data read
*/
void Write(const void* ptr, size_t size) final {
if (size == 0) return;
size_t Write(const void* ptr, size_t size) final {
if (size == 0) return 0;
#ifdef _WIN32
auto fwrite = [&]() -> ssize_t {
DWORD nwrite;
Expand All @@ -124,18 +124,16 @@ class Pipe : public dmlc::Stream {
DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite, GetLastErrorCode));
ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << GetLastError();
#else
while (size) {
ssize_t nwrite =
RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode);
ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno);

ICHECK_GT(nwrite, 0) << "Was unable to write any data to pipe";
ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, "
<< "but only expected to write " << size << " bytes";
size -= nwrite;
ptr = static_cast<const char*>(ptr) + nwrite;
}
ssize_t nwrite =
RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode);
ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno);

ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, "
<< "but only expected to write " << size << " bytes";

#endif

return nwrite;
}
/*!
* \brief Flush the pipe;
Expand Down

0 comments on commit 820f1b6

Please sign in to comment.