Skip to content

Commit 8e5af29

Browse files
authored
[Sampler] Add missing sync in gpu verifier (#2262)
1 parent cfd3b2c commit 8e5af29

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

cpp/serve/function_table.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object
135135
static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
136136
static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled));
137137
this->mod_get_func = [this](const std::string& name) -> PackedFunc {
138-
return this->local_vm->GetFunction(name, false);
138+
return this->local_vm->GetFunction(name, true);
139139
};
140140
this->get_global_func = [](const std::string& name) -> PackedFunc {
141141
const auto* f = tvm::runtime::Registry::Get(name);

cpp/serve/sampler/gpu_sampler.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ class GPUSampler : public SamplerObj {
244244
token_tree_first_child_device, token_tree_next_sibling_device,
245245
uniform_samples_device, token_tree_parent_ptr_device);
246246

247+
DeviceAPI::Get(device_)->SyncStreamFromTo(device_, compute_stream_, copy_stream_);
247248
CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, copy_stream_);
248249

249250
std::vector<SampleResult> additional_sample_result;

0 commit comments

Comments
 (0)