Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch][Quantized] Fix converting serialized quantized models #5839

Merged
merged 3 commits into from
Jun 18, 2020

Conversation

masahi
Copy link
Member

@masahi masahi commented Jun 18, 2020

This is a workaround for the issue reported in pytorch/pytorch#39690

In short, if a quantized PyTorch model is serialized and loaded back, dtypes of output tensors are dropped and the loaded model doesn't have QUInt8 types at all.

This becomes a problem when converting some Torch ops. For example, below the output dtype of quantize_per_tensor becomes float (Tensor means float tensor, wrong), so aten::adaptive_avg_pool2d thinks this is a float operation. But obviously the output of aten::quantize_per_tensor should be a quantized tensor, so aten::adaptive_avg_pool2d has to be converted to the quantized version.

The quantized resnet in torchvision uses aten::adaptive_avg_pool2d. So right now if we save and load back the qresnet, we get garbage result.

  %input.1 : Tensor = aten::quantize_per_tensor(%X.1, %7, %8, %9) # /home/masa/anaconda3
  ...
  %Xq.1 : Tensor = aten::adaptive_avg_pool2d(%input.1, %12)

please review @siju-samuel @anijain2305
cc @jjohnson-arm

tests/python/frontend/pytorch/qnn_test.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/pytorch.py Outdated Show resolved Hide resolved
Copy link
Member

@siju-samuel siju-samuel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the fix @masahi

@masahi masahi merged commit 082874c into apache:master Jun 18, 2020
Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 30, 2020
…#5839)

* [Torch] Fix converting serialized quantized models

* clean up dtype check

* comment clean up
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Jul 2, 2020
…#5839)

* [Torch] Fix converting serialized quantized models

* clean up dtype check

* comment clean up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants