Skip to content

Commit

Permalink
refine runtime symbol registry to hold the scalar in singleton (Paddl…
Browse files Browse the repository at this point in the history
…ePaddle#156)

to avoid out-of-date memory
  • Loading branch information
Superjomn authored Aug 6, 2020
1 parent 0ee03f7 commit 3bd9ed5
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 19 deletions.
8 changes: 4 additions & 4 deletions cinn/backends/codegen_cuda_dev_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ TEST(CodeGenCUDA, jit_host_call_cuda_kernel) {

LOG(INFO) << "fn_kernel: " << fn_kernel;

RuntimeSymbolRegistry::Global().Register("fn_kernel_ptr_", reinterpret_cast<void*>(&fn_kernel));
RuntimeSymbolRegistry::Global().Register("fn_kernel_stream_ptr_", reinterpret_cast<void*>(&stream));
RuntimeSymbolRegistry::Global().RegisterFn("fn_kernel_ptr_", reinterpret_cast<void*>(&fn_kernel));
RuntimeSymbolRegistry::Global().RegisterVar("fn_kernel_stream_ptr_", stream);

// compile host
{
Expand Down Expand Up @@ -647,8 +647,8 @@ TEST(elementwise_add, share_local_cache) {

// Register to JIT
void* stream = nullptr;
RuntimeSymbolRegistry::Global().Register("elementwise_add_kernel_ptr_", reinterpret_cast<void*>(&fn_kernel));
RuntimeSymbolRegistry::Global().Register("elementwise_add_kernel_stream_ptr_", reinterpret_cast<void*>(&stream));
RuntimeSymbolRegistry::Global().RegisterFn("elementwise_add_kernel_ptr_", reinterpret_cast<void*>(&fn_kernel));
RuntimeSymbolRegistry::Global().RegisterVar("elementwise_add_kernel_stream_ptr_", stream);

// launch the kernel

Expand Down
4 changes: 2 additions & 2 deletions cinn/backends/extern_func_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ std::ostream& operator<<(std::ostream& os, const ExternFuncID& x) {

ExternFunctionEmitterRegistry::ExternFunctionEmitterRegistry() {
// Register the runtime functions.
RuntimeSymbolRegistry::Global().Register(extern_tanh_host_repr, reinterpret_cast<void*>(__cinn_host_tanh_fp32));
RuntimeSymbolRegistry::Global().Register(extern_tanh_v_host_repr, reinterpret_cast<void*>(__cinn_host_tanh_v));
RuntimeSymbolRegistry::Global().RegisterFn(extern_tanh_host_repr, reinterpret_cast<void*>(__cinn_host_tanh_fp32));
RuntimeSymbolRegistry::Global().RegisterFn(extern_tanh_v_host_repr, reinterpret_cast<void*>(__cinn_host_tanh_v));

// tanh
Register(ExternFuncID(backend_C, extern_func__tanh), new ExternFuncEmitter_C_tanh);
Expand Down
2 changes: 1 addition & 1 deletion cinn/backends/extern_func_jit_register.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void RegisterExternFunctionHelper(const std::string &fn_name,
ExternFunctionEmitterRegistry::Global().Register(ExternFuncID{TargetToBackendRepr(target), fn_name.c_str()},
new backends::ExternFunctionLLVMEmitter(fn_name));

RuntimeSymbolRegistry::Global().Register(fn_name, reinterpret_cast<void *>(fn_ptr));
RuntimeSymbolRegistry::Global().RegisterFn(fn_name, reinterpret_cast<void *>(fn_ptr));
}

void RegisterExternFunction::End() {
Expand Down
14 changes: 7 additions & 7 deletions cinn/backends/llvm/execution_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ namespace {
bool RegisterKnownSymbols() {
decltype(auto) registry = RuntimeSymbolRegistry::Global();

registry.Register("sinf", reinterpret_cast<void *>(&sinf));
registry.Register("sin", reinterpret_cast<void *>(static_cast<double (*)(double)>(&sin)));
registry.RegisterFn("sinf", reinterpret_cast<void *>(&sinf));
registry.RegisterFn("sin", reinterpret_cast<void *>(static_cast<double (*)(double)>(&sin)));

registry.Register("cosf", reinterpret_cast<void *>(&cosf));
registry.Register("cos", reinterpret_cast<void *>(static_cast<double (*)(double)>(&cos)));
registry.RegisterFn("cosf", reinterpret_cast<void *>(&cosf));
registry.RegisterFn("cos", reinterpret_cast<void *>(static_cast<double (*)(double)>(&cos)));
return true;
}

Expand Down Expand Up @@ -232,11 +232,11 @@ TEST(ExecutionEngine, custom_runtime_symbols) {
// registry.Register("dereference_f64_ptr", (void *)+[](double *x) { return *x; });

for (size_t i = 0; i < angle.size(); i++) {
registry.Register("theta_" + std::to_string(i), reinterpret_cast<void *>(&angle[i]));
registry.RegisterVar("theta_" + std::to_string(i), angle[i]);
}

registry.Register("random_x_ptr", reinterpret_cast<void *>(&random_x));
registry.Register("random_y_ptr", reinterpret_cast<void *>(&random_y));
registry.RegisterVar("random_x_ptr", random_x);
registry.RegisterVar("random_y_ptr", random_y);
{
llvm::Type *i32_ty = builder->getInt32Ty();
llvm::FunctionType *fn_ty = llvm::FunctionType::get(i32_ty, {}, false);
Expand Down
43 changes: 42 additions & 1 deletion cinn/backends/llvm/runtime_symbol_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

#include <glog/logging.h>

#include <any>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <string_view>
#include <variant>
#include <vector>

#include "cinn/common/macros.h"

Expand All @@ -14,19 +17,57 @@ namespace backends {

class RuntimeSymbolRegistry {
public:
using value_t = std::variant<int, int32_t, int64_t, void *>;

static RuntimeSymbolRegistry &Global();

void Register(const std::string &name, void *address);
/**
* Register function address.
* @param name Name of the symbol.
* @param address Address of the function.
*/
void RegisterFn(const std::string &name, void *address) { Register(name, address); }

/**
* Register scalar.
* @tparam T Type of the scalar.
* @param name Name of the symbol.
* @param val Scalar value.
*/
template <typename T>
void RegisterVar(const std::string &name, T val) {
auto &data = scalar_holder_[name];
data.resize(sizeof(T));
memcpy(data.data(), &val, sizeof(T));
Register(name, reinterpret_cast<void *>(data.data()));
}

/**
* Lookup a symbol from the registry.
* @param name Name of the symbol.
* @return The address if existes, or nullptr will return.
*/
void *Lookup(std::string_view name) const;

/**
* Get all the symbols.
*/
const std::map<std::string, void *> &All() const { return symbols_; }

private:
/**
* Register external symbol to the registry, the symbols in the registry will finally registered to JIT .
* @param name Name of the symbol in the JIT.
* @param address The address of the variable in external space.
*/
void Register(const std::string &name, void *address);

RuntimeSymbolRegistry() = default;
CINN_DISALLOW_COPY_AND_ASSIGN(RuntimeSymbolRegistry);

mutable std::mutex mu_;
std::map<std::string, void *> symbols_;
std::map<std::string, std::vector<int8_t>> scalar_holder_;
};

} // namespace backends
Expand Down
7 changes: 3 additions & 4 deletions cinn/common/cuda_test_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ void CudaModuleTester::Compile(const lang::Module& m, const std::string& rewrite
CHECK(fn_kernel);
kernel_handles_.push_back(fn_kernel);

backends::RuntimeSymbolRegistry::Global().Register(kernel_fn_name + "_ptr_",
reinterpret_cast<void*>(&kernel_handles_.back()));
backends::RuntimeSymbolRegistry::Global().Register(kernel_fn_name + "_stream_ptr_",
reinterpret_cast<void*>(&stream_));
backends::RuntimeSymbolRegistry::Global().RegisterFn(kernel_fn_name + "_ptr_",
reinterpret_cast<void*>(&kernel_handles_.back()));
backends::RuntimeSymbolRegistry::Global().RegisterVar(kernel_fn_name + "_stream_ptr_", stream_);
}

jit_ = backends::SimpleJIT::Create();
Expand Down

0 comments on commit 3bd9ed5

Please sign in to comment.