Skip to content

Commit

Permalink
Einsum grad complex (#44598)
Browse files Browse the repository at this point in the history
* add complex for einsum grad kernel

* pass the ci
  • Loading branch information
2742195759 authored Jul 26, 2022
1 parent 25d3dce commit e0dd7f3
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 7 deletions.
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/einsum_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_grad_impl.h"

PD_REGISTER_KERNEL(
einsum_grad, CPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {}
PD_REGISTER_KERNEL(einsum_grad,
CPU,
ALL_LAYOUT,
phi::EinsumGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
14 changes: 11 additions & 3 deletions paddle/phi/kernels/cpu/tile_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/tile_kernel_impl.h"

PD_REGISTER_KERNEL(
tile, CPU, ALL_LAYOUT, phi::TileKernel, bool, float, double, int, int64_t) {
}
PD_REGISTER_KERNEL(tile,
CPU,
ALL_LAYOUT,
phi::TileKernel,
bool,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
5 changes: 5 additions & 0 deletions paddle/phi/kernels/funcs/eigen/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

Expand Down Expand Up @@ -75,13 +76,17 @@ struct EigenBroadcastGrad<Eigen::DefaultDevice, T, Rank> {
INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, dtype::complex<float>);
INSTANTIATION(EigenBroadcast, dtype::complex<double>);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int);
INSTANTIATION(EigenBroadcast, int64_t);
INSTANTIATION(EigenBroadcastGrad, bool);
INSTANTIATION(EigenBroadcastGrad, float);
INSTANTIATION(EigenBroadcastGrad, dtype::float16);
INSTANTIATION(EigenBroadcastGrad, dtype::complex<float>);
INSTANTIATION(EigenBroadcastGrad, dtype::complex<double>);
INSTANTIATION(EigenBroadcastGrad, double);
INSTANTIATION(EigenBroadcastGrad, int);
INSTANTIATION(EigenBroadcastGrad, int64_t);
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/funcs/eigen/broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

Expand Down Expand Up @@ -77,12 +78,16 @@ INSTANTIATION(EigenBroadcast, dtype::float16);
INSTANTIATION(EigenBroadcast, dtype::bfloat16);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, dtype::complex<float>);
INSTANTIATION(EigenBroadcast, dtype::complex<double>);
INSTANTIATION(EigenBroadcast, int);
INSTANTIATION(EigenBroadcast, int64_t);
INSTANTIATION(EigenBroadcastGrad, bool);
INSTANTIATION(EigenBroadcastGrad, float);
INSTANTIATION(EigenBroadcastGrad, dtype::float16);
INSTANTIATION(EigenBroadcastGrad, double);
INSTANTIATION(EigenBroadcastGrad, dtype::complex<float>);
INSTANTIATION(EigenBroadcastGrad, dtype::complex<double>);
INSTANTIATION(EigenBroadcastGrad, int);
INSTANTIATION(EigenBroadcastGrad, int64_t);
template struct EigenBroadcastGrad<Eigen::GpuDevice, float, 0>;
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/einsum_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(einsum_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/tile_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ PD_REGISTER_KERNEL(tile,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

0 comments on commit e0dd7f3

Please sign in to comment.