Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Laser fft solver for amd #1042

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions src/laser/MultiLaser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,37 @@ MultiLaser::InitData (const amrex::BoxArray& slice_ba,
CuFFTUtils::cufftErrorToString(result) << "\n";
}
#elif defined(AMREX_USE_HIP)
amrex::ignore_unused(fft_size); // TODO: fft solver on AMD
const std::size_t lengths[] = {
fft_size[0], fft_size[1]
};

rocfft_status result;
// Forward FFT plan
result = rocfft_plan_create(&(m_plan_fwd), rocfft_placement_notinplace,
rocfft_transform_type_real_forward,
#ifdef AMREX_USE_FLOAT
rocfft_precision_single,
#else
rocfft_precision_double,
#endif
2, lengths, 1, nullptr);
if ( result != rocfft_status_success ) {
amrex::Print() << " rocfft_plan_create failed! Error: " <<
result << "\n";
}
// Backward FFT plan
result = rocfft_plan_create(&(m_plan_bkw), rocfft_placement_notinplace,
rocfft_transform_type_real_forward,
#ifdef AMREX_USE_FLOAT
rocfft_precision_single,
#else
rocfft_precision_double,
#endif
2, lengths, 1, nullptr);
if ( result != rocfft_status_success ) {
amrex::Print() << " rocfft_plan_create failed! Error: " <<
result << "\n";
}
#else
// Forward FFT plan
m_plan_fwd = LaserFFT::VendorCreate(
Expand Down Expand Up @@ -818,7 +848,14 @@ MultiLaser::AdvanceSliceFFT (const Fields& fields, const amrex::Real dt, int ste
CuFFTUtils::cufftErrorToString(result) << "\n";
}
#elif defined(AMREX_USE_HIP)
amrex::Abort("Not implemented");
rocfft_execution_info execinfo = nullptr;
rocfft_status result = rocfft_execution_info_create(&execinfo);
result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());

void* in_buffer[2] = {(void*)m_rhs.dataPtr(), nullptr};
void* out_buffer[2] = {(void*)m_rhs_fourier.dataPtr(), nullptr};

result = rocfft_execute(m_plan_fwd, in_buffer, out_buffer, execinfo);
#else
LaserFFT::VendorExecute( m_plan_fwd );
#endif
Expand Down Expand Up @@ -856,7 +893,14 @@ MultiLaser::AdvanceSliceFFT (const Fields& fields, const amrex::Real dt, int ste
CuFFTUtils::cufftErrorToString(result) << "\n";
}
#elif defined(AMREX_USE_HIP)
amrex::Abort("Not implemented");
rocfft_execution_info execinfo = nullptr;
rocfft_status result = rocfft_execution_info_create(&execinfo);
result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());

void* in_buffer[2] = {(void*)m_rhs_fourier.dataPtr(), nullptr};
void* out_buffer[2] = {(void*)m_sol.dataPtr(), nullptr};

result = rocfft_execute(m_plan_bkw, in_buffer, out_buffer, execinfo);
#else
LaserFFT::VendorExecute( m_plan_bkw );
#endif
Expand Down
Loading