Skip to content

Commit

Permalink
fix compile error
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSinn committed Nov 27, 2023
1 parent 656883b commit 1c6bfdb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
2 changes: 0 additions & 2 deletions src/laser/MultiLaser.H
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,6 @@ private:
/** Geometry of the laser file, 'rt' or 'xyt' */
std::string m_file_geometry = "";

/** Nb fields in 3D array: new_real, new_imag, old_real, old_imag */
int m_nfields_3d {4};
/** Array of N slices required to compute current slice */
amrex::MultiFab m_slices;
amrex::Real m_MG_tolerance_rel = 1.e-4;
Expand Down
60 changes: 33 additions & 27 deletions src/laser/MultiLaser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "utils/DeprecatedInput.H"
#ifdef AMREX_USE_CUDA
# include "fields/fft_poisson_solver/fft/CuFFTUtils.H"
#elif defined(AMREX_USE_HIP)
# include "fields/fft_poisson_solver/fft/RocFFTUtils.H"
#endif
#include "particles/particles_utils/ShapeFactors.H"

Expand Down Expand Up @@ -123,36 +125,30 @@ MultiLaser::InitData (const amrex::BoxArray& slice_ba,
}
#elif defined(AMREX_USE_HIP)
const std::size_t lengths[] = {
fft_size[0], fft_size[1]
static_cast<std::size_t>(fft_size[0]), static_cast<std::size_t>(fft_size[1])
};

rocfft_status result;
// Forward FFT plan
result = rocfft_plan_create(&(m_plan_fwd), rocfft_placement_notinplace,
rocfft_transform_type_real_forward,
rocfft_transform_type_complex_forward,
#ifdef AMREX_USE_FLOAT
rocfft_precision_single,
#else
rocfft_precision_double,
#endif
2, lengths, 1, nullptr);
if ( result != rocfft_status_success ) {
amrex::Print() << " rocfft_plan_create failed! Error: " <<
result << "\n";
}
RocFFTUtils::assert_rocfft_status("rocfft_plan_create", result);
// Backward FFT plan
result = rocfft_plan_create(&(m_plan_bkw), rocfft_placement_notinplace,
rocfft_transform_type_real_forward,
rocfft_transform_type_complex_inverse,
#ifdef AMREX_USE_FLOAT
rocfft_precision_single,
#else
rocfft_precision_double,
#endif
2, lengths, 1, nullptr);
if ( result != rocfft_status_success ) {
amrex::Print() << " rocfft_plan_create failed! Error: " <<
result << "\n";
}
RocFFTUtils::assert_rocfft_status("rocfft_plan_create", result);
#else
// Forward FFT plan
m_plan_fwd = LaserFFT::VendorCreate(
Expand Down Expand Up @@ -848,14 +844,19 @@ MultiLaser::AdvanceSliceFFT (const Fields& fields, const amrex::Real dt, int ste
CuFFTUtils::cufftErrorToString(result) << "\n";
}
#elif defined(AMREX_USE_HIP)
rocfft_execution_info execinfo = nullptr;
rocfft_status result = rocfft_execution_info_create(&execinfo);
result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());

void* in_buffer[2] = {(void*)m_rhs.dataPtr(), nullptr};
void* out_buffer[2] = {(void*)m_rhs_fourier.dataPtr(), nullptr};

result = rocfft_execute(m_plan_fwd, in_buffer, out_buffer, execinfo);
rocfft_execution_info execinfo_fwd = nullptr;
rocfft_status result = rocfft_execution_info_create(&execinfo_fwd);
RocFFTUtils::assert_rocfft_status("rocfft_execution_info_create", result);
result = rocfft_execution_info_set_stream(execinfo_fwd, amrex::Gpu::gpuStream());
RocFFTUtils::assert_rocfft_status("rocfft_execution_info_set_stream", result);

void* in_buffer_fwd[2] = {(void*)m_rhs.dataPtr(), nullptr};
void* out_bufferfwd[2] = {(void*)m_rhs_fourier.dataPtr(), nullptr};

result = rocfft_execute(m_plan_fwd, in_buffer_fwd, out_bufferfwd, execinfo_fwd);
RocFFTUtils::assert_rocfft_status("rocfft_execute", result);
result = rocfft_execution_info_destroy(execinfo_fwd);
RocFFTUtils::assert_rocfft_status("rocfft_execution_info_destroy", result);
#else
LaserFFT::VendorExecute( m_plan_fwd );
#endif
Expand Down Expand Up @@ -893,14 +894,19 @@ MultiLaser::AdvanceSliceFFT (const Fields& fields, const amrex::Real dt, int ste
CuFFTUtils::cufftErrorToString(result) << "\n";
}
#elif defined(AMREX_USE_HIP)
rocfft_execution_info execinfo = nullptr;
rocfft_status result = rocfft_execution_info_create(&execinfo);
result = rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream());

void* in_buffer[2] = {(void*)m_rhs_fourier.dataPtr(), nullptr};
void* out_buffer[2] = {(void*)m_sol.dataPtr(), nullptr};

result = rocfft_execute(m_plan_bkw, in_buffer, out_buffer, execinfo);
rocfft_execution_info execinfo_bkw = nullptr;
result = rocfft_execution_info_create(&execinfo_bkw);
RocFFTUtils::assert_rocfft_status("rocfft_execution_info_create", result);
result = rocfft_execution_info_set_stream(execinfo_bkw, amrex::Gpu::gpuStream());
RocFFTUtils::assert_rocfft_status("rocfft_execution_info_set_stream", result);

void* in_buffer_bkw[2] = {(void*)m_rhs_fourier.dataPtr(), nullptr};
void* out_buffer_bkw[2] = {(void*)m_sol.dataPtr(), nullptr};

result = rocfft_execute(m_plan_bkw, in_buffer_bkw, out_buffer_bkw, execinfo_bkw);
RocFFTUtils::assert_rocfft_status("rocfft_execute", result);
result = rocfft_execution_info_destroy(execinfo_bkw);
RocFFTUtils::assert_rocfft_status("rocfft_execution_info_destroy", result);
#else
LaserFFT::VendorExecute( m_plan_bkw );
#endif
Expand Down

0 comments on commit 1c6bfdb

Please sign in to comment.