Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions ucm/store/device/ibuffered_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,37 @@
#define UNIFIEDCACHE_IBUFFERED_DEVICE_H

#include "idevice.h"
#include "thread/index_pool.h"

namespace UC {

class IBufferedDevice : public IDevice {
class LinearBuffer {
std::shared_ptr<std::byte> addr_{nullptr};
size_t index_{0};
size_t number_{0};
size_t size_{0};

public:
void Setup(std::shared_ptr<std::byte> addr, const size_t number, const size_t size)
{
this->addr_ = addr;
this->number_ = number;
this->size_ = size;
this->Reset();
}
void Reset() noexcept { this->index_ = 0; }
bool Full() const noexcept { return this->index_ == this->number_; }
bool Available(const size_t size) const noexcept { return this->size_ >= size; }
std::shared_ptr<std::byte> Get() noexcept
{
auto addr = this->addr_.get();
auto buffer = addr + this->size_ * this->index_;
++this->index_;
return std::shared_ptr<std::byte>(buffer, [](auto) {});
}
};
LinearBuffer buffer_;

public:
IBufferedDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
: IDevice{deviceId, bufferSize, bufferNumber}
Expand All @@ -39,26 +65,20 @@ class IBufferedDevice : public IDevice {
{
auto totalSize = this->bufferSize * this->bufferNumber;
if (totalSize == 0) { return Status::OK(); }
this->_addr = this->MakeBuffer(totalSize);
if (!this->_addr) { return Status::OutOfMemory(); }
this->_indexPool.Setup(this->bufferNumber);
auto addr = this->MakeBuffer(totalSize);
if (!addr) { return Status::OutOfMemory(); }
this->buffer_.Setup(addr, this->bufferNumber, this->bufferSize);
return Status::OK();
}
virtual std::shared_ptr<std::byte> GetBuffer(const size_t size) override
{
if (!this->_addr || size > this->bufferSize) { return this->MakeBuffer(size); }
auto idx = this->_indexPool.Acquire();
if (idx != IndexPool::npos) {
auto ptr = this->_addr.get() + this->bufferSize * idx;
return std::shared_ptr<std::byte>(ptr,
[this, idx](auto) { this->_indexPool.Release(idx); });
if (this->buffer_.Full()) {
auto status = this->Synchronized();
if (status.Failure()) { return nullptr; }
this->buffer_.Reset();
}
return this->MakeBuffer(size);
return this->buffer_.Available(size) ? this->buffer_.Get() : this->MakeBuffer(size);
}

private:
std::shared_ptr<std::byte> _addr{nullptr};
IndexPool _indexPool;
};

} // namespace UC
Expand Down
36 changes: 36 additions & 0 deletions ucm/store/test/e2e/nfsstore_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,39 @@ def embed(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Ten
store.commit(hashes, True)


def fetch(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]):
founds = store.lookup(hashes)
for found in founds:
assert found
block_ids = []
offsets = []
layers = []
for hash_id, block in zip(hashes, tensors):
offset = 0
for layer in block:
block_ids.append(hash_id)
offsets.append(offset)
layers.append(layer)
offset += layer.untyped_storage().size()
task = store.load(block_ids, offsets, layers)
assert task.task_id > 0
ret = store.wait(task)
assert ret == 0


def cmp_and_print_diff(a, b, rtol=0.0, atol=0.0):
for r, (row_a, row_b) in enumerate(zip(a, b)):
for c, (ta, tb) in enumerate(zip(row_a, row_b)):
if not torch.allclose(ta, tb, rtol=rtol, atol=atol):
mask = ~torch.isclose(ta, tb, rtol=rtol, atol=atol)
diff_a = ta[mask].cpu()
diff_b = tb[mask].cpu()
print(f"DIFF at [{r}][{c}] total {mask.sum().item()} element(s)")
print(" a val:", diff_a.flatten())
print(" b val:", diff_b.flatten())
assert False


def store_all_hashes(hashes):
kvcache_block_hashes_file = "kvcache_block_hashes.txt"
current_directory = os.path.dirname(__file__)
Expand Down Expand Up @@ -108,7 +141,10 @@ def main():
for batch in range(total_batches):
start = batch_size * batch
end = min(start + batch_size, block_number)
tensors2 = [[torch.empty_like(t) for t in row] for row in tensors]
embed(store, hashes[start:end], tensors)
fetch(store, hashes[start:end], tensors2)
cmp_and_print_diff(tensors, tensors2)
store_all_hashes(hashes)


Expand Down