Skip to content

Commit

Permalink
[Phi] migrate exponential kernel to phi (PaddlePaddle#44376)
Browse files Browse the repository at this point in the history
* [Phi] migrate exponential kernel to phi

* fix comment

* fix CI
  • Loading branch information
zhwesky2010 authored and Aurelius84 committed Jul 29, 2022
1 parent 57d754e commit edadc48
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 168 deletions.
86 changes: 13 additions & 73 deletions paddle/fluid/operators/exponential_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/exponential_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand All @@ -21,13 +23,6 @@ class ExponentialOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExponentialOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExponentialOp");
auto dim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dim);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
Expand All @@ -51,63 +46,17 @@ exponential distribution.
}
};

class ExponentialOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};

template <typename T>
class ExponentialKernel<phi::CPUContext, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::Tensor>("Out");
T *out_data = out->mutable_data<T>(ctx.GetPlace());

T lambda = static_cast<T>(ctx.Attr<float>("lambda"));
int64_t size = out->numel();

auto gen = framework::DefaultCPUGenerator();
auto engine = gen->GetCPUEngine();

std::uniform_real_distribution<T> uniform(0.0, 1.0);
phi::funcs::exponential_transform<T> trans(lambda);
for (int64_t i = 0; i < size; ++i) {
out_data[i] = trans(uniform(*engine));
}
}
};

class ExponentialGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out_Grad",
"ExponentialGradOp");

auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dim);
}
};

template <typename T>
class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("exponential_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
retv->SetType("fill_any_like");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetAttr("value", 0.0f);
retv->SetOutput("Out", this->InputGrad("X"));
}
};

Expand All @@ -118,24 +67,15 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;

DECLARE_INPLACE_OP_INFERER(ExponentialInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ExponentialGradInferer,
{paddle::framework::GradVarName("Out"),
paddle::framework::GradVarName("X")});

DECLARE_INFER_SHAPE_FUNCTOR(exponential,
ExponentialInfershapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));

REGISTER_OPERATOR(exponential,
ops::ExponentialOp,
ops::ExponentialOpMaker,
ops::ExponentialOpInferVarType,
ops::ExponentialGradOpMaker<paddle::framework::OpDesc>,
ops::ExponentialGradOpMaker<paddle::imperative::OpBase>,
ExponentialInferer);
REGISTER_OPERATOR(exponential_grad,
ops::ExponentialGradOp,
ExponentialGradInferer);

REGISTER_OP_CPU_KERNEL(exponential,
ops::ExponentialKernel<phi::CPUContext, float>,
ops::ExponentialKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(exponential_grad,
ops::ExponentialGradKernel<phi::CPUContext, float>,
ops::ExponentialGradKernel<phi::CPUContext, double>);
ExponentialInferer,
ExponentialInfershapeFunctor);
48 changes: 0 additions & 48 deletions paddle/fluid/operators/exponential_op.cu

This file was deleted.

42 changes: 0 additions & 42 deletions paddle/fluid/operators/exponential_op.h

This file was deleted.

3 changes: 2 additions & 1 deletion paddle/phi/api/yaml/generator/wrapped_infermeta_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def gene_wrapped_infermeta_and_register(api):
'const paddle::optional<Tensor>&': 'const MetaTensor&'
}

wrapped_infermeta_name = get_wrapped_infermeta_name(api.api)
wrapped_infermeta_name = get_wrapped_infermeta_name(
api.kernel['func'][0])
args = []
for input_name in api.inputs['names']:
if input_name in kernel_params:
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,17 @@
func : expm1
backward : expm1_grad

- api : exponential_
args : (Tensor x, float lambda)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : exponential
inplace : (x -> out)
backward : exponential__grad

- api : eye
args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor(out)
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,15 @@
func : expm1_grad
inplace : (out_grad -> x_grad)

- backward_api : exponential__grad
forward : exponential_ (Tensor x, float lambda) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
invoke : zeros_like(out_grad, DataType::UNDEFINED, {})
inplace : (out_grad -> x_grad)

- backward_api : flatten_grad
forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
Expand Down
45 changes: 45 additions & 0 deletions paddle/phi/kernels/cpu/exponential_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/exponential_kernel.h"

#include <random>

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"

namespace phi {

template <typename T, typename Context>
void ExponentialKernel(const Context& dev_ctx,
const DenseTensor& x,
float lambda,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
auto engine = dev_ctx.GetGenerator()->GetCPUEngine();

std::uniform_real_distribution<T> uniform(0.0, 1.0);
phi::funcs::exponential_transform<T> trans(lambda);

for (int64_t i = 0; i < out->numel(); ++i) {
out_data[i] = trans(uniform(*engine));
}
}

} // namespace phi

PD_REGISTER_KERNEL(
exponential, CPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {}
27 changes: 27 additions & 0 deletions paddle/phi/kernels/exponential_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void ExponentialKernel(const Context &dev_ctx,
const DenseTensor &x,
float lambda,
DenseTensor *out);

} // namespace phi
36 changes: 36 additions & 0 deletions paddle/phi/kernels/gpu/exponential_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/exponential_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"

namespace phi {

template <typename T, typename Context>
void ExponentialKernel(const Context &dev_ctx,
const DenseTensor &x,
float lambda,
DenseTensor *out) {
phi::funcs::uniform_distribution<T> dist;
phi::funcs::exponential_transform<T> trans(lambda);
phi::funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
}

} // namespace phi

PD_REGISTER_KERNEL(
exponential, GPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {}
26 changes: 26 additions & 0 deletions paddle/phi/ops/compat/exponential_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/core/compat/op_utils.h"

namespace phi {

KernelSignature ExponentialOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("exponential", {"X"}, {"lambda"}, {"Out"});
}

} // namespace phi

PD_REGISTER_ARG_MAPPING_FN(exponential, phi::ExponentialOpArgumentMapping);
Loading

0 comments on commit edadc48

Please sign in to comment.