Skip to content

Commit 26c536c

Browse files
committed
[Inductor][float8] Register qconv-binary fusion pass for float8
1 parent 04bf850 commit 26c536c

File tree

2 files changed

+157
-59
lines changed

2 files changed

+157
-59
lines changed

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 109 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -512,9 +512,10 @@ def test_qconv2d_fp8_mixed_bf16(self):
512512
def _qconv2d_unary_test_helper(
513513
self,
514514
device="cpu",
515-
int8_mixed_bf16=False,
515+
mixed_bf16=False,
516516
unary_op=torch.nn.ReLU(),
517517
qconv_unary_matcher_nodes=None,
518+
is_fp8=False,
518519
):
519520
class M(torch.nn.Module):
520521
def __init__(
@@ -563,8 +564,9 @@ def matcher_check_fn():
563564
mod,
564565
(v,),
565566
check_quantization=True,
566-
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32,
567+
check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32,
567568
matcher_check_fn=matcher_check_fn,
569+
is_fp8=is_fp8,
568570
)
569571

570572
@skipIfNoDynamoSupport
@@ -575,14 +577,23 @@ def test_qconv2d_relu_cpu(self):
575577
"""
576578
self._qconv2d_unary_test_helper(device="cpu")
577579

580+
@skipIfNoDynamoSupport
581+
@skipIfNoONEDNN
582+
@skipIfNoFloat8Support
583+
def test_qconv2d_relu_fp8_cpu(self):
584+
r"""
585+
This testcase will quantize Conv2d->ReLU pattern.
586+
"""
587+
self._qconv2d_unary_test_helper(device="cpu", is_fp8=True)
588+
578589
@skipIfNoDynamoSupport
579590
@skipIfNoONEDNNBF16
580591
@skipIfNoONEDNN
581592
def test_qconv2d_relu_int8_mixed_bf16_xpu(self):
582593
r"""
583594
This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
584595
"""
585-
self._qconv2d_unary_test_helper(int8_mixed_bf16=True)
596+
self._qconv2d_unary_test_helper(mixed_bf16=True)
586597

587598
@skipIfNoDynamoSupport
588599
@skipIfNoONEDNN
@@ -592,6 +603,15 @@ def test_qconv2d_relu6_cpu(self):
592603
"""
593604
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6())
594605

606+
@skipIfNoDynamoSupport
607+
@skipIfNoONEDNN
608+
@skipIfNoFloat8Support
609+
def test_qconv2d_relu6_fp8_cpu(self):
610+
r"""
611+
This testcase will quantize Conv2d->ReLU6 pattern.
612+
"""
613+
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True)
614+
595615
@skipIfNoDynamoSupport
596616
@skipIfNoONEDNN
597617
def test_qconv2d_hardtanh_cpu(self):
@@ -600,6 +620,15 @@ def test_qconv2d_hardtanh_cpu(self):
600620
"""
601621
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh())
602622

623+
@skipIfNoDynamoSupport
624+
@skipIfNoONEDNN
625+
@skipIfNoFloat8Support
626+
def test_qconv2d_hardtanh_fp8_cpu(self):
627+
r"""
628+
This testcase will quantize Conv2d->Hardtanh pattern.
629+
"""
630+
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True)
631+
603632
@skipIfNoDynamoSupport
604633
@skipIfNoONEDNNBF16
605634
@skipIfNoONEDNN
@@ -612,8 +641,26 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
612641
"""
613642
self._qconv2d_unary_test_helper(
614643
unary_op=torch.nn.Hardtanh(),
615-
int8_mixed_bf16=True,
644+
mixed_bf16=True,
645+
qconv_unary_matcher_nodes=11,
646+
)
647+
648+
@skipIfNoDynamoSupport
649+
@skipIfNoONEDNNBF16
650+
@skipIfNoONEDNN
651+
@skipIfNoFloat8Support
652+
def test_qconv2d_hardtanh_fp8_mixed_bf16_cpu(self):
653+
r"""
654+
This testcase will quantize Conv2d->Hardtanh pattern.
655+
Match.nodes:
656+
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
657+
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
658+
"""
659+
self._qconv2d_unary_test_helper(
660+
unary_op=torch.nn.Hardtanh(),
661+
mixed_bf16=True,
616662
qconv_unary_matcher_nodes=11,
663+
is_fp8=True,
617664
)
618665

619666
@skipIfNoDynamoSupport
@@ -624,6 +671,15 @@ def test_qconv2d_hardswish_cpu(self):
624671
"""
625672
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish())
626673

674+
@skipIfNoDynamoSupport
675+
@skipIfNoONEDNN
676+
@skipIfNoFloat8Support
677+
def test_qconv2d_hardswish_fp8_cpu(self):
678+
r"""
679+
This testcase will quantize Conv2d->Hardswish pattern.
680+
"""
681+
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True)
682+
627683
@skipIfNoDynamoSupport
628684
@skipIfNoONEDNNBF16
629685
@skipIfNoONEDNN
@@ -637,8 +693,27 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
637693
"""
638694
self._qconv2d_unary_test_helper(
639695
unary_op=torch.nn.Hardswish(),
640-
int8_mixed_bf16=True,
696+
mixed_bf16=True,
697+
qconv_unary_matcher_nodes=17,
698+
)
699+
700+
@skipIfNoDynamoSupport
701+
@skipIfNoONEDNNBF16
702+
@skipIfNoONEDNN
703+
@skipIfNoFloat8Support
704+
def test_qconv2d_hardswish_fp8_mixed_bf16_cpu(self):
705+
r"""
706+
This testcase will quantize Conv2d->Hardswish pattern.
707+
Match.nodes:
708+
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
709+
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
710+
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
711+
"""
712+
self._qconv2d_unary_test_helper(
713+
unary_op=torch.nn.Hardswish(),
714+
mixed_bf16=True,
641715
qconv_unary_matcher_nodes=17,
716+
is_fp8=True,
642717
)
643718

644719
@skipIfNoDynamoSupport
@@ -649,6 +724,15 @@ def test_qconv2d_silu_cpu(self):
649724
"""
650725
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU())
651726

727+
@skipIfNoDynamoSupport
728+
@skipIfNoONEDNN
729+
@skipIfNoFloat8Support
730+
def test_qconv2d_silu_fp8_cpu(self):
731+
r"""
732+
This testcase will quantize Conv2d->SiLU pattern.
733+
"""
734+
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True)
735+
652736
@skipIfNoDynamoSupport
653737
@skipIfNoONEDNNBF16
654738
@skipIfNoONEDNN
@@ -662,10 +746,29 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
662746
"""
663747
self._qconv2d_unary_test_helper(
664748
unary_op=torch.nn.SiLU(),
665-
int8_mixed_bf16=True,
749+
mixed_bf16=True,
666750
qconv_unary_matcher_nodes=11,
667751
)
668752

753+
@skipIfNoDynamoSupport
754+
@skipIfNoONEDNNBF16
755+
@skipIfNoONEDNN
756+
@skipIfNoFloat8Support
757+
def test_qconv2d_silu_fp8_mixed_bf16_cpu(self):
758+
r"""
759+
This testcase will quantize Conv2d->SiLU pattern.
760+
Match.nodes:
761+
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
762+
convert_element_type, quantize_per_tensor]
763+
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
764+
"""
765+
self._qconv2d_unary_test_helper(
766+
unary_op=torch.nn.SiLU(),
767+
mixed_bf16=True,
768+
qconv_unary_matcher_nodes=11,
769+
is_fp8=True,
770+
)
771+
669772
def _qconv2d_add_test_helper(
670773
self, device="cpu", use_relu=False, int8_mixed_bf16=False
671774
):

0 commit comments

Comments
 (0)