Skip to content

Commit

Permalink
[Quant][PT2E] Enable vec code gen for pair of quant/dequant (pytorch#…
Browse files Browse the repository at this point in the history
…104503)

**Summary**
We have supported the vectorization code gen with pattern of `dequant-relu-quant`, for which `to_uint8` is the last node of quant pattern before store into memory. However, there is another case that `dequant1-relu-quant2-dequant2-relu-quant3`. In this case, `quant2` is at the middle of fusion pattern, we enable vectorization code gen of `quant2-dequant2` in this PR.

**Test Plan**
```
python -u -m pytest -s -v test_cpu_repro.py  -k test_dequant_relu_quant_dequant_relu_quant_lowering
```

**Next Step**
* For better performance, we can add another pass to eliminate pair nodes of `float_to_uint8` and `uint8_to_float`.
* For better performance, we should annotate `dequant1` and `quant2` as share observer in quantization recipe. Then we can lower `dequant1-relu-quant2` into a QReLU node to fully eliminate the calculation of `dequant1` and `quant2`.

Pull Request resolved: pytorch#104503
Approved by: https://github.com/jgong5, https://github.com/jansel
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Jul 5, 2023
1 parent 12ca224 commit ea4d5c4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
47 changes: 47 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,53 @@ def channel_shuffle(x, groups, output_scale, output_zero_point):
self.common(channel_shuffle, (x, 2, output_scale, output_zero_point))
assert metrics.generated_cpp_vec_kernel_count == 2

@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_dequant_relu_quant_dequant_relu_quant_lowering(self):
def fn(x, scale, zero_point, scale2, zero_point2, scale3, zero_point3):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, scale, zero_point, 0, 255, torch.uint8
)
x = torch.relu(x)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, scale2, zero_point2, 0, 255, torch.uint8
)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, scale2, zero_point2, 0, 255, torch.uint8
)
x = torch.relu(x)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, scale3, zero_point3, 0, 255, torch.uint8
)
return x

for use_tensor_overload in [True, False]:
x = torch.clamp(
torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255
).to(torch.uint8)
zero_point_list = [100, 101, 102]
scale_list = [0.01, 0.02, 0.03]
if use_tensor_overload:
for i in range(len(zero_point_list)):
zero_point_list[i] = torch.tensor(
zero_point_list[i], dtype=torch.int64
)
scale_list[i] = torch.tensor(scale_list[i])
zero_point, zero_point2, zero_point3 = zero_point_list
scale, scale2, scale3 = scale_list
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(
fn,
(x, scale, zero_point, scale2, zero_point2, scale3, zero_point3),
rtol=1e-2,
atol=1e-2,
)
assert metrics.generated_cpp_vec_kernel_count == 1

def test_inplace_add_alpha(self):
def fn(x, y):
aten.add_.Tensor(x, y, alpha=0.55)
Expand Down
17 changes: 16 additions & 1 deletion torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,7 +2123,22 @@ def to_dtype(x, dtype):
elif dtype == torch.bool:
pass
elif dtype == torch.uint8:
if not all(usr.target in ["store"] for usr in cur_node.users):
# Only allow below 2 cases:
# Case 1: to_uint8 and store which corresponding to the single quant node
# at last of fusion pattern.
is_to_uint8_and_store = all(
usr.target in ["store"] for usr in cur_node.users
)
# Case 2: to_uint8 and to_float which corresponding to pair of quant/dequant node
# at middle of fusion pattern.
is_to_uint8_and_to_float = all(
(
usr.target in ["to_dtype"]
and usr.args[2] == torch.float32
)
for usr in cur_node.users
)
if not (is_to_uint8_and_store or is_to_uint8_and_to_float):
self.disable_vec(f"to_dtype: dtype {dtype}")
else:
self.disable_vec(f"to_dtype: dtype {dtype}")
Expand Down

0 comments on commit ea4d5c4

Please sign in to comment.