Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 4 No.20】Add i0 / i0e to paddle #52058

Merged
merged 35 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0a56f9b
added base code for i0 and i0e
PommesPeter May 10, 2023
f9458da
added grad base code for i0 and i0e
PommesPeter May 10, 2023
c97995c
added i0 and i0e python code
PommesPeter Mar 18, 2023
9a914c4
added ops and backward yaml config
PommesPeter Mar 18, 2023
07315ab
added i0 and i0e cpu kernel, but not test.
PommesPeter Mar 23, 2023
695151b
added i0 and i0e code and unitest files
PommesPeter Mar 23, 2023
bff7833
added test files
PommesPeter Mar 23, 2023
04bfe52
added i0/i0e gpu implementation code
PommesPeter Mar 25, 2023
4367c6c
updated code style
PommesPeter Mar 25, 2023
efff735
updated code style
PommesPeter Mar 25, 2023
282aa84
fixed unitests code
PommesPeter Mar 27, 2023
ccc1f61
updated i0 with eigen3
PommesPeter Apr 7, 2023
bd7aadb
fixed bug and added more test cases
PommesPeter Apr 18, 2023
46959c4
refactor: fixed static graph bug
PommesPeter Apr 18, 2023
1cef5b7
refactor: removed i0 and i0e from op_compat
PommesPeter Apr 18, 2023
d42468b
refactor: updated code style
PommesPeter Apr 18, 2023
22f1004
refactor: updated op_compat.yaml
PommesPeter Apr 18, 2023
66b5fd1
refactor: updated op_compat.yaml
PommesPeter Apr 21, 2023
924b3fd
refactor: fixed op name mapping and optimize unittest case
PommesPeter Apr 23, 2023
08ecc31
refactor: manually implement i0 / i0e
PommesPeter Apr 28, 2023
cff862c
refactor: added grad kernel for i0 / i0e,didn't finish
PommesPeter Apr 28, 2023
6c57f75
Update math.py
PommesPeter Apr 29, 2023
18757b7
refactor: added equation to doc in English and added comments for com…
PommesPeter Apr 29, 2023
db95300
refactor: removed eigen implementation
PommesPeter Apr 29, 2023
74a1404
refactor: finished i0 / i0e cpu and gpu op
PommesPeter May 4, 2023
2082bfb
refactor: updated code style
PommesPeter May 4, 2023
75a7d69
fix: find a bug but not fix
PommesPeter May 6, 2023
4aee598
fix: incorrect unittest cases
PommesPeter May 9, 2023
e665d31
update: updated code style and remove my file
PommesPeter May 9, 2023
90f0672
update: updated unittest case
PommesPeter May 9, 2023
eeca07d
fix: fixed sign error
PommesPeter May 10, 2023
d3fb6f1
fix: fixed mistakes when merging
PommesPeter May 10, 2023
0a13c0c
refactor: updated code style
PommesPeter May 10, 2023
ce88d69
refactor: remove unused code
PommesPeter May 10, 2023
30cb1b6
refactor: updated code style
PommesPeter May 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,26 @@
kernel :
func : huber_loss_grad

- backward_op : i0_grad
forward : i0 (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : i0_grad

- backward_op : i0e_grad
forward : i0e (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : i0e_grad

- backward_op : imag_grad
forward : imag (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
Expand Down
18 changes: 18 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,24 @@
intermediate : residual
backward : huber_loss_grad

- op : i0
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : i0
backward : i0_grad

- op : i0e
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : i0e
backward : i0e_grad

- op : imag
args : (Tensor x)
output : Tensor (out)
Expand Down
42 changes: 42 additions & 0 deletions paddle/phi/kernels/cpu/i0_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i0_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void I0GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto size = x.numel();
auto* x_data = x.data<T>();
auto* out_grad_data = out_grad.data<T>();
auto* x_grad_data = ctx.template Alloc<T>(x_grad);

phi::funcs::ForRange<Context> for_range(ctx, size);
I0GradFunctor<T> functor(x_data, out_grad_data, x_grad_data, size);
for_range(functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i0_grad, CPU, ALL_LAYOUT, phi::I0GradKernel, float, double) {
}
37 changes: 37 additions & 0 deletions paddle/phi/kernels/cpu/i0_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/i0_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void I0Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);

phi::funcs::ForRange<Context> for_range(ctx, size);
I0Functor<T> functor(x_data, out_data, size);
for_range(functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i0, CPU, ALL_LAYOUT, phi::I0Kernel, float, double) {}
44 changes: 44 additions & 0 deletions paddle/phi/kernels/cpu/i0e_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/i0e_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void I0eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto size = x.numel();
auto* x_data = x.data<T>();
auto* out_data = out.data<T>();
auto* out_grad_data = out_grad.data<T>();
auto* x_gard_data = ctx.template Alloc<T>(x_grad);

phi::funcs::ForRange<Context> for_range(ctx, size);
I0eGradFunctor<T> functor(x_data, out_data, out_grad_data, x_gard_data, size);
for_range(functor);
}

} // namespace phi

PD_REGISTER_KERNEL(
i0e_grad, CPU, ALL_LAYOUT, phi::I0eGradKernel, float, double) {}
37 changes: 37 additions & 0 deletions paddle/phi/kernels/cpu/i0e_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/i0e_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void I0eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
int64_t size = x.numel();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);

phi::funcs::ForRange<Context> for_range(ctx, size);
I0eFunctor<T> functor(x_data, out_data, size);
for_range(functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i0e, CPU, ALL_LAYOUT, phi::I0eKernel, float, double) {}
39 changes: 39 additions & 0 deletions paddle/phi/kernels/gpu/i0_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/i0_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h"

namespace phi {

template <typename T, typename Context>
void I0GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = CudaI0GradFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i0_grad, GPU, ALL_LAYOUT, phi::I0GradKernel, float, double) {
}
35 changes: 35 additions & 0 deletions paddle/phi/kernels/gpu/i0_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/i0_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h"

namespace phi {

template <typename T, typename Context>
void I0Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaI0Functor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i0, GPU, ALL_LAYOUT, phi::I0Kernel, float, double) {}
40 changes: 40 additions & 0 deletions paddle/phi/kernels/gpu/i0e_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/i0e_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h"

namespace phi {

template <typename T, typename Context>
void I0eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = CudaI0eGradFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(
i0e_grad, GPU, ALL_LAYOUT, phi::I0eGradKernel, float, double) {}
35 changes: 35 additions & 0 deletions paddle/phi/kernels/gpu/i0e_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/i0e_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h"

namespace phi {

template <typename T, typename Context>
void I0eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaI0eFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(i0e, GPU, ALL_LAYOUT, phi::I0eKernel, float, double) {}
28 changes: 28 additions & 0 deletions paddle/phi/kernels/i0_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"

namespace phi {

template <typename T, typename Context>
void I0GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad);

} // namespace phi
Loading