From 9c1db82e66b7b2afed0c33abbba8adaf5d6d9a3f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 1 May 2024 15:28:47 -0700 Subject: [PATCH] [Sampler] Add missing sync in gpu verifier --- cpp/serve/function_table.cc | 2 +- cpp/serve/sampler/gpu_sampler.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 16db4a8a03..bdf28dfdb5 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -135,7 +135,7 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object static_cast(tvm::runtime::memory::AllocatorType::kPooled), static_cast(kDLCPU), 0, static_cast(tvm::runtime::memory::AllocatorType::kPooled)); this->mod_get_func = [this](const std::string& name) -> PackedFunc { - return this->local_vm->GetFunction(name, false); + return this->local_vm->GetFunction(name, true); }; this->get_global_func = [](const std::string& name) -> PackedFunc { const auto* f = tvm::runtime::Registry::Get(name); diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index a1c7a308bc..36cb6e5c0a 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -244,6 +244,7 @@ class GPUSampler : public SamplerObj { token_tree_first_child_device, token_tree_next_sibling_device, uniform_samples_device, token_tree_parent_ptr_device); + DeviceAPI::Get(device_)->SyncStreamFromTo(device_, compute_stream_, copy_stream_); CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, copy_stream_); std::vector additional_sample_result;