Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#38 from mthreads/optimize_build_musa
Browse files Browse the repository at this point in the history
Optimize build musa for phi/backends/device_code.cc
  • Loading branch information
caizhi-mt authored and mt-robot committed Aug 14, 2023
2 parents b3c73b9 + cdd641f commit a455707
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 16 deletions.
103 changes: 91 additions & 12 deletions paddle/phi/backends/device_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ limitations under the License. */
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/flags.h"
#ifdef PADDLE_WITH_MUSA
#include <musa.h>
#endif

PHI_DECLARE_string(cuda_dir);

namespace phi {
Expand Down Expand Up @@ -140,7 +138,8 @@ void GPUDeviceCode::CheckAvailableStatus() {
hiprtcResult nvrtc_result =
dynload::hiprtcVersion(&nvrtc_major, &nvrtc_minor);
#elif defined(PADDLE_WITH_MUSA)
// TODO(@caizhi): enable dynload module
mtrtcResult nvrtc_result =
dynload::mtrtcVersion(&nvrtc_major, &nvrtc_minor);
#else
nvrtcResult nvrtc_result = dynload::nvrtcVersion(&nvrtc_major, &nvrtc_minor);
#endif
Expand Down Expand Up @@ -168,8 +167,7 @@ void GPUDeviceCode::CheckAvailableStatus() {
#ifdef PADDLE_WITH_HIP
if (nvrtc_result != HIPRTC_SUCCESS || driver_result != hipSuccess) {
#elif defined(PADDLE_WITH_MUSA)
// TODO(@caizhi): enable dynload module
if (false) {
if (nvrtc_result != MTRTC_SUCCESS || driver_result != MUSA_SUCCESS) {
#else
if (nvrtc_result != NVRTC_SUCCESS || driver_result != CUDA_SUCCESS) {
#endif
Expand Down Expand Up @@ -343,11 +341,85 @@ bool GPUDeviceCode::Compile(bool include_path) {
return false;
}
#elif defined(PADDLE_WITH_MUSA)
// TODO(@caizhi): enable dynload module
mtrtcProgram program;
if (!CheckNVRTCResult(dynload::mtrtcCreateProgram(&program,
kernel_.c_str(), // buffer
name_.c_str(), // name
0, // numHeaders
nullptr, // headers
nullptr), // includeNames
"mtrtcCreateProgram")) {
return false;
}

// Compile the program for specified compute_capability
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
DeviceContextPool::Instance().Get(place_));
is_compiled_ = false;
return false;
int compute_capability = dev_ctx->GetComputeCapability();
std::string compute_flag =
"--gpu-architecture=compute_" + std::to_string(compute_capability);
std::vector<const char*> options = {"--std=c++11", compute_flag.c_str()};
std::string include_option;
if (include_path) {
std::string cuda_include_path = FindCUDAIncludePath();
if (!cuda_include_path.empty()) {
include_option = "--include-path=" + cuda_include_path;
options.push_back(include_option.c_str());
}
}
mtrtcResult compile_result =
dynload::mtrtcCompileProgram(program, // program
options.size(), // numOptions
options.data()); // options
if (compile_result == MTRTC_ERROR_COMPILATION) {
// Obtain compilation log from the program
size_t log_size;
if (!CheckNVRTCResult(dynload::mtrtcGetProgramLogSize(program, &log_size),
"mtrtcGetProgramLogSize")) {
return false;
}
std::vector<char> log;
log.resize(log_size + 1);
if (!CheckNVRTCResult(dynload::mtrtcGetProgramLog(program, log.data()),
"nvrtcGetProgramLog")) {
return false;
}
LOG(WARNING) << "JIT compiling of MUSA code failed:"
<< "\n Kernel name: " << name_ << "\n Kernel body:\n"
<< kernel_ << "\n Compiling log: " << log.data();

return false;
}

// Obtain PTX from the program
size_t ptx_size;
if (!CheckNVRTCResult(dynload::mtrtcGetMUSASize(program, &ptx_size),
"mtrtcGetMUSASize")) {
return false;
}
ptx_.resize(ptx_size + 1);
if (!CheckNVRTCResult(dynload::mtrtcGetMUSA(program, ptx_.data()),
"mtrtcGetMUSA")) {
return false;
}

if (!CheckNVRTCResult(dynload::mtrtcDestroyProgram(&program),
"mtrtcDestroyProgram")) {
return false;
}

if (!CheckCUDADriverResult(dynload::muModuleLoadData(&module_, ptx_.data()),
"muModuleLoadData",
name_)) {
return false;
}

if (!CheckCUDADriverResult(
dynload::muModuleGetFunction(&function_, module_, name_.c_str()),
"muModuleGetFunction",
name_)) {
return false;
}
#else
nvrtcProgram program;
if (!CheckNVRTCResult(dynload::nvrtcCreateProgram(&program,
Expand Down Expand Up @@ -512,18 +584,25 @@ bool GPUDeviceCode::CheckNVRTCResult(hiprtcResult result,
}
return true;
}
#endif
#ifdef PADDLE_WITH_CUDA
#elif defined(PADDLE_WITH_MUSA)
bool GPUDeviceCode::CheckNVRTCResult(mtrtcResult result, std::string function) {
if (result != MTRTC_SUCCESS) {
LOG_FIRST_N(WARNING, 1)
<< "Call " << function << " for < " << name_
<< " > failed: " << dynload::mtrtcGetErrorString(result);
return false;
}
#else
bool GPUDeviceCode::CheckNVRTCResult(nvrtcResult result, std::string function) {
if (result != NVRTC_SUCCESS) {
LOG_FIRST_N(WARNING, 1)
<< "Call " << function << " for < " << name_
<< " > failed: " << dynload::nvrtcGetErrorString(result);
return false;
}
#endif
return true;
}
#endif
#endif

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/backends/device_code.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class GPUDeviceCode : public DeviceCode {
#ifdef PADDLE_WITH_HIP
bool CheckNVRTCResult(hiprtcResult result, std::string function);
#elif defined(PADDLE_WITH_MUSA)

bool CheckNVRTCResult(mtrtcResult result, std::string function);
#else
bool CheckNVRTCResult(nvrtcResult result, std::string function);
#endif
Expand Down
12 changes: 11 additions & 1 deletion paddle/phi/backends/dynload/musartc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/backends/dynload/musartc.h"

namespace phi {
namespace dynload {

std::once_flag musartc_dso_flag;
void* musartc_dso_handle = nullptr;

#define DEFINE_WRAP(__name) DynLoad__##__name __name

MUSARTC_ROUTINE_EACH(DEFINE_WRAP);

bool HasNVRTC() {
return false;
std::call_once(musartc_dso_flag,
[]() { musartc_dso_handle = GetNVRTCDsoHandle(); });
return musartc_dso_handle != nullptr;
}

} // namespace dynload
Expand Down
44 changes: 42 additions & 2 deletions paddle/phi/backends/dynload/musartc.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -14,11 +14,51 @@ limitations under the License. */

#pragma once

#include <mtrtc.h>

#include <mutex> // NOLINT

#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"

namespace phi {
namespace dynload {

extern std::once_flag musartc_dso_flag;
extern void* musartc_dso_handle;
extern bool HasNVRTC();

#define DECLARE_DYNAMIC_LOAD_NVRTC_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using musartc_func = decltype(&::__name); \
std::call_once(musartc_dso_flag, []() { \
musartc_dso_handle = phi::dynload::GetNVRTCDsoHandle(); \
}); \
static void* p_##__name = dlsym(musartc_dso_handle, #__name); \
return reinterpret_cast<musartc_func>(p_##__name)(args...); \
} \
}; \
extern struct DynLoad__##__name __name

/**
* include all needed musartc functions
**/
#define MUSARTC_ROUTINE_EACH(__macro) \
__macro(mtrtcVersion); \
__macro(mtrtcGetErrorString); \
__macro(mtrtcCompileProgram); \
__macro(mtrtcCreateProgram); \
__macro(mtrtcDestroyProgram); \
__macro(mtrtcGetMUSA); \
__macro(mtrtcGetMUSASize); \
__macro(mtrtcGetProgramLog); \
__macro(mtrtcGetProgramLogSize)

MUSARTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NVRTC_WRAP);

#undef DECLARE_DYNAMIC_LOAD_NVRTC_WRAP

} // namespace dynload
} // namespace phi

0 comments on commit a455707

Please sign in to comment.