Skip to content

Commit

Permalink
[ROCM]:Delete the special target and fix compiler options
Browse files Browse the repository at this point in the history
runtime compiler api will only build special target if it is bind.

'--include-path' is not supported by hipcc and "-I/include/folder"
is better choice

fix ut:
        * device_code_test
        * test_code_generator
        * test_fusion_group_pass
        * test_fusion_group_op

Signed-off-by: jiajuku <jiajuku12@163.com>
  • Loading branch information
onepick committed Jul 18, 2023
1 parent 006bd95 commit 1a25fd8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
10 changes: 8 additions & 2 deletions paddle/fluid/framework/ir/fusion_group/cuda_resources.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ __device__ inline double Log(double x) { return log(x); }
__device__ inline double Sqrt(double x) { return sqrt(x); }
)";

#ifdef PADDLE_WITH_HIP
static constexpr char predefined_cuda_functions_fp16[] = R"(
__device__ inline __half Exp(const __half x) { return hexp(x); }
__device__ inline __half Log(const __half x) { return hlog(x); }
__device__ inline __half Sqrt(const __half x) { return hsqrt(x); }
)";
#else
// List some built-in functions of __half implemented in cuda_fp16.hpp
static constexpr char predefined_cuda_functions_fp16[] = R"(
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
Expand Down Expand Up @@ -306,7 +312,7 @@ __device__ inline __half Sqrt(const __half x) { return hsqrt(x); }
typedef __half float16;
)";

#endif
static constexpr char cuda_kernel_template_1d[] = R"(
extern "C" __global__ void $func_name($parameters) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/backends/device_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ bool GPUDeviceCode::Compile(bool include_path) {
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(
DeviceContextPool::Instance().Get(place_));
int compute_capability = dev_ctx->GetComputeCapability();
std::vector<const char*> options = {"-std=c++11", "--amdgpu-target=gfx906"};
std::vector<const char*> options = {"-std=c++11"};
std::string include_option;
if (include_path) {
std::string cuda_include_path = FindCUDAIncludePath();
if (!cuda_include_path.empty()) {
include_option = "--include-path=" + cuda_include_path;
include_option = "-I" + cuda_include_path;
options.push_back(include_option.c_str());
}
}
Expand Down

0 comments on commit 1a25fd8

Please sign in to comment.