Skip to content

Commit

Permalink
laser fft solver for amd
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSinn committed Nov 22, 2023
1 parent a690500 commit 656883b
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions src/laser/MultiLaser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,37 @@ MultiLaser::InitData (const amrex::BoxArray& slice_ba,
CuFFTUtils::cufftErrorToString(result) << "\n";
}
#elif defined(AMREX_USE_HIP)
amrex::ignore_unused(fft_size); // TODO: fft solver on AMD
const std::size_t lengths[] = {
fft_size[0], fft_size[1]
};

rocfft_status result;
// Forward FFT plan
result = rocfft_plan_create(&(m_plan_fwd), rocfft_placement_notinplace,
rocfft_transform_type_real_forward,
#ifdef AMREX_USE_FLOAT
rocfft_precision_single,
#else
rocfft_precision_double,
#endif
2, lengths, 1, nullptr);
if ( result != rocfft_status_success ) {
amrex::Print() << " rocfft_plan_create failed! Error: " <<
result << "\n";
}
// Backward FFT plan
result = rocfft_plan_create(&(m_plan_bkw), rocfft_placement_notinplace,
rocfft_transform_type_real_forward,
#ifdef AMREX_USE_FLOAT
rocfft_precision_single,
#else
rocfft_precision_double,
#endif
2, lengths, 1, nullptr);
if ( result != rocfft_status_success ) {
amrex::Print() << " rocfft_plan_create failed! Error: " <<
result << "\n";
}
#else
// Forward FFT plan
m_plan_fwd = LaserFFT::VendorCreate(
Expand Down Expand Up @@ -818,7 +848,14 @@ MultiLaser::AdvanceSliceFFT (const Fields& fields, const amrex::Real dt, int ste
CuFFTUtils::cufftErrorToString(result) << "\n";
}
#elif defined(AMREX_USE_HIP)
amrex::Abort("Not implemented");
rocfft_execution_info execinfo = nullptr;
rocfft_status result = rocfft_execution_info_create(&execinfo);
result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());

void* in_buffer[2] = {(void*)m_rhs.dataPtr(), nullptr};
void* out_buffer[2] = {(void*)m_rhs_fourier.dataPtr(), nullptr};

result = rocfft_execute(m_plan_fwd, in_buffer, out_buffer, execinfo);
#else
LaserFFT::VendorExecute( m_plan_fwd );
#endif
Expand Down Expand Up @@ -856,7 +893,14 @@ MultiLaser::AdvanceSliceFFT (const Fields& fields, const amrex::Real dt, int ste
CuFFTUtils::cufftErrorToString(result) << "\n";
}
#elif defined(AMREX_USE_HIP)
amrex::Abort("Not implemented");
rocfft_execution_info execinfo = nullptr;
rocfft_status result = rocfft_execution_info_create(&execinfo);
result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());

void* in_buffer[2] = {(void*)m_rhs_fourier.dataPtr(), nullptr};
void* out_buffer[2] = {(void*)m_sol.dataPtr(), nullptr};

result = rocfft_execute(m_plan_bkw, in_buffer, out_buffer, execinfo);
#else
LaserFFT::VendorExecute( m_plan_bkw );
#endif
Expand Down

0 comments on commit 656883b

Please sign in to comment.