Skip to content

Commit

Permalink
fix for torch 1.7.0 error in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
abraham-arun committed Sep 28, 2021
1 parent 19043af commit 308b484
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ def get_qconfig(per_channel):
from torch.quantization.observer import default_weight_observer

if per_channel:
torch.backends.quantized.engine = "fbgemm"
return torch.quantization.get_default_qconfig("fbgemm")
else:
torch.backends.quantized.engine = "qnnpack"
act = MovingAverageMinMaxObserver.with_args(reduce_range=False)
return torch.quantization.QConfig(activation=act, weight=default_weight_observer)

Expand Down Expand Up @@ -298,7 +296,14 @@ def test_quantized_modules():
raw_module.eval()
inp = torch.rand(ishape)

quantize_model(raw_module, inp, per_channel=per_channel)
if module_name == "conv_transpose" and not is_version_greater_than("1.7.1"):
prev_engine = torch.backends.quantized.engine
torch.backends.quantized.engine = "qnnpack"
quantize_model(raw_module, inp, per_channel=per_channel)
torch.backends.quantized.engine = prev_engine
else:
quantize_model(raw_module, inp, per_channel=per_channel)

script_module = torch.jit.trace(raw_module, inp).eval()

with torch.no_grad():
Expand Down

0 comments on commit 308b484

Please sign in to comment.