Skip to content

Commit

Permalink
move c_identity to phi (PaddlePaddle#56215)
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatV authored and BeingGod committed Sep 9, 2023
1 parent 2e9d415 commit 9c266d9
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 127 deletions.
23 changes: 10 additions & 13 deletions paddle/fluid/operators/collective/c_identity_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_identity_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"

#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -79,20 +82,14 @@ class CIdentityOpGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

DECLARE_INFER_SHAPE_FUNCTOR(c_identity,
CIdentityShapeFunctor,
PD_INFER_META(phi::CIdentityInferMeta));

REGISTER_OPERATOR(c_identity,
ops::CIdentityOp,
ops::CIdentityOpGradMaker<paddle::framework::OpDesc>,
ops::CIdentityOpGradMaker<paddle::imperative::OpBase>,
ops::CIdentityOpMaker);

PD_REGISTER_STRUCT_KERNEL(c_identity,
CPU,
ALL_LAYOUT,
ops::CIdentityOpCPUKernel,
float,
double,
int,
int64_t,
plat::float16) {}
ops::CIdentityOpMaker,
CIdentityShapeFunctor);
32 changes: 0 additions & 32 deletions paddle/fluid/operators/collective/c_identity_op.cu.cc

This file was deleted.

57 changes: 0 additions & 57 deletions paddle/fluid/operators/collective/c_identity_op.h

This file was deleted.

47 changes: 35 additions & 12 deletions paddle/fluid/operators/custom_device_common_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/operators/collective/c_concat_op.h"
#include "paddle/fluid/operators/collective/c_identity_op.h"
#include "paddle/fluid/operators/load_combine_op.h"
#include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h"
Expand Down Expand Up @@ -147,6 +146,25 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class CIdentityOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");

int rid = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
rid,
0,
platform::errors::InvalidArgument(
"The ring_id (%d) for c_identity op must be non-negative.", rid));
ctx.device_context().Alloc<T>(out);

paddle::framework::TensorCopy(*x, out->place(), out);
}
};

template <typename DeviceContext, typename T>
class CSplitOpCustomDeviceKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -1363,17 +1381,22 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_identity,
device_type,
paddle::operators::
CIdentityOpKernel<float, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<double, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<int, paddle::platform::CustomDeviceContext>,
paddle::operators::
CIdentityOpKernel<int64_t, paddle::platform::CustomDeviceContext>,
paddle::operators::CIdentityOpKernel<
paddle::platform::float16,
paddle::platform::CustomDeviceContext>) {}
paddle::operators::CIdentityOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
float>,
paddle::operators::CIdentityOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
double>,
paddle::operators::CIdentityOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int>,
paddle::operators::CIdentityOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
int64_t>,
paddle::operators::CIdentityOpCustomDeviceKernel<
paddle::platform::CustomDeviceContext,
paddle::platform::float16>) {}

REGISTER_OP_CUSTOM_DEVICE_KERNEL(
c_sync_calc_stream,
device_type,
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,20 @@ void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out) {
out->share_lod(x);
}

void CIdentityInferMeta(const MetaTensor& x,
int ring_id,
bool use_calc_stream,
bool use_model_parallel,
MetaTensor* out) {
PADDLE_ENFORCE_GE(
ring_id,
0,
errors::InvalidArgument(
"The ring_id (%d) for c_identity must be non-negative.", ring_id));
out->set_dims(x.dims());
out->set_dtype(x.dtype());
}

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ void ClassCenterSampleInferMeta(const MetaTensor& label,

void ClipByNormInferMeta(const MetaTensor& x, float max_norm, MetaTensor* out);

void CIdentityInferMeta(const MetaTensor& x,
int ring_id,
bool use_calc_stream,
bool use_model_parallel,
MetaTensor* out);

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);

void CropInferMeta(const MetaTensor& x,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_identity_op.h"
#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace phi {

PD_REGISTER_STRUCT_KERNEL(c_identity,
XPU,
ALL_LAYOUT,
ops::CIdentityOpKernel,
float,
double,
int,
int64_t,
plat::float16) {}
template <typename T, typename Context>
void CIdentityKernel(const Context& dev_ctx,
const DenseTensor& x,
int ring_id,
bool use_calc_stream,
bool use_model_parallel,
DenseTensor* out);
} // namespace phi
43 changes: 43 additions & 0 deletions paddle/phi/kernels/cpu/c_identity_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/c_identity_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void CIdentityKernel(const Context& dev_ctx,
const DenseTensor& x,
int ring_id,
bool use_calc_stream,
bool use_model_parallel,
DenseTensor* out) {
PADDLE_THROW(
errors::Unavailable("Do not support c_identity for cpu kernel now."));
}

} // namespace phi

PD_REGISTER_KERNEL(c_identity,
CPU,
ALL_LAYOUT,
phi::CIdentityKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
30 changes: 30 additions & 0 deletions paddle/phi/kernels/gpu/c_identity_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/c_identity_kernel.h"
#include "paddle/phi/kernels/impl/c_identity_kernel_impl.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(c_identity,
GPU,
ALL_LAYOUT,
phi::CIdentityKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
41 changes: 41 additions & 0 deletions paddle/phi/kernels/impl/c_identity_kernel_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/kernels/c_identity_kernel.h"

#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void CIdentityKernel(const Context& dev_ctx,
const DenseTensor& x,
int ring_id,
bool use_calc_stream,
bool use_model_parallel,
DenseTensor* out) {
PADDLE_ENFORCE_GE(
ring_id,
0,
errors::InvalidArgument(
"The ring_id (%d) for c_identity op must be non-negative.", ring_id));

dev_ctx.template Alloc<T>(out);

phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}

} // namespace phi
Loading

0 comments on commit 9c266d9

Please sign in to comment.