Skip to content

Commit

Permalink
improve the performence of fftc2rgrad (#63137)
Browse files Browse the repository at this point in the history
* improve fftc2rgrad

* remove comments

* ci

* update typo
  • Loading branch information
YibinLiu666 authored Jun 5, 2024
1 parent 2a0aa1a commit 8be7f50
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
17 changes: 7 additions & 10 deletions paddle/phi/kernels/funcs/fft_fill_conj.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,26 +189,23 @@ template <typename T>
struct FFTFillConjGradFunctor {
T* input_;
const size_t axis_;
const int64_t* strides_;
const int64_t stride_to_last_axis;
const int64_t stride_second_to_last_axis;
const size_t double_length_;

FFTFillConjGradFunctor(T* input,
size_t axis,
const int64_t* strides,
int64_t stride_second_to_last_axis,
int64_t stride_to_last_axis,
size_t double_length)
: input_(input),
axis_(axis),
strides_(strides),
stride_to_last_axis(stride_to_last_axis),
stride_second_to_last_axis(stride_second_to_last_axis),
double_length_(double_length) {}

HOSTDEVICE void operator()(size_t index) {
size_t offtset = index; // back
size_t index_i;
for (size_t i = 0; i <= axis_; i++) {
index_i = offtset / strides_[i];
offtset %= strides_[i];
}

size_t index_i = (index % stride_second_to_last_axis) / stride_to_last_axis;
if ((0 < index_i) && (index_i < double_length_ + 1)) {
input_[index] *= static_cast<T>(2);
}
Expand Down
22 changes: 11 additions & 11 deletions paddle/phi/kernels/impl/fft_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,17 @@ void FFTC2RGradKernel(const Context& ctx,

const int64_t double_length =
out_grad.dims()[axes.back()] - x_grad->dims()[axes.back()];
const phi::DDim strides = common::stride(x_grad->dims());

#if defined(__NVCC__) || defined(__HIPCC__)
const thrust::device_vector<int64_t> strides_g(common::vectorize(strides));
const int64_t* pstrides = thrust::raw_pointer_cast(strides_g.data());
#else
const int64_t* pstrides = strides.Get();
#endif

funcs::FFTFillConjGradFunctor<C> func(
x_grad->data<C>(), axes.back(), pstrides, double_length);
int64_t stride_to_last_axis = 1;
auto ddim = x_grad->dims();
for (int i = ddim.size() - 2; i >= axes.back(); --i) {
stride_to_last_axis *= ddim[i + 1];
}
int64_t stride_second_to_last_axis = stride_to_last_axis * ddim[axes.back()];
funcs::FFTFillConjGradFunctor<C> func(x_grad->data<C>(),
axes.back(),
stride_second_to_last_axis,
stride_to_last_axis,
double_length);
size_t limit = x_grad->numel();
funcs::ForRange<Context> for_range(ctx, limit);
for_range(func);
Expand Down

0 comments on commit 8be7f50

Please sign in to comment.