Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize prelu alpha grad #7600

Merged
merged 9 commits into from
Feb 26, 2022
Merged

Conversation

Flowingsun007
Copy link
Contributor

No description provided.

@@ -400,10 +401,12 @@ class GpuPReluGradKernel final : public user_op::OpKernel {
alpha->dptr<T>(), dy->dptr<T>(), dx->mut_dptr<T>(),
broadcasted_alpha_diff);
}
NdarrayUtil<DeviceType::kCUDA, T>::ReduceSum(
if(alpha_requires_grad){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个if应该放到386行,并且如果alpha不需要grad的话,底下不需要申请tmp_buffer

@@ -400,10 +401,12 @@ class GpuPReluGradKernel final : public user_op::OpKernel {
alpha->dptr<T>(), dy->dptr<T>(), dx->mut_dptr<T>(),
broadcasted_alpha_diff);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcasted_alpha_diff 也不应该算,可以在cuda kernel中处理一下不写broadcasted_alpha_diff的情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@Flowingsun007 Flowingsun007 marked this pull request as ready for review February 25, 2022 08:35
@Flowingsun007 Flowingsun007 enabled auto-merge (squash) February 25, 2022 12:25
@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@Flowingsun007 Flowingsun007 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot February 26, 2022 01:33
@@ -4633,6 +4633,9 @@ def OneFlow_PreluGradOp : OneFlow_BaseOp<"prelu_grad", [NoSideEffect, DeclareOpI
OneFlow_Tensor:$dx,
OneFlow_Tensor:$alpha_diff
);
let attrs = (ins
DefaultValuedAttr<BoolAttr, "false">:$alpha_requires_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是说 alpha_requires_grad 默认是 false 吗,感觉它不应该有默认值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm,感觉默认为false比较符合直觉?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉没有很强的理由默认选择 false,是不是要求用户显式传入会好一些

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉没有很强的理由默认选择 false,是不是要求用户显式传入会好一些

同意,而且有默认也应该是true,因为这里应为false但是没设置导致默认为true,只对性能有影响,反过来影响正确性

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉没有很强的理由默认选择 false,是不是要求用户显式传入会好一些

好的,那我改成true吧(这个是可以显示传入的,如果alpha的requires_grad=True/False就显示传入了

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

✔️ OneFlow resnet50 time: 128.5ms (= 12853.7ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 136.8ms (= 13680.8ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.06 (= 136.8ms / 128.5ms)

❌ OneFlow resnet50 time: 79.5ms (= 7954.1ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 82.4ms (= 8243.4ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.04 (= 82.4ms / 79.5ms)

OneFlow resnet50 time: 50.4ms (= 10087.1ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 57.4ms (= 11482.3ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.14 (= 57.4ms / 50.4ms)

OneFlow resnet50 time: 43.7ms (= 8733.8ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 45.6ms (= 9121.1ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.04 (= 45.6ms / 43.7ms)

OneFlow resnet50 time: 38.7ms (= 7732.9ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 40.2ms (= 8045.7ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.04 (= 40.2ms / 38.7ms)

✔️ OneFlow resnet50 time: 142.5ms (= 14249.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 160.7ms (= 16074.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.13 (= 160.7ms / 142.5ms)

OneFlow resnet50 time: 88.6ms (= 8858.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 103.9ms (= 10385.7ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.17 (= 103.9ms / 88.6ms)

OneFlow resnet50 time: 61.5ms (= 12309.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 74.0ms (= 14791.9ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.20 (= 74.0ms / 61.5ms)

OneFlow resnet50 time: 50.8ms (= 10160.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.5ms (= 13509.4ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.33 (= 67.5ms / 50.8ms)

OneFlow resnet50 time: 48.3ms (= 9664.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 62.6ms (= 12525.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.30 (= 62.6ms / 48.3ms)

@Flowingsun007 Flowingsun007 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot February 26, 2022 03:21
@Flowingsun007 Flowingsun007 merged commit b7ad753 into master Feb 26, 2022
@Flowingsun007 Flowingsun007 deleted the dev_optinal_prelu_alpha_grad branch February 26, 2022 05:13
marigoold pushed a commit that referenced this pull request Mar 15, 2022
* optimize prelu alpha grad

* refine

* refine

* refine
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants