diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 306c827bab9f..9c8ef2cb6a46 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1337,6 +1337,15 @@ def run_trsm(inp): assert(grad[0, 0, 0] == 0) assert(grad[1, 0, 0] == 0) + def check_linalg_gemm2(): + a = mx.nd.ones(shape=(SMALL_Y, LARGE_X)) + b = mx.nd.ones(shape=(LARGE_X, SMALL_Y)) + res = nd.linalg_gemm2(a, b) + res.shape == (SMALL_Y, SMALL_Y) + assert res.asnumpy()[0][0] == LARGE_X + assert res.asnumpy()[-1][-1] == LARGE_X + + check_linalg_gemm2() check_potrf() check_potri() check_syrk_batch()