Skip to content

Commit

Permalink
Add atan2 op and test (#33067)
Browse files Browse the repository at this point in the history
* add atan2_op

* fix
  • Loading branch information
ronny1996 authored Jun 17, 2021
1 parent b0984c7 commit 918aeb7
Show file tree
Hide file tree
Showing 7 changed files with 528 additions and 0 deletions.
138 changes: 138 additions & 0 deletions paddle/fluid/operators/atan2_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/atan2_op.h"

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

namespace paddle {
namespace operators {

class Atan2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X1"), "Input", "X1", "atan2");
OP_INOUT_CHECK(ctx->HasInput("X2"), "Input", "X2", "atan2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "atan2");

auto in_dims = ctx->GetInputDim("X1");

ctx->SetOutputDim("Out", in_dims);
}
};

class Atan2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X1", "(Tensor), The input tensor of atan2 op.");
AddInput("X2", "(Tensor), The input tensor of atan2 op.");
AddOutput("Out", "(Tensor), The output tensor of atan2 op.");
AddComment(R"DOC(
Atan2 Operator.
This operator is used to perform elementwise atan2 for input $X1$, $X2$.
$$out = atan2(x1, x2)$$
)DOC");
}
};

class Atan2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X1"), "Input", "X1", "Atan2Grad");
OP_INOUT_CHECK(ctx->HasInput("X2"), "Input", "X2", "Atan2Grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@Grad", "Atan2Grad");

auto x1_grad_name = framework::GradVarName("X1");
auto x2_grad_name = framework::GradVarName("X2");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));

if (ctx->HasOutput(x1_grad_name)) {
ctx->SetOutputDim(framework::GradVarName("X1"), dout_dims);
}
if (ctx->HasOutput(x2_grad_name)) {
ctx->SetOutputDim(framework::GradVarName("X2"), dout_dims);
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X1");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};

template <typename T>
class Atan2GradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> retv) const override {
retv->SetType("atan2_grad");
retv->SetInput("X1", this->Input("X1"));
retv->SetInput("X2", this->Input("X2"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X1"), this->InputGrad("X1"));
retv->SetOutput(framework::GradVarName("X2"), this->InputGrad("X2"));
}
};

class Atan2OpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto type = ctx->GetInputDataType("X1");
if (ctx->GetInputDataType("X1") == framework::proto::VarType::INT32 ||
ctx->GetInputDataType("X1") == framework::proto::VarType::INT64 ||
ctx->GetInputDataType("X2") == framework::proto::VarType::INT32 ||
ctx->GetInputDataType("X2") == framework::proto::VarType::INT64) {
type = framework::proto::VarType::FP64;
}
ctx->SetOutputDataType("Out", type);
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(atan2, ops::Atan2Op, ops::Atan2OpMaker,
ops::Atan2GradMaker<paddle::framework::OpDesc>,
ops::Atan2GradMaker<paddle::imperative::OpBase>,
ops::Atan2OpVarTypeInference);

REGISTER_OPERATOR(atan2_grad, ops::Atan2GradOp);

REGISTER_OP_CPU_KERNEL(
atan2, ops::Atan2Kernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Atan2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);

REGISTER_OP_CPU_KERNEL(
atan2_grad, ops::Atan2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Atan2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Atan2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
31 changes: 31 additions & 0 deletions paddle/fluid/operators/atan2_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/atan2_op.h"

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
atan2, ops::Atan2Kernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Atan2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);

REGISTER_OP_CUDA_KERNEL(
atan2_grad,
ops::Atan2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Atan2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Atan2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
168 changes: 168 additions & 0 deletions paddle/fluid/operators/atan2_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using framework::To32BitIndex;

template <typename T>
struct Atan2Out {
using type = T;
};

template <>
struct Atan2Out<int32_t> {
using type = double;
};

template <>
struct Atan2Out<int64_t> {
using type = double;
};

template <typename T>
struct Atan2Functor {
Atan2Functor(const T* x1, const T* x2, typename Atan2Out<T>::type* out,
int64_t numel)
: x1_(x1), x2_(x2), out_(out), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = static_cast<typename Atan2Out<T>::type>(
::atan2f(static_cast<float>(x1_[idx]), static_cast<float>(x2_[idx])));
}

const T* x1_;
const T* x2_;
typename Atan2Out<T>::type* out_;
int64_t numel_;
};

template <>
struct Atan2Functor<double> {
Atan2Functor(const double* x1, const double* x2, double* out, int64_t numel)
: x1_(x1), x2_(x2), out_(out), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = ::atan2(x1_[idx], x2_[idx]);
}

const double* x1_;
const double* x2_;
double* out_;
int64_t numel_;
};

