-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【AMP OP&Test】unit test for test_logit_op #51051
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -456,6 +456,29 @@ struct LogitFunctor { | |||
} | |||
}; | |||
|
|||
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个宏不用加吧,直接写在下面的CudaFunctor处就可以了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本来是为了将 CUDALogitFunctor 和 LogitFunctor 写在文件同一个位置,根据建议修改
|
||
// logit(x) = ln(x/(1-x)) | ||
__device__ __forceinline__ T operator()(const T x) const { | ||
MT y = fminf(x, (one - static_cast<MT>(eps))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个直接用fmin``fmax``log
就行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据建议修改
@@ -55,7 +54,9 @@ def setUp(self): | |||
self.attrs = {'eps': self.eps} | |||
|
|||
def set_attrs(self): | |||
pass | |||
self.dtype = np.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认是fp64 ,这里变成fp32了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为fp32 本来在 NO_FP64_CHECK_GRAD_OP_LIST 中,所以只能使用fp64,根据建议修改
class TestLogitShape(TestLogitOp): | ||
def set_attrs(self): | ||
self.dtype = np.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不用改吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,根据建议修改
@@ -43,9 +45,6 @@ class TestLogitOp(OpTest): | |||
def setUp(self): | |||
self.op_type = 'logit' | |||
self.python_api = paddle.logit | |||
self.dtype = np.float64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些都是默认值,不用删,下面的set_attrs也保留pass,后续的case只重载set_attr就行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为了避免重复赋值,根据建议修改,使得代码逻辑更清晰
__device__ __forceinline__ T operator()(const T x) const { | ||
MT y = fmin(x, (one - static_cast<MT>(eps))); | ||
y = fmax(y, static_cast<MT>(eps)); | ||
y = log(y / (one - y)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API中有这样的描述,好像跟之前的eigen实现相比缺少了这部分
“若 eps 为默认值 None,并且 x < 0 或者 x > 1,该函数将返回 NaN”
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据建议修改
MT x = static_cast<MT>(arg_x); | ||
MT dx = (x < static_cast<MT>(eps) || x > one - static_cast<MT>(eps)) | ||
? zero | ||
: (static_cast<MT>(dout) / x / (one - x)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(static_cast<MT>(dout) / x / (one - x))
这里改成与公式相同,连除可能精度会有损失
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
|
||
// logit(x) = ln(x/(1-x)) | ||
__device__ __forceinline__ T operator()(const T x) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要先把x cast成MT吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for @unittest.skip
PR types
Others
PR changes
Others
Describe
add fp32 fp16 bf16 unit test for test_logit_op