Skip to content

Commit

Permalink
add python test
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Jan 1, 2025
1 parent 1556896 commit 9f98052
Show file tree
Hide file tree
Showing 13 changed files with 1,127 additions and 142 deletions.
14 changes: 4 additions & 10 deletions paddle/phi/kernels/cpu/clip_tensor_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,12 @@ void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
DenseTensor ex_min;
MetaTensor meta_min(&ex_min);
CastInferMeta(min, x.dtype(), &meta_min);
DenseTensor ex_max;
MetaTensor meta_max(&ex_max);
CastInferMeta(max, x.dtype(), &meta_max);
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());

const T* x_data = x.data<T>();
const T* min_data = ex_min.data<T>();
const T* max_data = ex_max.data<T>();
const T* min_data = tem_min.data<T>();
const T* max_data = tem_max.data<T>();
auto numel = x.numel();
auto* dout = out_grad.data<T>();

Expand Down
14 changes: 4 additions & 10 deletions paddle/phi/kernels/cpu/clip_tensor_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,12 @@ void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
DenseTensor ex_min;
MetaTensor meta_min(&ex_min);
CastInferMeta(min, x.dtype(), &meta_min);
DenseTensor ex_max;
MetaTensor meta_max(&ex_max);
CastInferMeta(max, x.dtype(), &meta_max);
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());

const T* x_data = x.data<T>();
const T* min_data = ex_min.data<T>();
const T* max_data = ex_max.data<T>();
const T* min_data = tem_min.data<T>();
const T* max_data = tem_max.data<T>();

auto x_numel = x.numel();

Expand Down
14 changes: 4 additions & 10 deletions paddle/phi/kernels/gpu/clip_tensor_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,13 @@ void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
DenseTensor ex_min;
MetaTensor meta_min(&ex_min);
CastInferMeta(min, x.dtype(), &meta_min);
DenseTensor ex_max;
MetaTensor meta_max(&ex_max);
CastInferMeta(max, x.dtype(), &meta_max);
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());

const T* x_data = x.data<T>();
auto numel = x.numel();
const T* min_data = ex_min.data<T>();
const T* max_data = ex_max.data<T>();
const T* min_data = tem_min.data<T>();
const T* max_data = tem_max.data<T>();
const T* out_grad_data = out_grad.data<T>();

T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
Expand Down
18 changes: 7 additions & 11 deletions paddle/phi/kernels/gpu/clip_tensor_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/clip_kernel.h"
#include "paddle/phi/kernels/clip_tensor_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
Expand All @@ -27,7 +27,9 @@ namespace phi {
template <typename T>
struct ClipTensorFunctor {
inline HOSTDEVICE T operator()(const T x, const T min_, const T max_) const {
return x < min_ ? min_ : x > max_ ? max_ : x;
T x_ = x < min_ ? min_ : x;
T x__ = x_ > max_ ? max_ : x_;
return x__;
}
};

Expand All @@ -37,16 +39,10 @@ void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
DenseTensor ex_min;
MetaTensor meta_min(&ex_min);
CastInferMeta(min, x.dtype(), &meta_min);
DenseTensor ex_max;
MetaTensor meta_max(&ex_max);
CastInferMeta(max, x.dtype(), &meta_max);
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());

std::vector<const DenseTensor*> ins = {&x, &ex_min, &ex_max};
std::vector<const DenseTensor*> ins = {&x, &tem_min, &tem_max};
std::vector<DenseTensor*> outs = {out};
dev_ctx.template Alloc<T>(out);

Expand Down
202 changes: 202 additions & 0 deletions paddle/phi/kernels/onednn/clip_tensor_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/clip_tensor_grad_kernel.h"

#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/elementwise_kernel.h"

namespace phi {
template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {

const auto& onednn_engine = dev_ctx.GetEngine();
auto& astream = OneDNNContext::tls().get_stream();

DenseTensor t_min_mask;
MetaTensor meta_min_mask(&t_min_mask);
UnchangedInferMeta(x, &meta_min_mask);
DenseTensor t_max_mask;
MetaTensor meta_max_mask(&t_max_mask);
UnchangedInferMeta(x, &meta_max_mask);
DenseTensor t_zero_mask;
MetaTensor meta_zero_mask(&t_zero_mask);
UnchangedInferMeta(x, &meta_zero_mask);

auto* tem_min_mask = &t_min_mask;
auto* tem_max_mask = &t_max_mask;
auto* tem_zero_mask = &t_zero_mask;
auto* non_const_x = &x;
auto* non_const_min = &min;
auto* non_const_max = &max;
auto* non_const_out_grad = &out_grad;

funcs::BinaryOneDNNHandler<T> Lesshandler(dnnl::algorithm::binary_lt,
-1,
onednn_engine,
dev_ctx.GetPlace(),
non_const_min,
non_const_out_grad,
tem_min_mask,
1.0f,
1.0f,
1.0f,
true);

auto src_memory_p_min1 = Lesshandler.AcquireSrcMemory(non_const_min);
auto src_memory_p_out_grad1 =
Lesshandler.AcquireSecondSrcMemory(non_const_out_grad);
std::shared_ptr<dnnl::memory> dst_memory_p1 = Lesshandler.AcquireDstMemory(tem_min_mask);
auto activation_p1 = Lesshandler.AcquireForwardPrimitive();

std::unordered_map<int, dnnl::memory> args1 = {
{DNNL_ARG_SRC_0, *src_memory_p_min1},
{DNNL_ARG_SRC_1, *src_memory_p_out_grad1},
{DNNL_ARG_DST, *dst_memory_p1}};

if (Lesshandler.Has_SRC_0_Scale()) {
args1.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0,
Lesshandler.Get_SRC_0_Scale_Memory()});
}

if (Lesshandler.Has_SRC_1_Scale()) {
args1.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1,
Lesshandler.Get_SRC_1_Scale_Memory()});
}

activation_p1->execute(astream, args1);

funcs::BinaryOneDNNHandler<T> Grahandler(dnnl::algorithm::binary_gt,
-1,
onednn_engine,
dev_ctx.GetPlace(),
non_const_max,
non_const_out_grad,
tem_max_mask,
1.0f,
1.0f,
1.0f,
true);

auto src_memory_p_max2 = Grahandler.AcquireSrcMemory(non_const_max);
auto src_memory_p_out_grad2 =
Grahandler.AcquireSecondSrcMemory(non_const_out_grad);
std::shared_ptr<dnnl::memory> dst_memory_p2 = Grahandler.AcquireDstMemory(tem_max_mask);
auto activation_p2 = Grahandler.AcquireForwardPrimitive();

std::unordered_map<int, dnnl::memory> args2 = {
{DNNL_ARG_SRC_0, *src_memory_p_max2},
{DNNL_ARG_SRC_1, *src_memory_p_out_grad2},
{DNNL_ARG_DST, *dst_memory_p2}};

if (Grahandler.Has_SRC_0_Scale()) {
args2.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0,
Grahandler.Get_SRC_0_Scale_Memory()});
}

if (Grahandler.Has_SRC_1_Scale()) {
args2.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1,
Grahandler.Get_SRC_1_Scale_Memory()});
}

activation_p2->execute(astream, args2);

funcs::BinaryOneDNNHandler<T> Mulhandler1(dnnl::algorithm::binary_mul,
-1,
onednn_engine,
dev_ctx.GetPlace(),
tem_min_mask,
tem_max_mask,
tem_zero_mask,
1.0f,
1.0f,
1.0f,
true);

auto src_memory_p_min3 = Mulhandler1.AcquireSrcMemory(tem_min_mask);
auto src_memory_p_max3 = Mulhandler1.AcquireSecondSrcMemory(tem_max_mask);
std::shared_ptr<dnnl::memory> dst_memory_p3 = Mulhandler1.AcquireDstMemory(tem_zero_mask);
auto activation_p3 = Mulhandler1.AcquireForwardPrimitive();

std::unordered_map<int, dnnl::memory> args3 = {
{DNNL_ARG_SRC_0, *src_memory_p_min3},
{DNNL_ARG_SRC_1, *src_memory_p_max3},
{DNNL_ARG_DST, *dst_memory_p3}};

if (Mulhandler1.Has_SRC_0_Scale()) {
args3.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0,
Mulhandler1.Get_SRC_0_Scale_Memory()});
}

if (Mulhandler1.Has_SRC_1_Scale()) {
args3.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1,
Mulhandler1.Get_SRC_1_Scale_Memory()});
}

activation_p3->execute(astream, args3);

funcs::BinaryOneDNNHandler<T> Mulhandler2(dnnl::algorithm::binary_mul,
-1,
onednn_engine,
dev_ctx.GetPlace(),
tem_zero_mask,
non_const_x,
x_grad,
1.0f,
1.0f,
1.0f,
true);

auto src_memory_p_zero4 = Mulhandler2.AcquireSrcMemory(tem_zero_mask);
auto src_memory_p_x4 = Mulhandler2.AcquireSecondSrcMemory(non_const_x);
std::shared_ptr<dnnl::memory> dst_memory_p4 = Mulhandler2.AcquireDstMemory(x_grad);
auto activation_p4 = Mulhandler2.AcquireForwardPrimitive();

std::unordered_map<int, dnnl::memory> args4 = {
{DNNL_ARG_SRC_0, *src_memory_p_zero4},
{DNNL_ARG_SRC_1, *src_memory_p_x4},
{DNNL_ARG_DST, *dst_memory_p4}};

if (Mulhandler2.Has_SRC_0_Scale()) {
args4.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0,
Mulhandler2.Get_SRC_0_Scale_Memory()});
}

if (Mulhandler2.Has_SRC_1_Scale()) {
args4.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1,
Mulhandler2.Get_SRC_1_Scale_Memory()});
}

activation_p4->execute(astream, args4);

astream.wait();

x_grad->set_mem_desc(dst_memory_p4->get_desc());
}
} // namespace phi

PD_REGISTER_KERNEL(clip_tensor_grad,
OneDNN,
ONEDNN,
phi::ClipTensorGradKernel,
float,
phi::dtype::bfloat16) {}
Loading

0 comments on commit 9f98052

Please sign in to comment.