// dx1 = dout * x2 / ((x1)^2 + (x2)^2)
// dx2 = - dout * x1 / ((x1)^2 + (x2)^2)
template <typename T>
struct Atan2GradFunctor {
Atan2GradFunctor(const T* x1, const T* x2, const T* dout, T* dx1, T* dx2,
int64_t numel)
: x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
float x1 = static_cast<float>(x1_[idx]);
float x2 = static_cast<float>(x2_[idx]);
float x = x1 * x1 + x2 * x2;
dx1_[idx] = static_cast<T>(static_cast<float>(dout_[idx]) * x2 / x);
dx2_[idx] = static_cast<T>(-static_cast<float>(dout_[idx]) * x1 / x);
}

const T* x1_;
const T* x2_;
const T* dout_;
T* dx1_;
T* dx2_;
int64_t numel_;
};

template <>
struct Atan2GradFunctor<double> {
Atan2GradFunctor(const double* x1, const double* x2, const double* dout,
double* dx1, double* dx2, int64_t numel)
: x1_(x1), x2_(x2), dout_(dout), dx1_(dx1), dx2_(dx2), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
auto x = x1_[idx] * x1_[idx] + x2_[idx] * x2_[idx];
dx1_[idx] = dout_[idx] * x2_[idx] / x;
dx2_[idx] = -dout_[idx] * x1_[idx] / x;
}

const double* x1_;
const double* x2_;
const double* dout_;
double* dx1_;
double* dx2_;
int64_t numel_;
};

template <typename DeviceContext, typename T>
class Atan2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* X1 = context.Input<Tensor>("X1");
const Tensor* X2 = context.Input<Tensor>("X2");
Tensor* Out = context.Output<Tensor>("Out");

auto numel = X1->numel();
auto x1 = X1->data<T>();
auto x2 = X2->data<T>();
auto out = Out->mutable_data<typename Atan2Out<T>::type>(
context.GetPlace(), size_t(numel * sizeof(typename Atan2Out<T>::type)));
auto& dev_ctx = context.template device_context<DeviceContext>();

platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
Atan2Functor<T> functor(x1, x2, out, numel);
for_range(functor);
}
};

template <typename DeviceContext, typename T>
class Atan2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const {
const Tensor* X1 = context.Input<Tensor>("X1");
const Tensor* X2 = context.Input<Tensor>("X2");
const Tensor* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* dX1 = context.Output<Tensor>(framework::GradVarName("X1"));
Tensor* dX2 = context.Output<Tensor>(framework::GradVarName("X2"));

auto numel = X1->numel();
auto x1 = X1->data<T>();
auto x2 = X2->data<T>();
auto dout = dOut->data<T>();
auto dx1 =
dX1->mutable_data<T>(context.GetPlace(), size_t(numel * sizeof(T)));
auto dx2 =
dX2->mutable_data<T>(context.GetPlace(), size_t(numel * sizeof(T)));
auto& dev_ctx = context.template device_context<DeviceContext>();

platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
Atan2GradFunctor<T> functor(x1, x2, dout, dx1, dx2, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
from .tensor.math import acos # noqa: F401
from .tensor.math import asin # noqa: F401
from .tensor.math import atan # noqa: F401
from .tensor.math import atan2 # noqa: F401
from .tensor.math import ceil # noqa: F401
from .tensor.math import cos # noqa: F401
from .tensor.math import tan # noqa: F401
Expand Down Expand Up @@ -434,6 +435,7 @@
'divide',
'ceil',
'atan',
'atan2',
'expand',
'broadcast_to',
'ones_like',
Expand Down
Loading

0 comments on commit 918aeb7

Please sign in to comment.