@@ -512,9 +512,10 @@ def test_qconv2d_fp8_mixed_bf16(self):
512512 def _qconv2d_unary_test_helper (
513513 self ,
514514 device = "cpu" ,
515- int8_mixed_bf16 = False ,
515+ mixed_bf16 = False ,
516516 unary_op = torch .nn .ReLU (),
517517 qconv_unary_matcher_nodes = None ,
518+ is_fp8 = False ,
518519 ):
519520 class M (torch .nn .Module ):
520521 def __init__ (
@@ -563,8 +564,9 @@ def matcher_check_fn():
563564 mod ,
564565 (v ,),
565566 check_quantization = True ,
566- check_autocast = torch .bfloat16 if int8_mixed_bf16 else torch .float32 ,
567+ check_autocast = torch .bfloat16 if mixed_bf16 else torch .float32 ,
567568 matcher_check_fn = matcher_check_fn ,
569+ is_fp8 = is_fp8 ,
568570 )
569571
570572 @skipIfNoDynamoSupport
@@ -575,14 +577,23 @@ def test_qconv2d_relu_cpu(self):
575577 """
576578 self ._qconv2d_unary_test_helper (device = "cpu" )
577579
580+ @skipIfNoDynamoSupport
581+ @skipIfNoONEDNN
582+ @skipIfNoFloat8Support
583+ def test_qconv2d_relu_fp8_cpu (self ):
584+ r"""
585+ This testcase will quantize Conv2d->ReLU pattern.
586+ """
587+ self ._qconv2d_unary_test_helper (device = "cpu" , is_fp8 = True )
588+
578589 @skipIfNoDynamoSupport
579590 @skipIfNoONEDNNBF16
580591 @skipIfNoONEDNN
581592 def test_qconv2d_relu_int8_mixed_bf16_xpu (self ):
582593 r"""
583594 This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
584595 """
585- self ._qconv2d_unary_test_helper (int8_mixed_bf16 = True )
596+ self ._qconv2d_unary_test_helper (mixed_bf16 = True )
586597
587598 @skipIfNoDynamoSupport
588599 @skipIfNoONEDNN
@@ -592,6 +603,15 @@ def test_qconv2d_relu6_cpu(self):
592603 """
593604 self ._qconv2d_unary_test_helper (device = "cpu" , unary_op = torch .nn .ReLU6 ())
594605
606+ @skipIfNoDynamoSupport
607+ @skipIfNoONEDNN
608+ @skipIfNoFloat8Support
609+ def test_qconv2d_relu6_fp8_cpu (self ):
610+ r"""
611+ This testcase will quantize Conv2d->ReLU6 pattern.
612+ """
613+ self ._qconv2d_unary_test_helper (device = "cpu" , unary_op = torch .nn .ReLU6 (), is_fp8 = True )
614+
595615 @skipIfNoDynamoSupport
596616 @skipIfNoONEDNN
597617 def test_qconv2d_hardtanh_cpu (self ):
@@ -600,6 +620,15 @@ def test_qconv2d_hardtanh_cpu(self):
600620 """
601621 self ._qconv2d_unary_test_helper (device = "cpu" , unary_op = torch .nn .Hardtanh ())
602622
623+ @skipIfNoDynamoSupport
624+ @skipIfNoONEDNN
625+ @skipIfNoFloat8Support
626+ def test_qconv2d_hardtanh_fp8_cpu (self ):
627+ r"""
628+ This testcase will quantize Conv2d->Hardtanh pattern.
629+ """
630+ self ._qconv2d_unary_test_helper (device = "cpu" , unary_op = torch .nn .Hardtanh (), is_fp8 = True )
631+
603632 @skipIfNoDynamoSupport
604633 @skipIfNoONEDNNBF16
605634 @skipIfNoONEDNN
@@ -612,8 +641,26 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
612641 """
613642 self ._qconv2d_unary_test_helper (
614643 unary_op = torch .nn .Hardtanh (),
615- int8_mixed_bf16 = True ,
644+ mixed_bf16 = True ,
645+ qconv_unary_matcher_nodes = 11 ,
646+ )
647+
648+ @skipIfNoDynamoSupport
649+ @skipIfNoONEDNNBF16
650+ @skipIfNoONEDNN
651+ @skipIfNoFloat8Support
652+ def test_qconv2d_hardtanh_fp8_mixed_bf16_cpu (self ):
653+ r"""
654+ This testcase will quantize Conv2d->Hardtanh pattern.
655+ Match.nodes:
656+ [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
657+ [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
658+ """
659+ self ._qconv2d_unary_test_helper (
660+ unary_op = torch .nn .Hardtanh (),
661+ mixed_bf16 = True ,
616662 qconv_unary_matcher_nodes = 11 ,
663+ is_fp8 = True ,
617664 )
618665
619666 @skipIfNoDynamoSupport
@@ -624,6 +671,15 @@ def test_qconv2d_hardswish_cpu(self):
624671 """
625672 self ._qconv2d_unary_test_helper (device = "cpu" , unary_op = torch .nn .Hardswish ())
626673
674+ @skipIfNoDynamoSupport
675+ @skipIfNoONEDNN
676+ @skipIfNoFloat8Support
677+ def test_qconv2d_hardswish_fp8_cpu (self ):
678+ r"""
679+ This testcase will quantize Conv2d->Hardswish pattern.
680+ """
681+ self ._qconv2d_unary_test_helper (device = "cpu" , unary_op = torch .nn .Hardswish (), is_fp8 = True )
682+
627683 @skipIfNoDynamoSupport
628684 @skipIfNoONEDNNBF16
629685 @skipIfNoONEDNN
@@ -637,8 +693,27 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
637693 """
638694 self ._qconv2d_unary_test_helper (
639695 unary_op = torch .nn .Hardswish (),
640- int8_mixed_bf16 = True ,
696+ mixed_bf16 = True ,
697+ qconv_unary_matcher_nodes = 17 ,
698+ )
699+
700+ @skipIfNoDynamoSupport
701+ @skipIfNoONEDNNBF16
702+ @skipIfNoONEDNN
703+ @skipIfNoFloat8Support
704+ def test_qconv2d_hardswish_fp8_mixed_bf16_cpu (self ):
705+ r"""
706+ This testcase will quantize Conv2d->Hardswish pattern.
707+ Match.nodes:
708+ [qconv2d_pointwise_default, convert_element_type, add, clamp_min,
709+ clamp_max, mul, div, convert_element_type, quantize_per_tensor]
710+ [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
711+ """
712+ self ._qconv2d_unary_test_helper (
713+ unary_op = torch .nn .Hardswish (),
714+ mixed_bf16 = True ,
641715 qconv_unary_matcher_nodes = 17 ,
716+ is_fp8 = True ,
642717 )
643718
644719 @skipIfNoDynamoSupport
@@ -649,6 +724,15 @@ def test_qconv2d_silu_cpu(self):
649724 """
650725 self ._qconv2d_unary_test_helper (device = "cpu" , unary_op = torch .nn .SiLU ())
651726
727+ @skipIfNoDynamoSupport
728+ @skipIfNoONEDNN
729+ @skipIfNoFloat8Support
730+ def test_qconv2d_silu_fp8_cpu (self ):
731+ r"""
732+ This testcase will quantize Conv2d->SiLU pattern.
733+ """
734+ self ._qconv2d_unary_test_helper (device = "cpu" , unary_op = torch .nn .SiLU (), is_fp8 = True )
735+
652736 @skipIfNoDynamoSupport
653737 @skipIfNoONEDNNBF16
654738 @skipIfNoONEDNN
@@ -662,10 +746,29 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
662746 """
663747 self ._qconv2d_unary_test_helper (
664748 unary_op = torch .nn .SiLU (),
665- int8_mixed_bf16 = True ,
749+ mixed_bf16 = True ,
666750 qconv_unary_matcher_nodes = 11 ,
667751 )
668752
753+ @skipIfNoDynamoSupport
754+ @skipIfNoONEDNNBF16
755+ @skipIfNoONEDNN
756+ @skipIfNoFloat8Support
757+ def test_qconv2d_silu_fp8_mixed_bf16_cpu (self ):
758+ r"""
759+ This testcase will quantize Conv2d->SiLU pattern.
760+ Match.nodes:
761+ [qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
762+ convert_element_type, quantize_per_tensor]
763+ [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
764+ """
765+ self ._qconv2d_unary_test_helper (
766+ unary_op = torch .nn .SiLU (),
767+ mixed_bf16 = True ,
768+ qconv_unary_matcher_nodes = 11 ,
769+ is_fp8 = True ,
770+ )
771+
669772 def _qconv2d_add_test_helper (
670773 self , device = "cpu" , use_relu = False , int8_mixed_bf16 = False
671774 ):
0 commit comments