-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][PyTorch] support for quantized conv_transpose2d op #9133
Merged
masahi
merged 3 commits into
apache:main
from
abraham-arun:pt_quant_conv_transpose_support
Sep 29, 2021
Merged
[Frontend][PyTorch] support for quantized conv_transpose2d op #9133
masahi
merged 3 commits into
apache:main
from
abraham-arun:pt_quant_conv_transpose_support
Sep 29, 2021
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
abraham-arun
requested review from
areusch,
comaniac,
Huyuwei,
jroesch,
junrushao,
jwfromm,
kazum,
mbrookhart,
merrymercy,
siju-samuel,
srkreddy1238,
tqchen and
yzhliu
as code owners
September 27, 2021 09:56
@masahi Please review when you get time. |
masahi
approved these changes
Sep 27, 2021
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thank you!
abraham-arun
force-pushed
the
pt_quant_conv_transpose_support
branch
from
September 28, 2021 14:01
be9dc43
to
b3441ad
Compare
PyTorch uses the same underlying function to pack and unpack the params for conv2d and conv_transpose2d ops. This change adds support for quantized conv_transpose2d op by reusing the ConvPackedParam and adding the output_padding param to it. This output_padding param will remain unused in case of conv2d. Also added test for above with specific condition for torch v1.7.1 and below.
abraham-arun
force-pushed
the
pt_quant_conv_transpose_support
branch
from
September 28, 2021 16:18
b3441ad
to
ae0da6b
Compare
AndrewZhaoLuo
added a commit
to AndrewZhaoLuo/tvm
that referenced
this pull request
Sep 29, 2021
* main: Fix flaky NMS test by making sure scores are unique (apache#9140) [Relay] Merge analysis/context_analysis.cc and transforms/device_annotation.cc (apache#9038) [LLVM] Make changes needed for opaque pointers (apache#9138) Arm(R) Ethos(TM)-U NPU codegen integration (apache#8849) [CI] Split Integration tests out of first phase of pipeline (apache#9128) [Meta Schedule][M3b] Runner (apache#9111) Fix Google Mock differences between Ubuntu 18.04 and 16.04 (apache#9141) [TIR] add loop partition hint pragma (apache#9121) fix things (apache#9146) [Meta Schedule][M3a] SearchStrategy (apache#9132) [Frontend][PyTorch] support for quantized conv_transpose2d op (apache#9133) [UnitTest] Parametrized test_conv2d_int8_intrinsics (apache#9143) [OpenCL] Remove redundant visit statement in CodeGen. (apache#9144) [BYOC] support arbitrary input dims for add/mul/relu of dnnl c_src codegen (apache#9127) [Relay][ConvertLayout] Support for qnn.conv2d_transpose (apache#9139) add nn.global_avgpool to fq2i (apache#9137) [UnitTests] Enable minimum testing on Vulkan target in CI (apache#9093) [Torch] Support returning quantized weights and bias for BYOC use cases (apache#9135) [Relay] Prepare for new plan_devices.cc (part II) (apache#9130) [microTVM][Zephyr] Add MIMXRT1050 board support (apache#9068)
AndrewZhaoLuo
added a commit
to AndrewZhaoLuo/tvm
that referenced
this pull request
Sep 30, 2021
* main: (80 commits) Introduce centralised name transformation functions (apache#9088) [OpenCL] Add vectorization to cuda conv2d_nhwc schedule (apache#8636) [6/6] Arm(R) Ethos(TM)-U NPU codegen integration with `tvmc` (apache#8854) [microTVM] Add wrapper for creating project using a MLF (apache#9090) Fix typo (apache#9156) [Hotfix][Testing] Wait for RPCServer to be established (apache#9150) Update find cublas so it search default path if needed. (apache#9149) [TIR][LowerMatchBuffer] Fix lowering strides when source region has higher dimension than the buffer (apache#9145) Fix flaky NMS test by making sure scores are unique (apache#9140) [Relay] Merge analysis/context_analysis.cc and transforms/device_annotation.cc (apache#9038) [LLVM] Make changes needed for opaque pointers (apache#9138) Arm(R) Ethos(TM)-U NPU codegen integration (apache#8849) [CI] Split Integration tests out of first phase of pipeline (apache#9128) [Meta Schedule][M3b] Runner (apache#9111) Fix Google Mock differences between Ubuntu 18.04 and 16.04 (apache#9141) [TIR] add loop partition hint pragma (apache#9121) fix things (apache#9146) [Meta Schedule][M3a] SearchStrategy (apache#9132) [Frontend][PyTorch] support for quantized conv_transpose2d op (apache#9133) [UnitTest] Parametrized test_conv2d_int8_intrinsics (apache#9143) ...
ylc
pushed a commit
to ylc/tvm
that referenced
this pull request
Jan 7, 2022
…#9133) * [Frontend][PyTorch] support for quantized conv_transpose2d op PyTorch uses the same underlying function to pack and unpack the params for conv2d and conv_transpose2d ops. This change adds support for quantized conv_transpose2d op by reusing the ConvPackedParam and adding the output_padding param to it. This output_padding param will remain unused in case of conv2d. Also added test for above with specific condition for torch v1.7.1 and below. * fix after merging main
ylc
pushed a commit
to ylc/tvm
that referenced
this pull request
Jan 13, 2022
…#9133) * [Frontend][PyTorch] support for quantized conv_transpose2d op PyTorch uses the same underlying function to pack and unpack the params for conv2d and conv_transpose2d ops. This change adds support for quantized conv_transpose2d op by reusing the ConvPackedParam and adding the output_padding param to it. This output_padding param will remain unused in case of conv2d. Also added test for above with specific condition for torch v1.7.1 and below. * fix after merging main
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PyTorch uses the same underlying function to pack and unpack the params for conv2d and conv_transpose2d ops.
This change adds support for quantized conv_transpose2d op by reusing the ConvPackedParam and adding the output_padding param to it.
This output_padding param will remain unused in case of conv2d.
Also added test for above with specific condition for torch v1.7.1 and below.