Skip to content

Commit 00600ee

Browse files
committed
[Inductor][float8] Register qconv-binary fusion pass for float8
1 parent 0160379 commit 00600ee

File tree

2 files changed

+63
-20
lines changed

2 files changed

+63
-20
lines changed

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def test_qconv2d_silu_fp8_mixed_bf16_cpu(self):
770770
)
771771

772772
def _qconv2d_add_test_helper(
773-
self, device="cpu", use_relu=False, int8_mixed_bf16=False
773+
self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False
774774
):
775775
r"""
776776
This testcase will quantize a Conv2d->Add pattern as:
@@ -844,11 +844,12 @@ def matcher_check_fn():
844844
(v,),
845845
matcher_check_fn,
846846
check_quantization=True,
847-
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32,
847+
check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32,
848+
is_fp8=is_fp8,
848849
)
849850

850851
def _qconv2d_add_test_helper2(
851-
self, device="cpu", use_relu=False, int8_mixed_bf16=False
852+
self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False
852853
):
853854
r"""
854855
This testcase will quantize two Conv2d->Add patterns as:
@@ -907,8 +908,11 @@ def forward(self, x, x2, x3):
907908
res = self.relu2(res)
908909
return res
909910

911+
add_fn_list = quantization_add_fn_list
912+
if not is_fp8:
913+
add_fn_list = add_fn_list + quantization_inplace_add_fn_list
910914
for add_fn, swap_inputs in itertools.product(
911-
quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True]
915+
add_fn_list, [False, True]
912916
):
913917
mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device)
914918
x = torch.randn(
@@ -941,7 +945,8 @@ def matcher_check_fn():
941945
(x, x2, x3),
942946
matcher_check_fn,
943947
check_quantization=True,
944-
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32,
948+
check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32,
949+
is_fp8=is_fp8,
945950
)
946951

947952
@skipIfNoDynamoSupport
@@ -950,25 +955,55 @@ def test_qconv2d_add_cpu(self):
950955
self._qconv2d_add_test_helper()
951956
self._qconv2d_add_test_helper2()
952957

958+
@skipIfNoDynamoSupport
959+
@skipIfNoONEDNN
960+
@skipIfNoFloat8Support
961+
def test_qconv2d_add_fp8_cpu(self):
962+
self._qconv2d_add_test_helper(is_fp8=True)
963+
self._qconv2d_add_test_helper2(is_fp8=True)
964+
953965
@skipIfNoDynamoSupport
954966
@skipIfNoONEDNNBF16
955967
@skipIfNoONEDNN
956968
def test_qconv2d_add_int8_mixed_bf16(self):
957-
self._qconv2d_add_test_helper(int8_mixed_bf16=True)
958-
self._qconv2d_add_test_helper2(int8_mixed_bf16=True)
969+
self._qconv2d_add_test_helper(mixed_bf16=True)
970+
self._qconv2d_add_test_helper2(mixed_bf16=True)
971+
972+
@skipIfNoDynamoSupport
973+
@skipIfNoONEDNNBF16
974+
@skipIfNoONEDNN
975+
@skipIfNoFloat8Support
976+
def test_qconv2d_add_fp8_mixed_bf16(self):
977+
self._qconv2d_add_test_helper(mixed_bf16=True, is_fp8=True)
978+
self._qconv2d_add_test_helper2(mixed_bf16=True, is_fp8=True)
959979

960980
@skipIfNoDynamoSupport
961981
@skipIfNoONEDNN
962982
def test_qconv2d_add_relu_cpu(self):
963983
self._qconv2d_add_test_helper(use_relu=True)
964984
self._qconv2d_add_test_helper2(use_relu=True)
965985

986+
@skipIfNoDynamoSupport
987+
@skipIfNoONEDNN
988+
@skipIfNoFloat8Support
989+
def test_qconv2d_add_relu_fp8_cpu(self):
990+
self._qconv2d_add_test_helper(use_relu=True, is_fp8=True)
991+
self._qconv2d_add_test_helper2(use_relu=True, is_fp8=True)
992+
966993
@skipIfNoDynamoSupport
967994
@skipIfNoONEDNNBF16
968995
@skipIfNoONEDNN
969996
def test_qconv2d_add_relu_int8_mixed_bf16(self):
970-
self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True)
971-
self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True)
997+
self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True)
998+
self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True)
999+
1000+
@skipIfNoDynamoSupport
1001+
@skipIfNoONEDNNBF16
1002+
@skipIfNoONEDNN
1003+
@skipIfNoFloat8Support
1004+
def test_qconv2d_add_relu_fp8_mixed_bf16(self):
1005+
self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True, is_fp8=True)
1006+
self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True, is_fp8=True)
9721007

9731008
@skipIfNoDynamoSupport
9741009
@skipIfNoONEDNN

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ def fn(match):
441441
return False
442442
binary_node_inputs = next(iter(compute_node.users)).args
443443
assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
444+
is_fp8 = match.kwargs["x"].meta["val"].dtype is torch.float8_e4m3fn
444445
if output_dtype in [torch.float32, torch.bfloat16]:
445446
extra_input_of_binary_node = None
446447
for arg in binary_node_inputs:
@@ -449,7 +450,7 @@ def fn(match):
449450
break
450451
assert extra_input_of_binary_node is not None
451452
# Extra input of binary node comes from dequant pattern
452-
if extra_input_from_dequant and (
453+
if not is_fp8 and extra_input_from_dequant and (
453454
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
454455
or (
455456
extra_input_of_binary_node.target
@@ -2293,37 +2294,44 @@ def _register_qconv_unary_fusion():
22932294

22942295

22952296
def _register_qconv_binary_fusion():
2296-
for int8_mixed_bf16_with_inplace_add in [False, True]:
2297+
for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]):
2298+
qconv_binary_op = (
2299+
torch.ops.onednn.qconv2d_pointwise.binary_tensor
2300+
if x_scale_zp_are_tensors
2301+
else torch.ops.onednn.qconv2d_pointwise.binary
2302+
)
22972303
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
22982304
swap_binary_inputs_list = [False, True]
22992305
binary_replace_patterns = {}
2300-
for swap_inputs in swap_binary_inputs_list:
2306+
for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]):
23012307
binary_replace_patterns.update(
23022308
{
23032309
PostOpAttr(
23042310
"sum", 1.0, "none", [], ""
23052311
): generate_pattern_with_output_quant(
23062312
generate_pattern_with_binary(
23072313
aten.add.Tensor,
2308-
get_qconv_pt2e_pattern(users=1),
2314+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
23092315
dequantize_accum_pattern,
23102316
int8_mixed_bf16_with_inplace_add,
23112317
swap_inputs=swap_inputs,
23122318
),
2319+
is_fp8=is_fp8,
23132320
),
23142321
PostOpAttr(
23152322
"sum", 1.0, "relu", [], ""
23162323
): generate_pattern_with_output_quant(
23172324
generate_pattern_with_unary(
23182325
generate_pattern_with_binary(
23192326
aten.add.Tensor,
2320-
get_qconv_pt2e_pattern(users=1),
2327+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
23212328
dequantize_accum_pattern,
23222329
int8_mixed_bf16_with_inplace_add,
23232330
swap_inputs=swap_inputs,
23242331
),
23252332
aten.relu.default,
23262333
),
2334+
is_fp8=is_fp8,
23272335
),
23282336
}
23292337
)
@@ -2332,7 +2340,7 @@ def _register_qconv_binary_fusion():
23322340
_register_qconv_post_op_fusion_pass(
23332341
patterns,
23342342
3, # pass_number
2335-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
2343+
qconv_binary_op, # computation_op
23362344
binary_unary_attr, # binary_unary_attr
23372345
)
23382346

@@ -2344,7 +2352,7 @@ def _register_qconv_binary_fusion():
23442352
PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
23452353
generate_pattern_with_binary(
23462354
aten.add.Tensor,
2347-
get_qconv_pt2e_pattern(users=1),
2355+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
23482356
KeywordArg("accum_after_dequant"),
23492357
int8_mixed_bf16_with_inplace_add,
23502358
swap_inputs=swap_inputs,
@@ -2362,14 +2370,14 @@ def _register_qconv_binary_fusion():
23622370
_register_qconv_post_op_fusion_pass(
23632371
patterns,
23642372
3, # pass_number
2365-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
2373+
qconv_binary_op, # computation_op
23662374
binary_unary_attr, # binary_unary_attr
23672375
)
23682376
else:
23692377
_register_qconv_post_op_fusion_pass(
23702378
patterns,
23712379
4, # pass_number
2372-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
2380+
qconv_binary_op, # computation_op
23732381
binary_unary_attr, # binary_unary_attr
23742382
)
23752383

@@ -2382,7 +2390,7 @@ def _register_qconv_binary_fusion():
23822390
"sum", 1.0, "none", [], ""
23832391
): generate_pattern_with_binary(
23842392
aten.add.Tensor,
2385-
get_qconv_pt2e_pattern(users=1),
2393+
get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1),
23862394
KeywordArg("accum_after_dequant"),
23872395
int8_mixed_bf16_with_inplace_add,
23882396
swap_inputs=swap_inputs,
@@ -2397,7 +2405,7 @@ def _register_qconv_binary_fusion():
23972405
_register_qconv_post_op_fusion_pass(
23982406
patterns,
23992407
4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number
2400-
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
2408+
qconv_binary_op, # computation_op
24012409
binary_unary_attr, # binary_unary_attr
24022410
)
24032411

0 commit comments

Comments
 (0)