Skip to content

Commit

Permalink
[PHI decoupling] replace dependency of inclusive_scan.h from phi (#48980
Browse files Browse the repository at this point in the history
)

* replace dependency of inclusive_scan.h from phi

* format code
  • Loading branch information
Patrick-Star125 authored Dec 12, 2022
1 parent 00f2031 commit c9f4cfa
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 33 deletions.
44 changes: 21 additions & 23 deletions paddle/phi/kernels/gpu/cumprod_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

#include <thrust/transform.h>

#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"
// NOTE(@xiongkun): use of IsComplex<>
#include "paddle/phi/core/utils/data_type.h"

Expand Down Expand Up @@ -194,16 +194,15 @@ void CumprodGradKernel(const Context &dev_ctx,
auto zero_mask = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(uint8_t));
auto *zero_mask_data = reinterpret_cast<uint8_t *>(zero_mask->ptr());
paddle::operators::math::InclusiveScan<uint8_t, cub::Max>(
zero_mask_without_cummax_data,
zero_mask_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<uint8_t>(0),
cub::Max(),
/*reverse=*/false,
dev_ctx);
phi::funcs::InclusiveScan<uint8_t, cub::Max>(zero_mask_without_cummax_data,
zero_mask_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<uint8_t>(0),
cub::Max(),
/*reverse=*/false,
dev_ctx);
zero_mask_without_cummax = nullptr;

// Step 2: calculate reversed cumsum(dy * y)
Expand All @@ -222,16 +221,15 @@ void CumprodGradKernel(const Context &dev_ctx,
.Allocate(numel * sizeof(T));
auto *dy_mul_y_reversed_cumsum_data =
reinterpret_cast<T *>(dy_mul_y_reversed_cumsum->ptr());
paddle::operators::math::InclusiveScan<T, cub::Sum>(
dy_mul_y_data,
dy_mul_y_reversed_cumsum_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
phi::funcs::InclusiveScan<T, cub::Sum>(dy_mul_y_data,
dy_mul_y_reversed_cumsum_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
cub::Sum(),
/*reverse=*/true,
dev_ctx);

// Step 3: calculate the gradient value except the first zero position.
// The gradient value of the first zero position is filled with out[idx-1],
Expand Down Expand Up @@ -262,7 +260,7 @@ void CumprodGradKernel(const Context &dev_ctx,
// Step 4: calculate cumprod of x_filled_one
auto *x_filled_one_cumprod_data =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
paddle::operators::math::InclusiveScan<T, funcs::MultiplyFunctor<T>>(
phi::funcs::InclusiveScan<T, funcs::MultiplyFunctor<T>>(
x_filled_one_data,
x_filled_one_cumprod_data,
outer_dim,
Expand All @@ -284,7 +282,7 @@ void CumprodGradKernel(const Context &dev_ctx,
funcs::MultiplyFunctor<T>());
auto *dy_mul_x_filled_one_cumprod_reversed_cumsum =
dy_mul_y_reversed_cumsum_data; // reuse former allocated memory
paddle::operators::math::InclusiveScan<T, cub::Sum>(
phi::funcs::InclusiveScan<T, cub::Sum>(
dy_mul_x_filled_one_cumprod,
dy_mul_x_filled_one_cumprod_reversed_cumsum,
outer_dim,
Expand Down
20 changes: 10 additions & 10 deletions paddle/phi/kernels/gpu/cumprod_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

#include "paddle/phi/kernels/cumprod_kernel.h"

#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"

namespace phi {

Expand All @@ -35,15 +35,15 @@ void CumprodKernel(const Context &dev_ctx,

const auto *x_data = x->data<T>();
auto *y_data = dev_ctx.template Alloc<T>(y);
paddle::operators::math::InclusiveScan(x_data,
y_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
/*reverse=*/false,
dev_ctx);
phi::funcs::InclusiveScan(x_data,
y_data,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(1),
funcs::MultiplyFunctor<T>(),
/*reverse=*/false,
dev_ctx);
}

} // namespace phi
Expand Down

0 comments on commit c9f4cfa

Please sign in to comment.