Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] Gaussian Blur clean up #6888

Merged
merged 5 commits into from
Nov 2, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Nov 2, 2022

Related to #6818

This PR:

  • Cleans up the assertions on the gaussian_blur kernel
  • Simplifies the reshaping logic
  • Adds in-place ops where possible

No regression on the speed, just a small 5% improvement on CUDA:

[------------- gaussian_blur cpu torch.float32 -------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |    4 (+-  0) ms  |    4 (+-  0) ms
      (16, 3, 400, 500)  |  306 (+-  3) ms  |  306 (+-  2) ms
6 threads: --------------------------------------------------
      (3, 400, 500)      |    6 (+-  0) ms  |    6 (+-  0) ms
      (16, 3, 400, 500)  |  334 (+-  1) ms  |  334 (+-  4) ms

Times are in milliseconds (ms).

[------------- gaussian_blur cuda torch.float32 ------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |  119 (+-  1) us  |  112 (+-  1) us
      (16, 3, 400, 500)  |  266 (+-  0) us  |  266 (+-  0) us
6 threads: --------------------------------------------------
      (3, 400, 500)      |  119 (+-  2) us  |  113 (+-  2) us
      (16, 3, 400, 500)  |  266 (+-  1) us  |  266 (+-  0) us

Times are in microseconds (us).

[-------------- gaussian_blur cpu torch.uint8 --------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |    5 (+-  0) ms  |    5 (+-  0) ms
      (16, 3, 400, 500)  |  355 (+-  1) ms  |  331 (+-  1) ms
6 threads: --------------------------------------------------
      (3, 400, 500)      |    7 (+-  0) ms  |    7 (+-  0) ms
      (16, 3, 400, 500)  |  383 (+-  2) ms  |  359 (+-  5) ms

Times are in milliseconds (ms).

[-------------- gaussian_blur cuda torch.uint8 -------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |  150 (+-  1) us  |  142 (+-  1) us
      (16, 3, 400, 500)  |  423 (+-  0) us  |  423 (+-  0) us
6 threads: --------------------------------------------------
      (3, 400, 500)      |  150 (+-  3) us  |  142 (+-  3) us
      (16, 3, 400, 500)  |  423 (+-  0) us  |  423 (+-  0) us

Times are in microseconds (us).

cc @vfdev-5 @bjuncek @pmeier

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK to me. I'm not sure why we can't use _cast_squeeze_in, _cast_squeeze_out anymore, but ok.

@datumbox
Copy link
Contributor Author

datumbox commented Nov 2, 2022

@vfdev-5 We could use it. But this means we will have multiple pieces of code handling reshaping (or needs_unsquash and need_squeeze in the previous code). In addition the casting mechanism in _cast_squeeze_in and _cast_squeeze_out makes assumptions over the order of preference of the provided dtypes and requires unnecessary complex checks for rounding. Casting can be simplified; the only thing we need to check is if the input was float and then just round and cast.

@datumbox datumbox merged commit 1921613 into pytorch:main Nov 2, 2022
@datumbox datumbox deleted the prototype/gaussian_blur branch November 2, 2022 13:32
@datumbox
Copy link
Contributor Author

datumbox commented Nov 2, 2022

It seems that this closed the gap between V1 and V2 for the GaussianBlur transform. My new benchmarks between V1+pure tensor and V2+feature report:

[------------- gaussian_blur cpu torch.float32 -------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |   13 (+-  1) ms  |    4 (+-  0) ms
      (16, 3, 400, 500)  |  306 (+-  1) ms  |  306 (+-  0) ms

Times are in milliseconds (ms).

[------------- gaussian_blur cuda torch.float32 ------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |  247 (+- 37) us  |  156 (+-  1) us
      (16, 3, 400, 500)  |  356 (+-  2) us  |  269 (+-  0) us

Times are in microseconds (us).

[-------------- gaussian_blur cpu torch.uint8 --------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |    5 (+-  0) ms  |    5 (+-  0) ms
      (16, 3, 400, 500)  |  351 (+-  1) ms  |  330 (+-  1) ms

Times are in milliseconds (ms).

[-------------- gaussian_blur cuda torch.uint8 -------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |  261 (+-  5) us  |  189 (+-  1) us
      (16, 3, 400, 500)  |  515 (+-  4) us  |  426 (+-  0) us

Times are in microseconds (us).

@vfdev-5 Might be worth rerunning the benchmarks later on your side to confirm.

@datumbox datumbox added the Perf For performance improvements label Nov 2, 2022
facebook-github-bot pushed a commit that referenced this pull request Nov 4, 2022
Summary:
* Refactor gaussian_blur

* Add conditional reshape

* Further refactoring

* Remove unused import.

Reviewed By: datumbox

Differential Revision: D41020542

fbshipit-source-id: 72694024272d91818c4154f7b5f7097e6d21154f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants