Skip to content

Commit

Permalink
fix batch matmul test
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Sep 11, 2020
1 parent 6ae7c02 commit 2892e6a
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def test_matmul():
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)


def verify_batch_matmul(a_shape, b_shape):
def verify_batch_matmul(a_shape, b_shape, target, ctx):
a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")
out_np = np.matmul(a_array, b_array)
Expand All @@ -877,17 +877,16 @@ def verify_batch_matmul(a_shape, b_shape):

model = helper.make_model(graph, producer_name="matmul_test")

for target, ctx in tvm.testing.enabled_targets():
tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)


# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
@tvm.testing.parametrize_targets("llvm")
def test_batch_matmul():
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
verify_batch_matmul((2, 4, 3), (3, 4))
verify_batch_matmul((2, 3, 4, 3), (3, 4))
def test_batch_matmul(target, ctx):
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx)
verify_batch_matmul((2, 4, 3), (3, 4), target, ctx)
verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx)


def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
Expand Down

0 comments on commit 2892e6a

Please sign in to comment.