diff --git a/tests/test_backend.py b/tests/test_backend.py index d93b25059..396ed57df 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2859,6 +2859,31 @@ def func(a, b, c): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3}, graph_validator=lambda g: check_op_count(g, "Gemm", 1)) + # test for gemm pattern0: alpha*A*B + beta*C + @check_opset_min_version(12, "Optimizer bug in ORT 1.2") + def test_gemm_pattern0_fail_broadcast(self): + # shapes (3, 3) * (3, 1) + (1, 4) => (3, 1) + (1, 4) + # c not uni-broadcastable to a * b, so should not use GEMM + m, n, k = 3, 3, 1 + x_val1 = np.random.rand(m, n).astype("float32") + x_val2 = np.random.rand(n, k).astype("float32") + x_val3 = np.random.rand(k, 4).astype("float32") + + def func(a, b, c): + alpha = tf.constant(1.0, dtype=tf.float32) + beta = tf.constant(2.0, dtype=tf.float32) + mul1 = tf.multiply(alpha, tf.matmul(a, b)) + mul2 = tf.multiply(beta, c) + x_ = mul1 + mul2 + return tf.identity(x_, name=_TFOUTPUT) + + def graph_validator(g): + if 'Gemm' in [n.type for n in g.get_nodes()]: return False + return True + + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3}, + graph_validator=graph_validator) + def test_graph_matcher(self): shape = [2, 6] x_val = np.random.random(shape).astype(np.float32) diff --git a/tf2onnx/rewriter/gemm_rewriter.py b/tf2onnx/rewriter/gemm_rewriter.py index c79bf5e9b..22abb45c7 100644 --- a/tf2onnx/rewriter/gemm_rewriter.py +++ b/tf2onnx/rewriter/gemm_rewriter.py @@ -8,6 +8,7 @@ from onnx import onnx_pb from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher + # pylint: disable=missing-docstring def rewrite_gemm(g, ops): @@ -77,10 +78,21 @@ def rewrite_gemm(g, ops): b_edge_name = matmul_node.input[1] c_edge_name = input_c_node.output[0] + a_mul_b_shape = g.get_shape(matmul_node.output[0]) + c_shape = g.get_shape(c_edge_name) + if c_shape is None: continue + if a_mul_b_shape is None: continue + if -1 in c_shape + a_mul_b_shape: continue + compatible = True + for i in range(1, len(c_shape) + 1): + if c_shape[-i] not in [1, a_mul_b_shape[-i]]: + compatible = False + if not compatible: continue + gemm = g.make_node("Gemm", inputs=[a_edge_name, b_edge_name, c_edge_name], attr=attr, shapes=[g.get_shape(add_node.output[0])], - dtypes=[g.get_dtype(add_node.output[0])]) + dtypes=[g.get_dtype(add_node.output[0])], op_name_scope=matmul_node.name) ops.append(gemm) g.replace_all_inputs(ops, add_node.output[0], gemm.output[0]) @@ -88,6 +100,7 @@ def rewrite_gemm(g, ops): g.safe_remove_nodes(to_delete) return ops + def get_gemm_attr(match): attr = {} for arg in ["alpha", "beta"]: