Skip to content

Commit

Permalink
[Phi] migrate exponential kernel to phi
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Jul 19, 2022
1 parent 4c1e77d commit 34bc3a9
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 167 deletions.
86 changes: 13 additions & 73 deletions paddle/fluid/operators/exponential_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/exponential_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand All @@ -21,13 +23,6 @@ class ExponentialOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExponentialOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExponentialOp");
auto dim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dim);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
Expand All @@ -51,63 +46,17 @@ exponential distribution.
}
};

class ExponentialOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};

template <typename T>
class ExponentialKernel<phi::CPUContext, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::Tensor>("Out");
T *out_data = out->mutable_data<T>(ctx.GetPlace());

T lambda = static_cast<T>(ctx.Attr<float>("lambda"));
int64_t size = out->numel();

auto gen = framework::DefaultCPUGenerator();
auto engine = gen->GetCPUEngine();

std::uniform_real_distribution<T> uniform(0.0, 1.0);
phi::funcs::exponential_transform<T> trans(lambda);
for (int64_t i = 0; i < size; ++i) {
out_data[i] = trans(uniform(*engine));
}
}
};

class ExponentialGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out_Grad",
"ExponentialGradOp");

auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dim);
}
};

template <typename T>
class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("exponential_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
retv->SetType("fill_any_like");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetAttr("value", 0.0f);
retv->SetOutput("Out", this->InputGrad("X"));
}
};

Expand All @@ -118,24 +67,15 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;

DECLARE_INPLACE_OP_INFERER(ExponentialInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ExponentialGradInferer,
{paddle::framework::GradVarName("Out"),
paddle::framework::GradVarName("X")});

DECLARE_INFER_SHAPE_FUNCTOR(exponential,
ExponentialInfershapeFunctor,
PD_INFER_META(phi::ExponentialInferMeta));

REGISTER_OPERATOR(exponential,
ops::ExponentialOp,
ops::ExponentialOpMaker,
ops::ExponentialOpInferVarType,
ops::ExponentialGradOpMaker<paddle::framework::OpDesc>,
ops::ExponentialGradOpMaker<paddle::imperative::OpBase>,
ExponentialInferer);
REGISTER_OPERATOR(exponential_grad,
ops::ExponentialGradOp,
ExponentialGradInferer);

REGISTER_OP_CPU_KERNEL(exponential,
ops::ExponentialKernel<phi::CPUContext, float>,
ops::ExponentialKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(exponential_grad,
ops::ExponentialGradKernel<phi::CPUContext, float>,
ops::ExponentialGradKernel<phi::CPUContext, double>);
ExponentialInferer,
ExponentialInfershapeFunctor);
48 changes: 0 additions & 48 deletions paddle/fluid/operators/exponential_op.cu

This file was deleted.

42 changes: 0 additions & 42 deletions paddle/fluid/operators/exponential_op.h

This file was deleted.

10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,16 @@
func : expm1
backward : expm1_grad

- api : exponential_
args : (Tensor x, float lambda)
output : Tensor(out)
infer_meta :
func : ExponentialInferMeta
kernel :
func : exponential
inplace : (x -> out)
backward : exponential__grad

- api : eye
args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor(out)
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,15 @@
func : expm1_grad
inplace : (out_grad -> x_grad)

- backward_api : exponential__grad
forward : exponential_ (Tensor x, float lambda) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
invoke : zeros_like(out_grad, DataType::UNDEFINED, {})
inplace : (out_grad -> x_grad)

- backward_api : flatten_grad
forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3378,6 +3378,11 @@ void IdentityLossInferMeta(const MetaTensor& x,
}
}

void ExponentialInferMeta(const MetaTensor& x, float lambda, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(x.dtype());
}

} // namespace phi

PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,4 +486,5 @@ void ChannelShuffleInferMeta(const MetaTensor& x,

void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out);

void ExponentialInferMeta(const MetaTensor& x, float lambda, MetaTensor* out);
} // namespace phi
45 changes: 45 additions & 0 deletions paddle/phi/kernels/cpu/exponential_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/exponential_kernel.h"

#include <random>

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"

namespace phi {

template <typename T, typename Context>
void ExponentialKernel(const Context& dev_ctx,
const DenseTensor& x,
float lambda,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
auto engine = dev_ctx.GetGenerator()->GetCPUEngine();

std::uniform_real_distribution<T> uniform(0.0, 1.0);
phi::funcs::exponential_transform<T> trans(lambda);

for (int64_t i = 0; i < out->numel(); ++i) {
out_data[i] = trans(uniform(*engine));
}
}

} // namespace phi

PD_REGISTER_KERNEL(
exponential, CPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {}
27 changes: 27 additions & 0 deletions paddle/phi/kernels/exponential_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void ExponentialKernel(const Context &dev_ctx,
const DenseTensor &x,
float lambda,
DenseTensor *out);

} // namespace phi
36 changes: 36 additions & 0 deletions paddle/phi/kernels/gpu/exponential_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/exponential_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"

namespace phi {

template <typename T, typename Context>
void ExponentialKernel(const Context &dev_ctx,
const DenseTensor &x,
float lambda,
DenseTensor *out) {
phi::funcs::uniform_distribution<T> dist;
phi::funcs::exponential_transform<T> trans(lambda);
phi::funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
}

} // namespace phi

PD_REGISTER_KERNEL(
exponential, GPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {}
Loading

0 comments on commit 34bc3a9

Please sign in to comment.