Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new api trunc, test=develop #33371

Merged
merged 14 commits into from
Jun 16, 2021
Merged

new api trunc, test=develop #33371

merged 14 commits into from
Jun 16, 2021

Conversation

zhangbo9674
Copy link
Contributor

@zhangbo9674 zhangbo9674 commented Jun 7, 2021

PR types

New features

PR changes

APIs

Describe

Create a new paddle API: paddle.trunc(x, name=none).

  1. API描述:
    对一个输入tensor,返回一个新的tensor,其中包含输入的截断整数值。支持的数据类型包括:int32、int64、float、double。

2.新增内容综述:
(1)新增两个Op:正向Op:trunc、对应的反向Op:trunc_grad;
(2)新增python API:paddle.trunc();
(3)新增单测文件:test_trunc_op.py;

3.API使用示例:

import paddle

input = paddle.rand([2,2],'float32')
print(input)
# Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#         [[0.02331470, 0.42374918],
#         [0.79647720, 0.74970269]])

output = paddle.trunc(input)
print(output)
# Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#         [[0., 0.],
#         [0., 0.]]))

4.文档预览截图:
图片

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jun 7, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Comment on lines 23 to 25
__global__ void TruncGrad(const T* dout, T* dx, int N) {
CUDA_KERNEL_LOOP(index, N) { dx[index] = 0.0; }
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input argument dx is not needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!


void Apply(GradOpPtr<T> retv) const override {
retv->SetType("trunc_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can register NoNeedBufferVars for X@GRAD to save memory, see https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/op_notes_cn.html#id6 for details.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());

int numel = x->numel();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int -> int64_t, since there are tensors of huge size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

# [0., 0.]])
'''
if in_dygraph_mode():
out = _varbase_creator(dtype=x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

attrs = {}

helper = LayerHelper("trunc", **locals())
check_variable_and_dtype(x, 'X', ['float16', 'float32', 'float64'], 'trunc')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems float16 is not supported in c++ code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Examples:
.. code-block:: python
import paddle
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use paddle.rand() directly instead of numpy.random.random

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

self.outputs = {'Out': (np.trunc(self.inputs['X']))}

def init_dtype_type(self):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put self.dtype = np.float64 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 1 to 10
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are blank lines in copyright.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it is changed, done, thanks!

@zhangbo9674 zhangbo9674 closed this Jun 7, 2021
@zhangbo9674 zhangbo9674 reopened this Jun 7, 2021
namespace operators {

template <typename T>
class truncFunctor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class truncFunctor {
class TruncFunctor {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

public:
__device__ truncFunctor(const T x) : _x(x) {}
__device__ T operator()() { return trunc(_x); }
const T _x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x -> x, plz refer to google code stype: https://google.github.io/styleguide/cppguide.html.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks


template <typename T>
__global__ void TruncGrad(T* dx, int64_t N) {
CUDA_KERNEL_LOOP(index, N) { dx[index] = 0.0; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
CUDA_KERNEL_LOOP(index, N) { dx[index] = 0.0; }
CUDA_KERNEL_LOOP(index, N) {
dx[index] = static_cast<T>(0.0);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks

Comment on lines 66 to 67
dim3 blockSize(256);
dim3 gridSize((numel + blockSize.x - 1) / blockSize.x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

theads = platform::PADDLE_CUDA_NUM_THREADS;
blocks = (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 1 to 10
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it changed?

@@ -855,6 +855,44 @@ def add_n(inputs, name=None):
return out


def trunc(input, name=None):
'''
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增API需要文档预览截图,并且需要同时提交中文文档

Copy link
Contributor Author

@zhangbo9674 zhangbo9674 Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1)新增API文档预览截图:
图片
图片
(2)中文文档pr:
PaddlePaddle/docs#3585


paddle.set_device('cpu')
input = paddle.rand([2,2],'float32')
output = paddle.trunc(input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有给出原始tensor内容,可能会影响用户对行为的理解

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加原始tensor内容,谢谢!


import paddle

paddle.set_device('cpu')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要特意设置CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有必要特意设置CPU,已删除,谢谢!

paddle.set_device('cpu')
input = paddle.rand([2,2],'float32')
output = paddle.trunc(input)
print(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释给出的输出与代码不符,可以改为print(output.numpy()),或者把注释信息改为tensor的输出

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已将注释信息改为tensor的输出,谢谢!

zhiqiu
zhiqiu previously approved these changes Jun 10, 2021
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

lanxianghit
lanxianghit previously approved these changes Jun 10, 2021
Copy link
Contributor

@lanxianghit lanxianghit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This API is used to returns a new tensor with the truncated integer values of input.

Args:
input (Tensor): The input tensor, it's data type should be int, float, double.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data type 应具体一点,如 int8、int16、int32、float32 等;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

@zhangbo9674 zhangbo9674 dismissed stale reviews from lanxianghit and zhiqiu via 2783560 June 11, 2021 02:05
TCChenlong
TCChenlong previously approved these changes Jun 11, 2021
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

zhiqiu
zhiqiu previously approved these changes Jun 11, 2021
lanxianghit
lanxianghit previously approved these changes Jun 11, 2021
@zhangbo9674 zhangbo9674 dismissed stale reviews from zhiqiu and TCChenlong via 982b57b June 11, 2021 06:45
zhiqiu
zhiqiu previously approved these changes Jun 15, 2021
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants