From 6863ea536a6b93a03228a5904cbecbd0f3ba8002 Mon Sep 17 00:00:00 2001 From: liujuncheng Date: Fri, 1 Apr 2022 11:37:54 +0800 Subject: [PATCH 1/2] [PersistentTable] Async write --- oneflow/core/embedding/persistent_table.cpp | 156 ++++++++++++-------- 1 file changed, 95 insertions(+), 61 deletions(-) diff --git a/oneflow/core/embedding/persistent_table.cpp b/oneflow/core/embedding/persistent_table.cpp index f50b192c55d..8929e7c19e1 100644 --- a/oneflow/core/embedding/persistent_table.cpp +++ b/oneflow/core/embedding/persistent_table.cpp @@ -227,6 +227,24 @@ class RingEngine final { } } + void AsyncPwrite(int fd, const void* buf, size_t count, off_t offset) { + if (num_readings_ == kRingQueueDepth) { + struct io_uring_cqe* cqe = nullptr; + PCHECK(io_uring_wait_cqe(&ring_, &cqe) == 0); + CHECK_GE(cqe->res, 0); + io_uring_cqe_seen(&ring_, cqe); + } else { + num_readings_ += 1; + } + io_uring_sqe* sqe = CHECK_NOTNULL(io_uring_get_sqe(&ring_)); + io_uring_prep_write(sqe, fd, buf, count, offset); + pending_submit_ += 1; + if (pending_submit_ == kRingSubmitBatch) { + PCHECK(io_uring_submit(&ring_) == pending_submit_); + pending_submit_ = 0; + } + } + void WaitUntilDone() { if (pending_submit_ > 0) { PCHECK(io_uring_submit(&ring_) == pending_submit_); @@ -278,6 +296,20 @@ class AioEngine final { num_readings_ += 1; } + void AsyncPwrite(int fd, const void* buf, size_t count, off_t offset) { + if (num_readings_ == kAioQueueDepth) { WaitUntilDone(); } + struct iocb* cb = &cbs_.at(num_readings_); + cb->aio_fildes = fd; + cb->aio_lio_opcode = IOCB_CMD_PWRITE; + cb->aio_reqprio = 0; + cb->aio_buf = reinterpret_cast(buf); + cb->aio_nbytes = count; + cb->aio_offset = offset; + const long nr = 1; + PCHECK(syscall(__NR_io_submit, ctx_, nr, &cbs_ptr_.at(num_readings_)) >= 0); + num_readings_ += 1; + } + void WaitUntilDone() { if (num_readings_ != 0) { PCHECK(syscall(__NR_io_getevents, ctx_, num_readings_, num_readings_, events_.data(), nullptr) @@ -298,19 +330,10 @@ class AioEngine final { constexpr size_t kCacheLineSize = 64; template -using ForRange = std::function; +using IoTask = std::function; template -struct ParallelForTask { - ParallelForTask(size_t num_workers, size_t total, const ForRange* for_range) - : counter(0), total(total), for_range(for_range), bc(num_workers) {} - union alignas(kCacheLineSize) { - std::atomic counter; - }; - size_t total; - const ForRange* for_range; - BlockingCounter bc; -}; +using ForRange = std::function; template class Worker final { @@ -322,29 +345,21 @@ class Worker final { thread_.join(); } - void Schedule(ParallelForTask* task) { tasks_.Send(task); } + void Schedule(IoTask task) { tasks_.Send(std::move(task)); } void Shutdown() { tasks_.Close(); } private: void PullTask() { while (true) { - ParallelForTask* task = nullptr; + IoTask task; const ChannelStatus status = tasks_.Receive(&task); if (status == ChannelStatus::kChannelStatusErrorClosed) { break; } CHECK_EQ(status, ChannelStatus::kChannelStatusSuccess); - while (true) { - const size_t start = task->counter.fetch_add(kParallelForStride, std::memory_order_relaxed); - if (start >= task->total) { break; } - const size_t next_start = start + kParallelForStride; - const size_t end = std::min(next_start, task->total); - (*task->for_range)(&engine_, start, end); - } - engine_.WaitUntilDone(); - task->bc.Decrease(); + task(&engine_); } } - Channel*> tasks_; + Channel> tasks_; Engine engine_; std::thread thread_; }; @@ -538,45 +553,51 @@ void PersistentTableImpl::PutBlocks(uint32_t num_keys, const void* physical_table_size_ += num_padded_keys; CHECK_EQ(start_index % num_values_per_block_, 0); const uint64_t start_block_id = start_index / num_values_per_block_; - for (uint64_t i = 0; i < num_keys; ++i) { - row_id_mapping_[static_cast(keys)[i]] = start_index + i; - } uint64_t written_blocks = 0; const uint64_t block_keys_size = num_values_per_block_ * sizeof(Key); - while (written_blocks < num_blocks) { - const uint64_t batch_start_block_id = start_block_id + written_blocks; - const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_; - if (batch_chunk_id == value_files_.size()) { - value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644); - } else { - CHECK_LE(batch_chunk_id, value_files_.size()); - } - if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) { - writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644); + BlockingCounter bc(1); + workers_.at(0)->Schedule([&](Engine* engine) { + while (written_blocks < num_blocks) { + const uint64_t batch_start_block_id = start_block_id + written_blocks; + const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_; + if (batch_chunk_id == value_files_.size()) { + value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644); + } else { + CHECK_LE(batch_chunk_id, value_files_.size()); + } + if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) { + writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644); + } + PosixFile& value_file = value_files_.at(batch_chunk_id); + const uint64_t block_id_in_chunk = + batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_; + const uint64_t blocks_to_write = + std::min(num_blocks - written_blocks, + (batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id); + const uint64_t values_bytes = blocks_to_write * logical_block_size_; + const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_; + CHECK_LE(value_file.Size(), values_offset_in_file); + value_file.Truncate(values_offset_in_file + values_bytes); + PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_), + values_bytes, values_offset_in_file) + == values_bytes); + const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size; + writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size); + const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_, + blocks_to_write * num_values_per_block_) + * sizeof(Key); + engine->AsyncPwrite(writable_key_file_.fd(), + BytesOffset(keys, written_blocks * block_keys_size), keys_bytes, + keys_offset_in_file); + written_blocks += blocks_to_write; } - PosixFile& value_file = value_files_.at(batch_chunk_id); - const uint64_t block_id_in_chunk = - batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_; - const uint64_t blocks_to_write = - std::min(num_blocks - written_blocks, - (batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id); - const uint64_t values_bytes = blocks_to_write * logical_block_size_; - const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_; - CHECK_LE(value_file.Size(), values_offset_in_file); - value_file.Truncate(values_offset_in_file + values_bytes); - PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_), - values_bytes, values_offset_in_file) - == values_bytes); - const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size; - writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size); - const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_, - blocks_to_write * num_values_per_block_) - * sizeof(Key); - PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size), - keys_bytes, keys_offset_in_file) - == keys_bytes); - written_blocks += blocks_to_write; + engine->WaitUntilDone(); + bc.Decrease(); + }); + for (uint64_t i = 0; i < num_keys; ++i) { + row_id_mapping_[static_cast(keys)[i]] = start_index + i; } + bc.WaitForeverUntilCntEqualZero(); } template @@ -747,9 +768,22 @@ void PersistentTableImpl::SaveSnapshot(const std::string& name) { template void PersistentTableImpl::ParallelFor(size_t total, const ForRange& for_range) { - ParallelForTask task(workers_.size(), total, &for_range); - for (size_t i = 0; i < workers_.size(); ++i) { workers_.at(i)->Schedule(&task); } - task.bc.WaitForeverUntilCntEqualZero(); + BlockingCounter bc(workers_.size()); + std::atomic counter(0); + for (size_t i = 0; i < workers_.size(); ++i) { + workers_.at(i)->Schedule([&](Engine* engine) { + while (true) { + const size_t start = counter.fetch_add(kParallelForStride, std::memory_order_relaxed); + if (start >= total) { break; } + const size_t next_start = start + kParallelForStride; + const size_t end = std::min(next_start, total); + for_range(engine, start, end); + } + engine->WaitUntilDone(); + bc.Decrease(); + }); + } + bc.WaitForeverUntilCntEqualZero(); } template From 30af1bcc6bc6cc764c8d236cb7d4ef49833bb970 Mon Sep 17 00:00:00 2001 From: liujuncheng Date: Fri, 1 Apr 2022 14:56:14 +0800 Subject: [PATCH 2/2] fix --- oneflow/core/embedding/persistent_table.cpp | 41 ++------------------- 1 file changed, 4 insertions(+), 37 deletions(-) diff --git a/oneflow/core/embedding/persistent_table.cpp b/oneflow/core/embedding/persistent_table.cpp index 8929e7c19e1..95817b2dcda 100644 --- a/oneflow/core/embedding/persistent_table.cpp +++ b/oneflow/core/embedding/persistent_table.cpp @@ -227,24 +227,6 @@ class RingEngine final { } } - void AsyncPwrite(int fd, const void* buf, size_t count, off_t offset) { - if (num_readings_ == kRingQueueDepth) { - struct io_uring_cqe* cqe = nullptr; - PCHECK(io_uring_wait_cqe(&ring_, &cqe) == 0); - CHECK_GE(cqe->res, 0); - io_uring_cqe_seen(&ring_, cqe); - } else { - num_readings_ += 1; - } - io_uring_sqe* sqe = CHECK_NOTNULL(io_uring_get_sqe(&ring_)); - io_uring_prep_write(sqe, fd, buf, count, offset); - pending_submit_ += 1; - if (pending_submit_ == kRingSubmitBatch) { - PCHECK(io_uring_submit(&ring_) == pending_submit_); - pending_submit_ = 0; - } - } - void WaitUntilDone() { if (pending_submit_ > 0) { PCHECK(io_uring_submit(&ring_) == pending_submit_); @@ -296,20 +278,6 @@ class AioEngine final { num_readings_ += 1; } - void AsyncPwrite(int fd, const void* buf, size_t count, off_t offset) { - if (num_readings_ == kAioQueueDepth) { WaitUntilDone(); } - struct iocb* cb = &cbs_.at(num_readings_); - cb->aio_fildes = fd; - cb->aio_lio_opcode = IOCB_CMD_PWRITE; - cb->aio_reqprio = 0; - cb->aio_buf = reinterpret_cast(buf); - cb->aio_nbytes = count; - cb->aio_offset = offset; - const long nr = 1; - PCHECK(syscall(__NR_io_submit, ctx_, nr, &cbs_ptr_.at(num_readings_)) >= 0); - num_readings_ += 1; - } - void WaitUntilDone() { if (num_readings_ != 0) { PCHECK(syscall(__NR_io_getevents, ctx_, num_readings_, num_readings_, events_.data(), nullptr) @@ -556,7 +524,7 @@ void PersistentTableImpl::PutBlocks(uint32_t num_keys, const void* uint64_t written_blocks = 0; const uint64_t block_keys_size = num_values_per_block_ * sizeof(Key); BlockingCounter bc(1); - workers_.at(0)->Schedule([&](Engine* engine) { + workers_.at(0)->Schedule([&](Engine*) { while (written_blocks < num_blocks) { const uint64_t batch_start_block_id = start_block_id + written_blocks; const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_; @@ -586,12 +554,11 @@ void PersistentTableImpl::PutBlocks(uint32_t num_keys, const void* const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_, blocks_to_write * num_values_per_block_) * sizeof(Key); - engine->AsyncPwrite(writable_key_file_.fd(), - BytesOffset(keys, written_blocks * block_keys_size), keys_bytes, - keys_offset_in_file); + PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size), + keys_bytes, keys_offset_in_file) + == keys_bytes); written_blocks += blocks_to_write; } - engine->WaitUntilDone(); bc.Decrease(); }); for (uint64_t i = 0; i < num_keys; ++i) {