Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【AMP OP&Test】unit test for test_logit_op #51051

Merged
merged 4 commits into from
Mar 22, 2023

Conversation

zhangbopd
Copy link
Contributor

PR types

Others

PR changes

Others

Describe

add fp32 fp16 bf16 unit test for test_logit_op

@paddle-bot
Copy link

paddle-bot bot commented Mar 1, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zhangbopd zhangbopd changed the title 【AMP OP&Test】test_logit_op 【AMP OP&Test】unit test for test_logit_op Mar 1, 2023
@@ -456,6 +456,29 @@ struct LogitFunctor {
}
};

#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个宏不用加吧,直接写在下面的CudaFunctor处就可以了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本来是为了将 CUDALogitFunctor 和 LogitFunctor 写在文件同一个位置,根据建议修改


// logit(x) = ln(x/(1-x))
__device__ __forceinline__ T operator()(const T x) const {
MT y = fminf(x, (one - static_cast<MT>(eps)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个直接用fmin``fmax``log就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

@@ -55,7 +54,9 @@ def setUp(self):
self.attrs = {'eps': self.eps}

def set_attrs(self):
pass
self.dtype = np.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认是fp64 ,这里变成fp32了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为fp32 本来在 NO_FP64_CHECK_GRAD_OP_LIST 中,所以只能使用fp64,根据建议修改

class TestLogitShape(TestLogitOp):
def set_attrs(self):
self.dtype = np.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不用改吧

Copy link
Contributor Author

@zhangbopd zhangbopd Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,根据建议修改

@@ -43,9 +45,6 @@ class TestLogitOp(OpTest):
def setUp(self):
self.op_type = 'logit'
self.python_api = paddle.logit
self.dtype = np.float64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些都是默认值,不用删,下面的set_attrs也保留pass,后续的case只重载set_attr就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为了避免重复赋值,根据建议修改,使得代码逻辑更清晰

__device__ __forceinline__ T operator()(const T x) const {
MT y = fmin(x, (one - static_cast<MT>(eps)));
y = fmax(y, static_cast<MT>(eps));
y = log(y / (one - y));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API中有这样的描述,好像跟之前的eigen实现相比缺少了这部分
“若 eps 为默认值 None,并且 x < 0 或者 x > 1,该函数将返回 NaN”

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

MT x = static_cast<MT>(arg_x);
MT dx = (x < static_cast<MT>(eps) || x > one - static_cast<MT>(eps))
? zero
: (static_cast<MT>(dout) / x / (one - x));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(static_cast<MT>(dout) / x / (one - x))这里改成与公式相同,连除可能精度会有损失

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

// logit(x) = ln(x/(1-x))
__device__ __forceinline__ T operator()(const T x) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要先把x cast成MT吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for @unittest.skip

@ZzSean ZzSean merged commit 289677e into PaddlePaddle:develop Mar 22, 2023
@zhangbopd zhangbopd deleted the test_logit_op branch March 23, 2023 08:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants