Skip to content

[webgpu] Enable graph capture #24900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3458625
[webgpu] Add graph capture support
qjia7 May 15, 2025
08bfa2b
clean code
qjia7 May 29, 2025
ac47a39
Merge branch 'main' into graph_capture
qjia7 May 30, 2025
b2db0b8
Merge branch 'main' into graph_capture
qjia7 Jun 5, 2025
b6cf35d
don't need to execute commands in capture mode
qjia7 Jun 6, 2025
d2acf41
fix CI errors
qjia7 Jun 7, 2025
0756d5b
Add a Graph BufferCacheMode
qjia7 Jun 10, 2025
21d0243
add GraphSimple mode for uniform
qjia7 Jun 11, 2025
13995ad
nits
qjia7 Jun 11, 2025
c17dce6
Merge branch 'main' into graph_capture
qjia7 Jun 11, 2025
2bf3e6a
address comments
qjia7 Jun 11, 2025
004c9d0
fix unmap errors due to CreateUMA path
qjia7 Jun 13, 2025
66d558a
fix errors when no gpu
qjia7 Jun 13, 2025
9ffc44e
Merge branch 'main' into graph_capture
qjia7 Jun 13, 2025
6b5b9b2
fix CI errors
qjia7 Jun 13, 2025
6cdf804
fix CI errors
qjia7 Jun 16, 2025
4903c27
Merge branch 'main' into graph_capture
qjia7 Jun 16, 2025
5d1ff55
Merge branch 'main' into graph_capture
qjia7 Jun 17, 2025
e8d19ac
Merge branch 'main' into graph_capture
qjia7 Jun 20, 2025
6b88726
force use graph mode when enable_graph_capture is true
qjia7 Jun 20, 2025
55f595a
Merge branch 'main' into graph_capture
qjia7 Jun 26, 2025
87abbe6
decouple WebGpuContext and session_id
qjia7 Jun 27, 2025
f1efee2
Make GpuBufferAllocator and DataTransfer use the same buffer manager
qjia7 Jun 27, 2025
c9a2f0c
remove legacy session id
qjia7 Jun 28, 2025
7cf055e
remove useless comments
qjia7 Jun 28, 2025
df96164
fix CI errors
qjia7 Jun 28, 2025
0a29521
fix using incorrect buffer manager
qjia7 Jun 30, 2025
9ed4ade
address comments
qjia7 Jul 1, 2025
456c1ad
restore to use raw WebGPU handles
qjia7 Jul 2, 2025
11656cd
Merge branch 'main' into graph_capture
qjia7 Jul 2, 2025
151529f
remove external buffer manager in webgpucontext
qjia7 Jul 2, 2025
da83468
fix CI errors
qjia7 Jul 2, 2025
591a999
remove duplicated comments
qjia7 Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions onnxruntime/core/providers/webgpu/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "core/framework/session_state.h"
#include "core/providers/webgpu/allocator.h"
#include "core/providers/webgpu/webgpu_context.h"
#include "core/providers/webgpu/buffer_manager.h"

namespace onnxruntime {
namespace webgpu {
Expand All @@ -15,18 +15,17 @@ void* GpuBufferAllocator::Alloc(size_t size) {

stats_.num_allocs++;

#if !defined(__wasm__)
if (!session_initialized_ && context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages)) {
return context_.BufferManager().CreateUMA(size);
// Check if the buffer manager supports UMA and we're not yet in an initialized session
if (!session_initialized_ && buffer_manager_.SupportsUMA()) {
return buffer_manager_.CreateUMA(size);
}
#endif // !defined(__wasm__)

return context_.BufferManager().Create(size);
return buffer_manager_.Create(size);
}

void GpuBufferAllocator::Free(void* p) {
if (p != nullptr) {
context_.BufferManager().Release(static_cast<WGPUBuffer>(p));
buffer_manager_.Release(static_cast<WGPUBuffer>(p));
stats_.num_allocs--;
}
}
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/core/providers/webgpu/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,29 @@
namespace onnxruntime {
namespace webgpu {

class WebGpuContext;
class BufferManager;

class GpuBufferAllocator : public IAllocator {
public:
GpuBufferAllocator(const WebGpuContext& context)
GpuBufferAllocator(const BufferManager& buffer_manager)

Check warning on line 16 in onnxruntime/core/providers/webgpu/allocator.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/webgpu/allocator.h:16: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
: IAllocator(
OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0),
OrtMemTypeDefault)),
context_{context} {
buffer_manager_{buffer_manager} {
}

virtual void* Alloc(size_t size) override;
virtual void Free(void* p) override;
void GetStats(AllocatorStats* stats) override;

void OnSessionInitializationEnd();

// Return the associated BufferManager
const BufferManager& GetBufferManager() const { return buffer_manager_; }

private:
AllocatorStats stats_;
const WebGpuContext& context_;
const BufferManager& buffer_manager_;
bool session_initialized_ = false;
};

Expand Down
206 changes: 186 additions & 20 deletions onnxruntime/core/providers/webgpu/buffer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
wgpuBufferRelease(buffer);
}

void OnRefresh() override {
void OnRefresh(const SessionState& /*session_status*/) override {
// no-op
}
};
Expand All @@ -59,7 +59,7 @@
pending_buffers_.emplace_back(buffer);
}

void OnRefresh() override {
void OnRefresh(const SessionState& /*session_status*/) override {
Release();
pending_buffers_.clear();
}
Expand Down Expand Up @@ -103,7 +103,7 @@
pending_buffers_.emplace_back(buffer);
}

void OnRefresh() override {
void OnRefresh(const SessionState& /*session_status*/) override {
for (auto& buffer : pending_buffers_) {
buffers_[static_cast<size_t>(wgpuBufferGetSize(buffer))].emplace_back(buffer);
}
Expand Down Expand Up @@ -196,12 +196,9 @@
pending_buffers_.emplace_back(buffer);
}

void OnRefresh() override {
// TODO: consider graph capture. currently not supported

void OnRefresh(const SessionState& /*session_status*/) override {
for (auto& buffer : pending_buffers_) {
auto buffer_size = static_cast<size_t>(wgpuBufferGetSize(buffer));

auto it = buckets_.find(buffer_size);
if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) {
it->second.emplace_back(buffer);
Expand Down Expand Up @@ -249,6 +246,155 @@
std::vector<size_t> buckets_keys_;
};

class GraphCacheManager : public IBufferCacheManager {
public:
GraphCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} {
Initialize();
}
GraphCacheManager(std::unordered_map<size_t, size_t>&& buckets_limit) : buckets_limit_{buckets_limit} {

Check warning on line 254 in onnxruntime/core/providers/webgpu/buffer_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.cc:254: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
Initialize();
}

size_t CalculateBufferSize(size_t request_size) override {
// binary serch size
auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size);
if (it == buckets_keys_.end()) {
return NormalizeBufferSize(request_size);
} else {
return *it;
}
}

WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override {
auto it = buckets_.find(buffer_size);
if (it != buckets_.end() && !it->second.empty()) {
auto buffer = it->second.back();
it->second.pop_back();
return buffer;
}
return nullptr;
}

void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
// no-op
}

void ReleaseBuffer(WGPUBuffer buffer) override {
pending_buffers_.emplace_back(buffer);
}

void OnRefresh(const SessionState& /*session_status*/) override {
// Initialize buckets if they don't exist yet
if (buckets_.empty()) {
for (const auto& pair : buckets_limit_) {
buckets_.emplace(pair.first, std::vector<WGPUBuffer>());
}
}

for (auto& buffer : pending_buffers_) {
auto buffer_size = static_cast<size_t>(wgpuBufferGetSize(buffer));
auto it = buckets_.find(buffer_size);
if (it != buckets_.end()) {
it->second.emplace_back(buffer);
} else {
// insert a new bucket if it doesn't exist
buckets_[buffer_size] = std::vector<WGPUBuffer>{buffer};
}
}

pending_buffers_.clear();
}

~GraphCacheManager() {
for (auto& buffer : pending_buffers_) {
wgpuBufferRelease(buffer);
}
for (auto& pair : buckets_) {
for (auto& buffer : pair.second) {
wgpuBufferRelease(buffer);
}
}
}

protected:
void Initialize() {
buckets_keys_.reserve(buckets_limit_.size());
for (const auto& pair : buckets_limit_) {
buckets_keys_.push_back(pair.first);
}
std::sort(buckets_keys_.begin(), buckets_keys_.end());

Check warning on line 325 in onnxruntime/core/providers/webgpu/buffer_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for sort [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.cc:325: Add #include <algorithm> for sort [build/include_what_you_use] [4]

#ifndef NDEBUG // if debug build
ORT_ENFORCE(std::all_of(buckets_keys_.begin(), buckets_keys_.end(), [](size_t size) { return size % 16 == 0; }),
"Bucket sizes must be multiples of 16.");

for (size_t i = 1; i < buckets_keys_.size(); ++i) {
ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order.");
}
#endif
}
std::unordered_map<size_t, size_t> buckets_limit_;
std::unordered_map<size_t, std::vector<WGPUBuffer>> buckets_;

Check warning on line 337 in onnxruntime/core/providers/webgpu/buffer_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.cc:337: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
std::vector<WGPUBuffer> pending_buffers_;
std::vector<size_t> buckets_keys_;
};

class GraphSimpleCacheManager : public IBufferCacheManager {
size_t CalculateBufferSize(size_t request_size) override {
return NormalizeBufferSize(request_size);
}

WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override {
auto it = buffers_.find(buffer_size);
if (it != buffers_.end() && !it->second.empty()) {
auto buffer = it->second.back();
it->second.pop_back();
return buffer;
}

return nullptr;
}

void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
// no-op
}

void ReleaseBuffer(WGPUBuffer buffer) override {
pending_buffers_.emplace_back(buffer);
}

void OnRefresh(const SessionState& session_status) override {
for (auto& buffer : pending_buffers_) {
if (session_status == SessionState::Default) {
buffers_[static_cast<size_t>(wgpuBufferGetSize(buffer))].emplace_back(buffer);
} else {
captured_buffers_.emplace_back(buffer);
}
}
pending_buffers_.clear();
}

public:
~GraphSimpleCacheManager() {
for (auto& buffer : pending_buffers_) {
wgpuBufferRelease(buffer);
}
for (auto& pair : buffers_) {
for (auto& buffer : pair.second) {
wgpuBufferRelease(buffer);
}
}
for (auto& buffer : captured_buffers_) {
wgpuBufferRelease(buffer);
}
}

protected:
std::map<size_t, std::vector<WGPUBuffer>> buffers_;

Check warning on line 393 in onnxruntime/core/providers/webgpu/buffer_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <map> for map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.cc:393: Add #include <map> for map<> [build/include_what_you_use] [4]
std::vector<WGPUBuffer> pending_buffers_;
std::vector<WGPUBuffer> captured_buffers_;

Check warning on line 395 in onnxruntime/core/providers/webgpu/buffer_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.cc:395: Add #include <vector> for vector<> [build/include_what_you_use] [4]
};

std::unique_ptr<IBufferCacheManager> CreateBufferCacheManager(BufferCacheMode cache_mode) {
switch (cache_mode) {
case BufferCacheMode::Disabled:
Expand All @@ -259,6 +405,10 @@
return std::make_unique<SimpleCacheManager>();
case BufferCacheMode::Bucket:
return std::make_unique<BucketCacheManager>();
case BufferCacheMode::Graph:
return std::make_unique<GraphCacheManager>();
case BufferCacheMode::GraphSimple:
return std::make_unique<GraphSimpleCacheManager>();
default:
ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode");
}
Expand All @@ -278,6 +428,12 @@
case BufferCacheMode::Bucket:
os << "Bucket";
break;
case BufferCacheMode::Graph:
os << "Graph";
break;
case BufferCacheMode::GraphSimple:
os << "GraphSimple";
break;
default:
os << "Unknown(" << static_cast<int>(mode) << ")";
}
Expand All @@ -292,7 +448,7 @@
default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} {
}

