Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new api trunc, test=develop #33371

Merged
merged 14 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/unused_var_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ static const std::unordered_set<std::string> &GetOpWithUnusedVarAllowSet() {
"data_norm_grad", // 0
"update_loss_scaling", // 0
"fused_embedding_eltwise_layernorm", // 0
"trunc_grad", // 1
});
return *allow_set;
}
Expand Down
89 changes: 89 additions & 0 deletions paddle/fluid/operators/trunc_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/trunc_op.h"

namespace paddle {
namespace operators {

class TruncOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "trunc");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "trunc");
auto input_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", input_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};

class TruncOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of trunc op.");
AddOutput("Out", "(Tensor), The output tensor of trunc op.");
AddComment(R"DOC(
Trunc Operator.
Returns a new tensor with the truncated integer values of input.
$$out = trunc(x)$$
)DOC");
}
};

class TruncGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "TruncGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "TruncGrad");

auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
}
};

template <typename T>
class TruncGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> retv) const override {
retv->SetType("trunc_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can register NoNeedBufferVars for X@GRAD to save memory, see https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/op_notes_cn.html#id6 for details.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker,
ops::TruncGradOpMaker<paddle::framework::OpDesc>,
ops::TruncGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp);

REGISTER_OP_CPU_KERNEL(trunc, ops::TruncKernel<float>, ops::TruncKernel<double>,
ops::TruncKernel<int>, ops::TruncKernel<int64_t>);

REGISTER_OP_CPU_KERNEL(trunc_grad, ops::TruncGradKernel<float>,
ops::TruncGradKernel<double>, ops::TruncGradKernel<int>,
ops::TruncGradKernel<int64_t>);
115 changes: 115 additions & 0 deletions paddle/fluid/operators/trunc_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/trunc_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;

template <typename T>
class TruncFunctor {
public:
__device__ TruncFunctor(const T x) : x_(x) {}
__device__ T operator()() { return trunc(x_); }

public:
const T x_;
};

template <>
class TruncFunctor<int> {
public:
__device__ TruncFunctor(const int x) : x_(x) {}
__device__ int operator()() { return x_; }

public:
const int x_;
};

template <>
class TruncFunctor<int64_t> {
public:
__device__ TruncFunctor(const int64_t x) : x_(x) {}
__device__ int64_t operator()() { return x_; }

public:
const int64_t x_;
};

template <typename T>
__global__ void Trunc(const T* x, T* out, int64_t N) {
CUDA_KERNEL_LOOP(index, N) {
TruncFunctor<T> functor(x[index]);
out[index] = functor();
}
}

template <typename T>
__global__ void TruncGrad(T* dx, int64_t N) {
CUDA_KERNEL_LOOP(index, N) { dx[index] = static_cast<T>(0.0); }
}

template <typename T>
class TruncCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");

const auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());

int64_t numel = x->numel();

int theads = PADDLE_CUDA_NUM_THREADS;
int blocks = (numel + theads - 1) / theads;

Trunc<<<blocks, theads>>>(x_data, out_data, numel);
}
};

template <typename T>
class TruncCUDAGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));

const auto* dout_data = dout->data<T>();
auto* dx_data = dx->mutable_data<T>(context.GetPlace());

int64_t numel = dout->numel();

int theads = PADDLE_CUDA_NUM_THREADS;
int blocks = (numel + theads - 1) / theads;

TruncGrad<<<blocks, theads>>>(dx_data, numel);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(trunc, ops::TruncCUDAKernel<float>,
ops::TruncCUDAKernel<double>, ops::TruncCUDAKernel<int>,
ops::TruncCUDAKernel<int64_t>);

REGISTER_OP_CUDA_KERNEL(trunc_grad, ops::TruncCUDAGradKernel<float>,
ops::TruncCUDAGradKernel<double>,
ops::TruncCUDAGradKernel<int>,
ops::TruncCUDAGradKernel<int64_t>);
55 changes: 55 additions & 0 deletions paddle/fluid/operators/trunc_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <math.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
class TruncKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");

size_t numel = x->numel();
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());

for (size_t i = 0; i < numel; i++) {
out_data[i] = trunc(x_data[i]);
}
}
};

template <typename T>
class TruncGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
T* dx_data = dx->mutable_data<T>(context.GetPlace());

int numel = dx->numel();
memset(dx_data, 0.0, numel * sizeof(T));
}
};

} // namespace operators
} // namespace paddle
4 changes: 3 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@
from .tensor.math import prod # noqa: F401
from .tensor.math import broadcast_shape # noqa: F401
from .tensor.math import conj # noqa: F401
from .tensor.math import trunc # noqa: F401
from .tensor.math import neg # noqa: F401
from .tensor.math import lgamma # noqa: F401

Expand Down Expand Up @@ -499,5 +500,6 @@
'log2',
'log10',
'concat',
'check_shape'
'check_shape',
'trunc'
]
88 changes: 88 additions & 0 deletions python/paddle/fluid/tests/unittests/test_trunc_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard

paddle.enable_static()


class TestTruncOp(OpTest):
def setUp(self):
self.op_type = "trunc"
self.dtype = np.float64
np.random.seed(2021)
self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)}
self.outputs = {'Out': (np.trunc(self.inputs['X']))}

def init_dtype_type(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5)


class TestFloatTruncOp(TestTruncOp):
def init_dtype_type(self):
self.dtype = np.float32


class TestIntTruncOp(TestTruncOp):
def init_dtype_type(self):
self.dtype = np.int32


class TestTruncAPI(unittest.TestCase):
def setUp(self):
self.shape = [20, 20]
self.x = np.random.random((20, 20)).astype(np.float32)
self.place = paddle.CPUPlace()

def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.shape)
out = paddle.trunc(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x}, fetch_list=[out])
out_ref = np.trunc(self.x)
for out in res:
self.assertEqual(np.allclose(out, out_ref, rtol=1e-08), True)

def test_api_dygraph(self):
paddle.disable_static(self.place)
x_tensor = paddle.to_tensor(self.x)
out = paddle.trunc(x_tensor)
out_ref = np.trunc(self.x)
self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True)
paddle.enable_static()

def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [20, 20], 'bool')
self.assertRaises(TypeError, paddle.trunc, x)


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
from .math import any # noqa: F401
from .math import broadcast_shape # noqa: F401
from .math import conj # noqa: F401
from .math import trunc # noqa: F401
from .math import neg # noqa: F401
from .math import lgamma # noqa: F401

Expand Down Expand Up @@ -350,5 +351,6 @@
'rank',
'shape',
'real',
'imag'
'imag',
'trunc'
]
Loading