Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] Clean up and port the resize kernel in V2 #6892

Merged
merged 6 commits into from
Nov 3, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Nov 2, 2022

This PR:

  • Moves the implementation of resize into V2
  • Removes unnecessary logic related to reshaping from the previous implementation and performs equivalent simplifications
  • Does a few minor perf optimzations such inplace clamp/round etc

No speed regression, small speed improvements due to simplified logic:

InterpolationMode.BICUBIC

[----------------- resize cpu torch.float32 ----------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |    1 (+-  0) ms  |    1 (+-  0) ms
      (16, 3, 400, 500)  |   22 (+-  0) ms  |   22 (+-  0) ms
6 threads: --------------------------------------------------
      (3, 400, 500)      |    1 (+-  0) ms  |    1 (+-  0) ms
      (16, 3, 400, 500)  |   22 (+-  0) ms  |   22 (+-  0) ms

Times are in milliseconds (ms).

[---------------- resize cuda torch.float32 ----------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |   23 (+-  0) us  |   22 (+-  0) us
      (16, 3, 400, 500)  |  104 (+-  1) us  |  104 (+-  0) us
6 threads: --------------------------------------------------
      (3, 400, 500)      |   23 (+-  0) us  |   22 (+-  0) us
      (16, 3, 400, 500)  |  104 (+-  0) us  |  104 (+-  0) us

Times are in microseconds (us).

[------------------ resize cpu torch.uint8 -----------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 400, 500)  |   45 (+-  0) ms  |   42 (+-  0) ms
6 threads: --------------------------------------------------
      (3, 400, 500)      |    2 (+-  0) ms  |    2 (+-  0) ms
      (16, 3, 400, 500)  |   47 (+-  0) ms  |   43 (+-  1) ms

Times are in milliseconds (ms).

[----------------- resize cuda torch.uint8 -----------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |   60 (+-  0) us  |   56 (+-  0) us
      (16, 3, 400, 500)  |  190 (+-  0) us  |  188 (+-  0) us
6 threads: --------------------------------------------------
      (3, 400, 500)      |   61 (+-  1) us  |   56 (+-  0) us
      (16, 3, 400, 500)  |  190 (+-  1) us  |  188 (+-  0) us
Times are in microseconds (us).

InterpolationMode.BILINEAR:
[------------------ resize cpu torch.float32 -----------------]
                         |         old       |         new     
1 threads: ----------------------------------------------------
      (3, 400, 500)      |   320 (+-  1) us  |   319 (+-  1) us
      (16, 3, 400, 500)  |  6120 (+-100) us  |  6037 (+- 37) us
6 threads: ----------------------------------------------------
      (3, 400, 500)      |   372 (+-  4) us  |   370 (+-  8) us
      (16, 3, 400, 500)  |  6111 (+-174) us  |  6113 (+-109) us

Times are in microseconds (us).

[---------------- resize cuda torch.float32 ----------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |   17 (+-  0) us  |   16 (+-  0) us
      (16, 3, 400, 500)  |   53 (+-  0) us  |   53 (+-  0) us
6 threads: --------------------------------------------------
      (3, 400, 500)      |   17 (+-  0) us  |   16 (+-  0) us
      (16, 3, 400, 500)  |   53 (+-  0) us  |   53 (+-  0) us

Times are in microseconds (us).

[-------------------- resize cpu torch.uint8 -------------------]
                         |         old        |         new      
1 threads: ------------------------------------------------------
      (3, 400, 500)      |   624 (+-  1) us   |   591 (+-  1) us 
      (16, 3, 400, 500)  |  27336 (+- 47) us  |  25671 (+-182) us
6 threads: ------------------------------------------------------
      (3, 400, 500)      |   797 (+- 18) us   |   777 (+- 12) us 
      (16, 3, 400, 500)  |  28049 (+-321) us  |  26405 (+-237) us

Times are in microseconds (us).

[----------------- resize cuda torch.uint8 -----------------]
                         |        old       |        new     
1 threads: --------------------------------------------------
      (3, 400, 500)      |   44 (+-  0) us  |   41 (+-  0) us
      (16, 3, 400, 500)  |  131 (+-  0) us  |  130 (+-  0) us
6 threads: --------------------------------------------------
      (3, 400, 500)      |   44 (+-  0) us  |   41 (+-  1) us
      (16, 3, 400, 500)  |  131 (+-  1) us  |  130 (+-  2) us

Times are in microseconds (us).

cc @vfdev-5 @bjuncek @pmeier

Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question, otherwise LGTM if CI is green! Thanks Vasilis.

torchvision/prototype/transforms/functional/_geometry.py Outdated Show resolved Hide resolved
torchvision/prototype/transforms/functional/_geometry.py Outdated Show resolved Hide resolved
Comment on lines +146 to +148
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
image = image.clamp_(min=0, max=255)
image = image.round_().to(dtype=dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to self that I need to address this in #6830.

@datumbox datumbox merged commit d8cec34 into pytorch:main Nov 3, 2022
@datumbox datumbox deleted the prototype/port_resize branch November 3, 2022 09:22
facebook-github-bot pushed a commit that referenced this pull request Nov 4, 2022
Summary:
* Ported `resize`

* Align with previous behaviour

* Update torchvision/prototype/transforms/functional/_geometry.py

* Moving input verification on top of method.

Reviewed By: datumbox

Differential Revision: D41020541

fbshipit-source-id: ef55147eb263fc530e2068b914d1e2b539415260

Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants