-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mkldnn_softmax #4331
Add mkldnn_softmax #4331
Conversation
89cd07a
to
799f80a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MKLDNNActivation.h中resetFwd和resetBwd的实现能放到MKLDNNActivation.cpp中么,因为这两个函数的实现太长了,放在头文件中不合适。
return Error(); | ||
/** | ||
* @brief Base class of MKLDNN softmax Activation, | ||
* only have mkldnn forward, use cpu implement for backward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请问为什么这里不用mkldnn backward呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为mkldnn目前还没有实现softmax的backward,所以先保留与CPU的实现一样。
"mkldnn_" #ACT_TYPE); \ | ||
}); | ||
|
||
/** | ||
* @def DEFINE_MKLDNN_ELTWISE_ACTIVATION | ||
*/ | ||
#define DEFINE_MKLDNN_ELTWISE_ACTIVATION(ACT_TYPE, ALPHA, BWD_ALPHA) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MKLDNN_ACTIVATION_CLASS_NAME
类能继承MKLDNN_ACTIVATION_CLASS_NAME
么?
·这样以下内容就不用重复写了:
55行,60行,64-65行,68-71行。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,可以再简化点。
paddle/math/Matrix.cpp
Outdated
@@ -3637,7 +3637,7 @@ void CpuMatrix::oneHotCrossEntropy(Matrix& output, IVector& label) { | |||
for (size_t i = 0; i < numSamples; ++i, out += dim) { | |||
CHECK_GE(lbl[i], 0); | |||
CHECK_LT((size_t)lbl[i], dim); | |||
cost[i] = -std::log(out[lbl[i]]); | |||
cost[i] = -std::log(std::max(out[lbl[i]], real(FLT_MIN))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FLT_MIN这个数是在哪儿设置的?paddle中没有这个变量。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
定义在#include <float.h>
里面
paddle/math/Matrix.cpp
Outdated
@@ -3652,7 +3652,7 @@ void CpuMatrix::oneHotCrossEntropyBp(Matrix& output, IVector& label) { | |||
real* grad = getData(); | |||
int* lbl = label.getData(); | |||
for (size_t i = 0; i < numSamples; ++i, out += dim, grad += dim) { | |||
grad[lbl[i]] -= 1 / out[lbl[i]]; | |||
grad[lbl[i]] -= 1 / std::max(out[lbl[i]], real(FLT_MIN)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
定义在#include <float.h>
里面
@@ -93,42 +128,21 @@ class MKLDNNEltwiseActivation : public MKLDNNActivation { | |||
return (mkldnn::algorithm)0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
119-124行,可以用正则表达式,把mkldnn换成eltwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
会用map来简化下。
c5216eb
to
d5f5828
Compare
d5f5828
to
2c6ac62
Compare
if (outputG->useGpu()) { | ||
outputG->softmaxBackward(*outputV); | ||
} else { | ||
SetDevice device(act.deviceId); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一段是直接复制的SoftmaxActivation::backward吧,多复制了193-196行。这里可以使用MatrixPtr么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确实可以删掉,thx。 done
007a64a
to
043aa8a
Compare
043aa8a
to
672c968
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
solve #4330