-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dropout static when axis != None #37223
Fix dropout static when axis != None #37223
Conversation
Thanks for your contribution! |
… fix_dropout_axis
@@ -940,6 +940,7 @@ def get_attrs(prog, dropout_prob, is_test, seed): | |||
|
|||
#get mask shape | |||
input_shape = x.shape | |||
input_shape_tensor = paddle.shape(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了减小对性能的影响,是否区分动静态图呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是说动态图的时候可以少走一个op是么
x=input, p=0.7, axis=1, training=True, mode='upscale_in_train') | ||
|
||
in_np = np.ones([40, 40]).astype("float32") | ||
in_np2 = np.ones([1, 250000000]).astype("float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为这个指定axis的分支不会走dropout op,感觉加的单测不用和去dropout OP比较,能跑下来就行,这个大size的单测可能会带来其他问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
* fix dropout static when axis != None * update dropout test * add dropout test * fix test * Update test_dropout_op.py * Update test_dropout_op.py * fix testcase * fix testcase * Update test_dropout_op.py * fix testcase * fix testcase * optimize perf * add new test * fix testcase
* fix dropout static when axis != None * update dropout test * add dropout test * fix test * Update test_dropout_op.py * Update test_dropout_op.py * fix testcase * fix testcase * Update test_dropout_op.py * fix testcase * fix testcase * optimize perf * add new test * fix testcase
* fix dropout static when axis != None * update dropout test * add dropout test * fix test * Update test_dropout_op.py * Update test_dropout_op.py * fix testcase * fix testcase * Update test_dropout_op.py * fix testcase * fix testcase * optimize perf * add new test * fix testcase
PR types
Bug fixes
PR changes
APIs
Describe
Fix dropout static when axis != None