void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) {
void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) const {
// If the buffer is mapped, we can directly write to it.
void* mapped_data = wgpuBufferGetMappedRange(dst, 0, WGPU_WHOLE_MAP_SIZE); // ensure the buffer is mapped
if (mapped_data) {
Expand All @@ -317,10 +473,10 @@
auto& command_encoder = context_.GetCommandEncoder();
context_.EndComputePass();
command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size);
context_.Flush();
context_.Flush(*this);
}

void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) {
void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const {
ORT_ENFORCE(src != dst, "Source and destination buffers must be different.");
EnforceBufferUnmapped(context_, src);
EnforceBufferUnmapped(context_, dst);
Expand All @@ -337,7 +493,7 @@
command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size);
}

WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) {
WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const {
auto& cache = GetCacheManager(usage);
auto buffer_size = cache.CalculateBufferSize(size);

Expand All @@ -358,7 +514,7 @@
return buffer;
}

WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) {
WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) const {
ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must be a storage buffer.");
auto& cache = GetCacheManager(usage);
auto buffer_size = cache.CalculateBufferSize(size);
Expand All @@ -378,12 +534,21 @@
return buffer;
}

void BufferManager::Release(WGPUBuffer buffer) {
bool BufferManager::SupportsUMA() const {
#if !defined(__wasm__)
// Check if the device supports the BufferMapExtendedUsages feature
return context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages);
#else
return false;
#endif // !defined(__wasm__)
}

void BufferManager::Release(WGPUBuffer buffer) const {
EnforceBufferUnmapped(context_, buffer);
GetCacheManager(buffer).ReleaseBuffer(buffer);
}

void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) {
void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const {
EnforceBufferUnmapped(context_, src);
auto buffer_size = NormalizeBufferSize(size);

Expand All @@ -395,7 +560,7 @@
auto& command_encoder = context_.GetCommandEncoder();
context_.EndComputePass();
command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size);
context_.Flush();
context_.Flush(*this);

// TODO: revise wait in whole project

Expand All @@ -405,13 +570,14 @@

auto mapped_data = staging_buffer.GetConstMappedRange();
memcpy(dst, mapped_data, size);
staging_buffer.Unmap();
}

void BufferManager::RefreshPendingBuffers() {
storage_cache_->OnRefresh();
uniform_cache_->OnRefresh();
query_resolve_cache_->OnRefresh();
default_cache_->OnRefresh();
void BufferManager::RefreshPendingBuffers(const SessionState& session_status) const {
storage_cache_->OnRefresh(session_status);
uniform_cache_->OnRefresh(session_status);
query_resolve_cache_->OnRefresh(session_status);
default_cache_->OnRefresh(session_status);
}

IBufferCacheManager& BufferManager::GetCacheManager(wgpu::BufferUsage usage) const {
Expand Down
Loading
Loading