@@ -441,6 +441,7 @@ def fn(match):
441441 return False
442442 binary_node_inputs = next (iter (compute_node .users )).args
443443 assert len (binary_node_inputs ) == 2 , "Expects binary node with 2 inputs"
444+ is_fp8 = match .kwargs ["x" ].meta ["val" ].dtype is torch .float8_e4m3fn
444445 if output_dtype in [torch .float32 , torch .bfloat16 ]:
445446 extra_input_of_binary_node = None
446447 for arg in binary_node_inputs :
@@ -449,7 +450,7 @@ def fn(match):
449450 break
450451 assert extra_input_of_binary_node is not None
451452 # Extra input of binary node comes from dequant pattern
452- if extra_input_from_dequant and (
453+ if not is_fp8 and extra_input_from_dequant and (
453454 (not isinstance (extra_input_of_binary_node , torch .fx .Node ))
454455 or (
455456 extra_input_of_binary_node .target
@@ -2293,37 +2294,44 @@ def _register_qconv_unary_fusion():
22932294
22942295
22952296def _register_qconv_binary_fusion ():
2296- for int8_mixed_bf16_with_inplace_add in [False , True ]:
2297+ for int8_mixed_bf16_with_inplace_add , x_scale_zp_are_tensors in itertools .product ([False , True ], [False , True ]):
2298+ qconv_binary_op = (
2299+ torch .ops .onednn .qconv2d_pointwise .binary_tensor
2300+ if x_scale_zp_are_tensors
2301+ else torch .ops .onednn .qconv2d_pointwise .binary
2302+ )
22972303 # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
22982304 swap_binary_inputs_list = [False , True ]
22992305 binary_replace_patterns = {}
2300- for swap_inputs in swap_binary_inputs_list :
2306+ for swap_inputs , is_fp8 in itertools . product ( swap_binary_inputs_list , [ False , True ]) :
23012307 binary_replace_patterns .update (
23022308 {
23032309 PostOpAttr (
23042310 "sum" , 1.0 , "none" , [], ""
23052311 ): generate_pattern_with_output_quant (
23062312 generate_pattern_with_binary (
23072313 aten .add .Tensor ,
2308- get_qconv_pt2e_pattern (users = 1 ),
2314+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
23092315 dequantize_accum_pattern ,
23102316 int8_mixed_bf16_with_inplace_add ,
23112317 swap_inputs = swap_inputs ,
23122318 ),
2319+ is_fp8 = is_fp8 ,
23132320 ),
23142321 PostOpAttr (
23152322 "sum" , 1.0 , "relu" , [], ""
23162323 ): generate_pattern_with_output_quant (
23172324 generate_pattern_with_unary (
23182325 generate_pattern_with_binary (
23192326 aten .add .Tensor ,
2320- get_qconv_pt2e_pattern (users = 1 ),
2327+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
23212328 dequantize_accum_pattern ,
23222329 int8_mixed_bf16_with_inplace_add ,
23232330 swap_inputs = swap_inputs ,
23242331 ),
23252332 aten .relu .default ,
23262333 ),
2334+ is_fp8 = is_fp8 ,
23272335 ),
23282336 }
23292337 )
@@ -2332,7 +2340,7 @@ def _register_qconv_binary_fusion():
23322340 _register_qconv_post_op_fusion_pass (
23332341 patterns ,
23342342 3 , # pass_number
2335- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
2343+ qconv_binary_op , # computation_op
23362344 binary_unary_attr , # binary_unary_attr
23372345 )
23382346
@@ -2344,7 +2352,7 @@ def _register_qconv_binary_fusion():
23442352 PostOpAttr ("sum" , 1.0 , "relu" , [], "" ): generate_pattern_with_unary (
23452353 generate_pattern_with_binary (
23462354 aten .add .Tensor ,
2347- get_qconv_pt2e_pattern (users = 1 ),
2355+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
23482356 KeywordArg ("accum_after_dequant" ),
23492357 int8_mixed_bf16_with_inplace_add ,
23502358 swap_inputs = swap_inputs ,
@@ -2362,14 +2370,14 @@ def _register_qconv_binary_fusion():
23622370 _register_qconv_post_op_fusion_pass (
23632371 patterns ,
23642372 3 , # pass_number
2365- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
2373+ qconv_binary_op , # computation_op
23662374 binary_unary_attr , # binary_unary_attr
23672375 )
23682376 else :
23692377 _register_qconv_post_op_fusion_pass (
23702378 patterns ,
23712379 4 , # pass_number
2372- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
2380+ qconv_binary_op , # computation_op
23732381 binary_unary_attr , # binary_unary_attr
23742382 )
23752383
@@ -2382,7 +2390,7 @@ def _register_qconv_binary_fusion():
23822390 "sum" , 1.0 , "none" , [], ""
23832391 ): generate_pattern_with_binary (
23842392 aten .add .Tensor ,
2385- get_qconv_pt2e_pattern (users = 1 ),
2393+ get_qconv_pt2e_pattern (x_scale_zp_are_tensors , 1 ),
23862394 KeywordArg ("accum_after_dequant" ),
23872395 int8_mixed_bf16_with_inplace_add ,
23882396 swap_inputs = swap_inputs ,
@@ -2397,7 +2405,7 @@ def _register_qconv_binary_fusion():
23972405 _register_qconv_post_op_fusion_pass (
23982406 patterns ,
23992407 4 if int8_mixed_bf16_with_inplace_add else 5 , # pass_number
2400- torch . ops . onednn . qconv2d_pointwise . binary , # computation_op
2408+ qconv_binary_op , # computation_op
24012409 binary_unary_attr , # binary_unary_attr
24022410 )
24032411
0 commit comments