Skip to content

Commit

Permalink
Merge pull request #906 from jignparm/jignparm/gemm_broadcast
Browse files Browse the repository at this point in the history
Fix GEMM to check for shape broadcast compatibility of A*B and C
  • Loading branch information
jignparm authored May 1, 2020
2 parents 7c37ccb + cffe8c5 commit 59fed17
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
25 changes: 25 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2859,6 +2859,31 @@ def func(a, b, c):
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
graph_validator=lambda g: check_op_count(g, "Gemm", 1))

# test for gemm pattern0: alpha*A*B + beta*C
@check_opset_min_version(12, "Optimizer bug in ORT 1.2")
def test_gemm_pattern0_fail_broadcast(self):
# shapes (3, 3) * (3, 1) + (1, 4) => (3, 1) + (1, 4)
# c not uni-broadcastable to a * b, so should not use GEMM
m, n, k = 3, 3, 1
x_val1 = np.random.rand(m, n).astype("float32")
x_val2 = np.random.rand(n, k).astype("float32")
x_val3 = np.random.rand(k, 4).astype("float32")

def func(a, b, c):
alpha = tf.constant(1.0, dtype=tf.float32)
beta = tf.constant(2.0, dtype=tf.float32)
mul1 = tf.multiply(alpha, tf.matmul(a, b))
mul2 = tf.multiply(beta, c)
x_ = mul1 + mul2
return tf.identity(x_, name=_TFOUTPUT)

def graph_validator(g):
if 'Gemm' in [n.type for n in g.get_nodes()]: return False
return True

self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, _INPUT2: x_val3},
graph_validator=graph_validator)

def test_graph_matcher(self):
shape = [2, 6]
x_val = np.random.random(shape).astype(np.float32)
Expand Down
15 changes: 14 additions & 1 deletion tf2onnx/rewriter/gemm_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from onnx import onnx_pb
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher


# pylint: disable=missing-docstring

def rewrite_gemm(g, ops):
Expand Down Expand Up @@ -77,17 +78,29 @@ def rewrite_gemm(g, ops):
b_edge_name = matmul_node.input[1]
c_edge_name = input_c_node.output[0]

a_mul_b_shape = g.get_shape(matmul_node.output[0])
c_shape = g.get_shape(c_edge_name)
if c_shape is None: continue
if a_mul_b_shape is None: continue
if -1 in c_shape + a_mul_b_shape: continue
compatible = True
for i in range(1, len(c_shape) + 1):
if c_shape[-i] not in [1, a_mul_b_shape[-i]]:
compatible = False
if not compatible: continue

gemm = g.make_node("Gemm", inputs=[a_edge_name, b_edge_name, c_edge_name],
attr=attr,
shapes=[g.get_shape(add_node.output[0])],
dtypes=[g.get_dtype(add_node.output[0])])
dtypes=[g.get_dtype(add_node.output[0])], op_name_scope=matmul_node.name)

ops.append(gemm)
g.replace_all_inputs(ops, add_node.output[0], gemm.output[0])
to_delete = [add_node, matmul_node]
g.safe_remove_nodes(to_delete)
return ops


def get_gemm_attr(match):
attr = {}
for arg in ["alpha", "beta"]:
Expand Down

0 comments on commit 59fed17

Please sign in to comment.