diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 520b5fbdfb..2e0a4f7738 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -44,7 +44,6 @@ from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( X86InductorQuantizer, ) -from torchao.testing.utils import skip_if_rocm from torchao.utils import torch_version_at_least # The dict value is match_nodes(computation_op+unary_op) @@ -93,6 +92,9 @@ skipIfNoFloat8Support = unittest.skipIf( not torch_version_at_least("2.9.0"), "Float8 requires torch 2.9+" ) +skipIfNoQConvFp8Support = unittest.skipIf( + not torch_version_at_least("2.10.0.dev"), "QConv fp8 requires torch 2.10+" +) def get_default_quantizer(is_qat, is_dynamic): @@ -138,6 +140,61 @@ def forward(self, input): return out +class FP8QDQConv2d(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn( + (out_channels, in_channels // groups, *kernel_size) + ).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if bias: + self.bias = torch.randn((out_channels,)) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=self.weight.data, + scale=torch.tensor([self.weight_scale]), + output_dtype=torch.float, + ) + q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( + tensor=input, + scale=torch.tensor([self.scale]), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=q_input, + scale=torch.tensor([self.scale]), + output_dtype=torch.float, + ) + + return torch.nn.functional.conv2d( + dq_input, + weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def qdq(input, scale): dtype = input.dtype q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( @@ -171,9 +228,7 @@ def create_mod_info_recursion(parent): parent_child_mod_dict = generate_model_info(model) for name, mod in model.named_modules(): mod_type_str = mod.__class__.__name__ - if mod_type_str not in [ - "Linear", - ]: + if mod_type_str not in ["Linear", "Conv2d"]: continue param = mod.weight xmax = torch.max(param) @@ -190,6 +245,20 @@ def create_mod_info_recursion(parent): patched_mod.bias = mod.bias patched_mod.weight_scale = weight_scale.item() patched_mod.weight.data = q_param + elif mod_type_str in ["Conv2d"]: + patched_mod = FP8QDQConv2d( + mod.in_channels, + mod.out_channels, + mod.kernel_size, + mod.stride, + mod.padding, + mod.dilation, + mod.groups, + False, + ) + patched_mod.bias = mod.bias + patched_mod.weight_scale = weight_scale.item() + patched_mod.weight.data = q_param parent = parent_child_mod_dict[mod].parent name = parent_child_mod_dict[mod].name @@ -381,8 +450,9 @@ def _test_code_common( @unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+") +@unittest.skipIf(torch.version.hip is not None, "Not applicable to ROCm") class TestPatternMatcher(TestPatternMatcherBase): - def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): + def _qconv2d_test_helper(self, device="cpu", mixed_bf16=False, is_fp8=False): class M(torch.nn.Module): def __init__( self, @@ -408,14 +478,14 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1 # int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution] - # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), + # mixed_bf16: [dequant_node, optional(convert_element_type_4), # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 ) self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_nodes"], - 18 if int8_mixed_bf16 else 12, + 18 if mixed_bf16 else 12, ) self.assertEqual( counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3 @@ -426,34 +496,53 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qconv2d_cpu(self): r""" This testcase will quantize a single Conv2d module. """ self._qconv2d_test_helper("cpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_fp8_cpu(self): + r""" + This testcase will quantize a single Conv2d module. + """ + self._qconv2d_test_helper("cpu", is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qconv2d_int8_mixed_bf16(self): r""" This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. """ - self._qconv2d_test_helper(int8_mixed_bf16=True) + self._qconv2d_test_helper(mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_fp8_mixed_bf16(self): + r""" + This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. + """ + self._qconv2d_test_helper(mixed_bf16=True, is_fp8=True) def _qconv2d_unary_test_helper( self, device="cpu", - int8_mixed_bf16=False, + mixed_bf16=False, unary_op=torch.nn.ReLU(), qconv_unary_matcher_nodes=None, + is_fp8=False, ): class M(torch.nn.Module): def __init__( @@ -502,8 +591,9 @@ def matcher_check_fn(): mod, (v,), check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, matcher_check_fn=matcher_check_fn, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -514,6 +604,15 @@ def test_qconv2d_relu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_relu_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -521,7 +620,7 @@ def test_qconv2d_relu_int8_mixed_bf16_xpu(self): r""" This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization. """ - self._qconv2d_unary_test_helper(int8_mixed_bf16=True) + self._qconv2d_unary_test_helper(mixed_bf16=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -531,6 +630,17 @@ def test_qconv2d_relu6_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_relu6_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern. + """ + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_hardtanh_cpu(self): @@ -539,6 +649,17 @@ def test_qconv2d_hardtanh_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_hardtanh_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + """ + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -551,8 +672,26 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardtanh(), - int8_mixed_bf16=True, + mixed_bf16=True, + qconv_unary_matcher_nodes=11, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_hardtanh_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardtanh(), + mixed_bf16=True, qconv_unary_matcher_nodes=11, + is_fp8=True, ) @skipIfNoDynamoSupport @@ -563,6 +702,17 @@ def test_qconv2d_hardswish_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_hardswish_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + """ + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -576,8 +726,27 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardswish(), - int8_mixed_bf16=True, + mixed_bf16=True, + qconv_unary_matcher_nodes=17, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_hardswish_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, + clamp_max, mul, div, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardswish(), + mixed_bf16=True, qconv_unary_matcher_nodes=17, + is_fp8=True, ) @skipIfNoDynamoSupport @@ -588,6 +757,17 @@ def test_qconv2d_silu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_silu_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + """ + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -601,12 +781,31 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.SiLU(), - int8_mixed_bf16=True, + mixed_bf16=True, + qconv_unary_matcher_nodes=11, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_silu_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, + convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.SiLU(), + mixed_bf16=True, qconv_unary_matcher_nodes=11, + is_fp8=True, ) def _qconv2d_add_test_helper( - self, device="cpu", use_relu=False, int8_mixed_bf16=False + self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False ): r""" This testcase will quantize a Conv2d->Add pattern as: @@ -680,11 +879,12 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, + is_fp8=is_fp8, ) def _qconv2d_add_test_helper2( - self, device="cpu", use_relu=False, int8_mixed_bf16=False + self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False ): r""" This testcase will quantize two Conv2d->Add patterns as: @@ -743,9 +943,10 @@ def forward(self, x, x2, x3): res = self.relu2(res) return res - for add_fn, swap_inputs in itertools.product( - quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True] - ): + add_fn_list = quantization_add_fn_list + if not is_fp8: + add_fn_list = add_fn_list + quantization_inplace_add_fn_list + for add_fn, swap_inputs in itertools.product(add_fn_list, [False, True]): mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device) x = torch.randn( (1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device @@ -777,7 +978,8 @@ def matcher_check_fn(): (x, x2, x3), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -786,12 +988,27 @@ def test_qconv2d_add_cpu(self): self._qconv2d_add_test_helper() self._qconv2d_add_test_helper2() + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_add_fp8_cpu(self): + self._qconv2d_add_test_helper(is_fp8=True) + self._qconv2d_add_test_helper2(is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN def test_qconv2d_add_int8_mixed_bf16(self): - self._qconv2d_add_test_helper(int8_mixed_bf16=True) - self._qconv2d_add_test_helper2(int8_mixed_bf16=True) + self._qconv2d_add_test_helper(mixed_bf16=True) + self._qconv2d_add_test_helper2(mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_add_fp8_mixed_bf16(self): + self._qconv2d_add_test_helper(mixed_bf16=True, is_fp8=True) + self._qconv2d_add_test_helper2(mixed_bf16=True, is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -799,12 +1016,27 @@ def test_qconv2d_add_relu_cpu(self): self._qconv2d_add_test_helper(use_relu=True) self._qconv2d_add_test_helper2(use_relu=True) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_add_relu_fp8_cpu(self): + self._qconv2d_add_test_helper(use_relu=True, is_fp8=True) + self._qconv2d_add_test_helper2(use_relu=True, is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN def test_qconv2d_add_relu_int8_mixed_bf16(self): - self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True) - self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True) + self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True) + self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_add_relu_fp8_mixed_bf16(self): + self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True, is_fp8=True) + self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True, is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1035,7 +1267,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qat_qconv2d(self): r""" This testcase will quantize a single Conv2d module with qat flow. @@ -1178,7 +1409,6 @@ def test_qat_qconv2d_hardswish(self): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qat_qconv2d_add(self): r""" This testcase will quantize a Conv2d->Add pattern as: @@ -1244,7 +1474,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qat_qconv2d_add_relu(self): r""" This testcase will quantize a Conv2d->Add->ReLU pattern as: @@ -1384,7 +1613,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qconv2d_dequant_promotion_cpu(self): self._test_qconv2d_dequant_promotion_helper() diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index a0aef11541..c5280b9db0 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -167,60 +167,49 @@ def get_dequantize_per_tensor_activation_pattern( KeywordArg("w_dtype"), ) -dequantize_per_channel_to_bf16_weight_pattern = ( - _may_generate_pattern_with_dtype_convert( - dequantize_per_channel_weight_pattern, - KeywordArg("autocast_wgt_dtype"), - ) +dequantize_fp8_weight_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + output_dtype=KeywordArg("w_dtype"), ) -dequantize_per_channel_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_weight_pattern, - memory_format=KeywordArg("memory_format"), -) -dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_to_bf16_weight_pattern, - memory_format=KeywordArg("memory_format"), -) +def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern): + return _may_generate_pattern_with_dtype_convert( + dequant_wgt_pattern, + KeywordArg("autocast_wgt_dtype"), + ) -def get_qconv_pt2e_pattern(users=1): +def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): return CallFunction( - torch.ops.onednn.qconv_pointwise.default, - KeywordArg("x"), - KeywordArg("x_scale"), - KeywordArg("x_zp"), - KeywordArg("packed_weight"), - KeywordArg("w_scale"), - KeywordArg("w_zp"), - KeywordArg("b"), - KeywordArg("stride"), - KeywordArg("padding"), - KeywordArg("dilation"), - KeywordArg("groups"), - KeywordArg("output_scale"), - KeywordArg("output_zero_point"), - KeywordArg("output_dtype"), - KeywordArg("postop_name"), - KeywordArg("postop_args"), - KeywordArg("postop_algorithm"), - _users=users, + aten.clone.default, + dequant_wgt_pattern, + memory_format=KeywordArg("memory_format"), + ) + + +def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern): + return get_dequantize_clone_weight_pattern( + get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern) ) -def get_qconv2d_binary_pt2e_pattern(users=1): +def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): + qconv_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) return CallFunction( - torch.ops.onednn.qconv2d_pointwise.binary, + qconv_op, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), KeywordArg("packed_weight"), KeywordArg("w_scale"), KeywordArg("w_zp"), - KeywordArg("accum"), KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), @@ -229,13 +218,9 @@ def get_qconv2d_binary_pt2e_pattern(users=1): KeywordArg("output_scale"), KeywordArg("output_zero_point"), KeywordArg("output_dtype"), - KeywordArg("accum_scale"), - KeywordArg("accum_zero_point"), - KeywordArg("binary_op_name"), - KeywordArg("alpha"), - KeywordArg("unary_op_name"), - KeywordArg("unary_op_args"), - KeywordArg("unary_op_algorithm"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), _users=users, ) @@ -461,6 +446,7 @@ def fn(match): return False binary_node_inputs = next(iter(compute_node.users)).args assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + is_fp8 = match.kwargs["x"].meta["val"].dtype is torch.float8_e4m3fn if output_dtype in [torch.float32, torch.bfloat16]: extra_input_of_binary_node = None for arg in binary_node_inputs: @@ -469,14 +455,18 @@ def fn(match): break assert extra_input_of_binary_node is not None # Extra input of binary node comes from dequant pattern - if extra_input_from_dequant and ( - (not isinstance(extra_input_of_binary_node, torch.fx.Node)) - or ( - extra_input_of_binary_node.target - not in [ - quantized_decomposed.dequantize_per_tensor.default, - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - ] + if ( + not is_fp8 + and extra_input_from_dequant + and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + not in [ + quantized_decomposed.dequantize_per_tensor.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] + ) ) ): return False @@ -711,7 +701,9 @@ def _inner(match): return _inner -def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): +def _register_qconv_weight_prepack_pass( + pattern, pass_number, dtype=torch.float32, is_fp8=False +): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype), @@ -724,7 +716,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): | dequant_per_tensor | - Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight Insert weight prepack node and change the pattern to: int8 activation @@ -747,7 +739,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): ) if dtype == torch.float32: - dequant_per_channel = ( + dequant = ( clone_node.args[0] # type: ignore[union-attr] if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] @@ -758,9 +750,9 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] ) - dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + dequant = weight_to_bf16_node.args[0] # type: ignore[union-attr] - assert dequant_per_channel.target in [ # type: ignore[union-attr] + assert dequant.target in [ # type: ignore[union-attr] quantized_decomposed.dequantize_per_channel.default, torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ] @@ -768,7 +760,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): # Activation QParams qx, x_zp, x_scale = ( kwargs["x"], - kwargs["x_zp"], + kwargs["x_zp"] if "x_zp" in kwargs else None, kwargs["x_scale"], ) @@ -776,7 +768,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): qw, w_scale, w_zp = ( kwargs["q_weight"], kwargs["w_scale"], - kwargs["w_zp"], + kwargs["w_zp"] if "w_zp" in kwargs else None, ) # Conv Params @@ -792,14 +784,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None + if is_fp8: + # For float8, we assume the scales are from aten.full.default instead of + # a constant buffer to avoid constant folding of q/dq before fusion passes. + assert ( + w_scale.target is torch.ops.aten.full.default + and x_scale.target is torch.ops.aten.full.default + ) + with torch.utils._python_dispatch._disable_current_modes(): + w_scale_tensor = torch.tensor([w_scale.args[1]]) + match.graph.owning_module.register_buffer("w_scale", w_scale_tensor) + w_scale = match.graph.create_node("get_attr", "w_scale") graph = match.graph with graph.inserting_before(conv_node): # Insert weight prepack node and the QConv node packed_weight_inputs = ( qw, w_scale, - x_scale, - x_zp, + x_scale.args[1] if is_fp8 else x_scale, + 0, stride, padding, dilation, @@ -830,9 +833,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): [], # scalars "", # algorithm ) - new_conv_node = graph.call_function( - torch.ops.onednn.qconv_pointwise.default, args=new_args - ) + Node = torch.fx.node.Node + # fp8 not need zp + if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8): + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.tensor, args=new_args + ) + else: + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) conv_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(conv_node.meta) @@ -847,7 +857,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): graph.erase_node(clone_node) # type: ignore[arg-type] if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] - graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + graph.erase_node(dequant) # type: ignore[arg-type] counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( match.nodes @@ -855,17 +865,17 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): def _generate_dequant_convolution_node_pattern( - _dequant_per_channel_pattern, dtype=torch.float32 + _dequant_pattern, dtype=torch.float32, is_fp8=False ): assert dtype in [torch.float32, torch.bfloat16] dequant_convolution_node_pattern = CallFunction( aten.convolution.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(), + get_dequantize_per_tensor_activation_pattern(is_fp8=is_fp8), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), - _dequant_per_channel_pattern, + _dequant_pattern, KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), @@ -877,24 +887,30 @@ def _generate_dequant_convolution_node_pattern( return dequant_convolution_node_pattern -def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32, is_fp8=False): assert dtype in [torch.float32, torch.bfloat16] + if is_fp8: + dequant_wgt_pattern = dequantize_fp8_weight_pattern + else: + dequant_wgt_pattern = dequantize_per_channel_weight_pattern return ( _generate_dequant_convolution_node_pattern( - dequantize_per_channel_weight_pattern + dequant_wgt_pattern if dtype == torch.float32 - else dequantize_per_channel_to_bf16_weight_pattern, + else get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), # There is another pattern due to the pass of convert_conv_weights_to_channels_last # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. # Depend on some heuristics, it may or may not insert to(channel_last) node - # between convolution and dequant_per_channel node + # between convolution and dequant node _generate_dequant_convolution_node_pattern( - dequantize_per_channel_clone_weight_pattern + get_dequantize_clone_weight_pattern(dequant_wgt_pattern) if dtype == torch.float32 - else dequantize_per_channel_to_bf16_clone_weight_pattern, + else get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), ) @@ -1302,12 +1318,7 @@ def _generate_qlinear_weight_prepack_patterns( is_fp8=False, ): if is_fp8: - dequant_wgt_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - KeywordArg("q_weight"), - KeywordArg("w_scale"), - output_dtype=KeywordArg("w_dtype"), - ) + dequant_wgt_pattern = dequantize_fp8_weight_pattern else: dequant_wgt_pattern = dequantize_per_channel_weight_pattern if input_dim_exceeds_two and not input_contiguous: @@ -1449,12 +1460,16 @@ def _register_dequant_promotion(): def _register_qconv_weight_prepack(): - for dtype in [torch.float32, torch.bfloat16]: - weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for dtype, is_fp8 in itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ): + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns( + dtype, is_fp8=is_fp8 + ) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass( - weight_prepack_pattern, pass_number=1, dtype=dtype + weight_prepack_pattern, pass_number=1, dtype=dtype, is_fp8=is_fp8 ) @@ -2053,13 +2068,25 @@ def qconv(match: Match, *args, **kwargs): kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float8_e4m3fn, + torch.float32, + torch.bfloat16, + ] # Output QParams - o_inv_scale = ( - kwargs["o_inv_scale"] - if (output_dtype == torch.uint8 or output_dtype == torch.int8) - else 1.0 - ) + if output_dtype == torch.float8_e4m3fn: + # For float8, we assume the scale is from aten.full.default instead of + # a constant buffer to avoid constant folding of q/dq before fusion passes. + assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default + o_inv_scale = kwargs["o_inv_scale"].args[1] + else: + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 1.0 + ) o_zero_point = ( kwargs["o_zp"] if (output_dtype == torch.uint8 or output_dtype == torch.int8) @@ -2165,56 +2192,69 @@ def _register_qconv_unary_fusion(): _silu_fusion, ) - for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + combinations = itertools.product( + [torch.float32, torch.bfloat16], [False, True], [False, True] + ) + for original_pattern_output_dtype, x_scale_zp_are_tensors, is_fp8 in combinations: # Priority 1 to match: QConv2d Unary pattern with int8 output # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant is_bf16 = original_pattern_output_dtype == torch.bfloat16 + computation_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) conv_unary_replace_patterns = { PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + is_fp8=is_fp8, ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default ), + is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardtanh", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), 1, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardswish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), PostOpAttr( "none", None, "swish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), } @@ -2223,21 +2263,21 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv_pointwise.default, # computation_op + computation_op, # computation_op unary_attr, # unary_attr ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output conv_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default ), PostOpAttr( "none", None, "hardtanh", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), 1, is_bf16, ), @@ -2249,7 +2289,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), @@ -2261,7 +2301,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), @@ -2275,17 +2315,26 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv_pointwise.default, # computation_op + computation_op, # computation_op unary_attr, # unary_attr ) def _register_qconv_binary_fusion(): - for int8_mixed_bf16_with_inplace_add in [False, True]: + for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product( + [False, True], [False, True] + ): + qconv_binary_op = ( + torch.ops.onednn.qconv2d_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv2d_pointwise.binary + ) # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output swap_binary_inputs_list = [False, True] binary_replace_patterns = {} - for swap_inputs in swap_binary_inputs_list: + for swap_inputs, is_fp8 in itertools.product( + swap_binary_inputs_list, [False, True] + ): binary_replace_patterns.update( { PostOpAttr( @@ -2293,11 +2342,12 @@ def _register_qconv_binary_fusion(): ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), + is_fp8=is_fp8, ), PostOpAttr( "sum", 1.0, "relu", [], "" @@ -2305,13 +2355,14 @@ def _register_qconv_binary_fusion(): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), aten.relu.default, ), + is_fp8=is_fp8, ), } ) @@ -2320,7 +2371,7 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr ) @@ -2332,7 +2383,7 @@ def _register_qconv_binary_fusion(): PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2350,14 +2401,14 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr ) else: _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr ) @@ -2370,7 +2421,7 @@ def _register_qconv_binary_fusion(): "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2385,7 +2436,7 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr ) @@ -2427,8 +2478,8 @@ def qlinear_post_op_fusion(match: Match, *args, **kwargs): # Output QParams if output_dtype == torch.float8_e4m3fn: - # For float8, torchao.quantize_affine_float8 requires tensor as scale - # Support scale node is full firstly + # For float8, we assume the scale is from aten.full.default instead of + # a constant buffer to avoid constant folding of q/dq before fusion passes. assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default o_inv_scale = kwargs["o_inv_scale"].args[1] else: