-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.46】为 Paddle gumbel_softmax 算子实现 float16 数据类型支持 #50923
Changes from 1 commit
c79ba20
2ee3f49
3eb74a6
5a6031f
5965ebf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,14 @@ def init_attrs(self): | |
self.dtype = "float64" | ||
|
||
|
||
class TestGumbelSoftmaxOp6(TestGumbelSoftmaxOp): | ||
def init_attrs(self): | ||
self.shape = [20, 10, 5] | ||
self.attrs = {"hard": True, "axis": 1} | ||
self.count_expected = 100 | ||
self.dtype = np.float16 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这4个单测继承TestGumbelSoftmaxFP16OP。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 您好,因为前面TestGumbelSoftmax_ZeroDim_FP16OP是针对于ZeroDim的,所以内部没有init_attrs()函数。无法更改名字为TestGumbelSoftmaxFP16OP。所以直接继承自TestGumbelSoftmaxOp。 |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FP16的单测需要参考低精度算子的单测规范进行修改: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
class TestGumbelSoftmaxOpSampleDistribution(OpTest): | ||
def softmax(self, x): | ||
x_row_max = x.max(axis=-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是否有修改的必要?无论T为何种类型,这里都cast到FP16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done