Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge branch 'v1.x' into add_lt_gemm2
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya authored Jul 27, 2020
2 parents e170000 + 566d9d3 commit 9ded67b
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
13 changes: 13 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,19 @@ def create_vector(size, dtype=np.int64):
a = mx.nd.arange(0, size, dtype=dtype)
return a

# For testing Large Square Matrix with total size > 2^32 elements
def get_identity_mat(size):
A = mx.nd.zeros((size, size))
for i in range(size):
A[i, i] = 1
return A

# For testing Batch of Large Square Matrix with total size > 2^32 elements
def get_identity_mat_batch(size):
A = get_identity_mat(size)
A_np = A.asnumpy()
return mx.nd.array([A_np, A_np])

def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None,
data_init=None, rsp_indices=None, modifier_func=None,
shuffle_csr_indices=False, ctx=None):
Expand Down
119 changes: 118 additions & 1 deletion tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../python/unittest/'))

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor, get_identity_mat, get_identity_mat_batch
from mxnet import gluon, nd
from common import with_seed, with_post_test_cleanup
from nose.tools import with_setup
Expand Down Expand Up @@ -1226,11 +1226,128 @@ def run_gemm2(inp1,inp2):
assert inp2_grad.asnumpy()[0][0] == (LARGE_X, SMALL_Y)
assert_almost_equal(inp2_grad.asnumpy()[0][0],49.1)

def check_det():
def run_det(inp):
inp.attach_grad()
with mx.autograd.record():
out = mx.nd.linalg.det(inp)
return inp.grad, out

A = get_identity_mat(LARGE_SQ_X)
grad, out = run_det(A)
assert(out.shape == (1,))
assert(out[0] == 1)
out.backward()
assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
assert(grad[0, 0] == 1)

def check_inverse():
def run_inverse(inp):
inp.attach_grad()
with mx.autograd.record():
out = mx.nd.linalg.inverse(inp)
return inp.grad, out

A = get_identity_mat(LARGE_SQ_X)
grad, out = run_inverse(A)
assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
assert(out[0, 0] == 1)
out.backward()
assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
assert(grad[0, 0] == -1)

def check_trmm():
def run_trmm(inp):
inp.attach_grad()
with mx.autograd.record():
out = mx.nd.linalg.trmm(inp, inp)
return inp.grad, out

A = get_identity_mat(LARGE_SQ_X)
grad, out = run_trmm(A)
assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
assert(out[0, 0] == 1)
out.backward()
assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
assert(grad[0, 0] == 2)

def check_trsm():
def run_trsm(inp):
inp.attach_grad()
with mx.autograd.record():
out = mx.nd.linalg.trsm(inp, inp)
return inp.grad, out

A = get_identity_mat(LARGE_SQ_X)
grad, out = run_trsm(A)
assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
assert(out[0, 0] == 1)
out.backward()
assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
assert(grad[0, 0] == 0)

def check_batch_inverse():
def run_inverse(inp):
inp.attach_grad()
with mx.autograd.record():
out = mx.nd.linalg.inverse(inp)
return inp.grad, out

B = get_identity_mat_batch(LARGE_SQ_X)
grad, out = run_inverse(B)
assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
assert(out[0, 0, 0] == 1)
assert(out[1, 0, 0] == 1)
out.backward()
assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
assert(grad[0, 0, 0] == -1)
assert(grad[1, 0, 0] == -1)

def check_batch_trmm():
def run_trmm(inp):
inp.attach_grad()
with mx.autograd.record():
out = mx.nd.linalg.trmm(inp, inp)
return inp.grad, out

B = get_identity_mat_batch(LARGE_SQ_X)
grad, out = run_trmm(B)
assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
assert(out[0, 0, 0] == 1)
assert(out[1, 0, 0] == 1)
out.backward()
assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
assert(grad[0, 0, 0] == 2)
assert(grad[1, 0, 0] == 2)

def check_batch_trsm():
def run_trsm(inp):
inp.attach_grad()
with mx.autograd.record():
out = mx.nd.linalg.trsm(inp, inp)
return inp.grad, out

B = get_identity_mat_batch(LARGE_SQ_X)
grad, out = run_trsm(B)
assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
assert(out[0, 0, 0] == 1)
assert(out[1, 0, 0] == 1)
out.backward()
assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
assert(grad[0, 0, 0] == 0)
assert(grad[1, 0, 0] == 0)

check_potrf()
check_potri()
check_syrk_batch()
check_gemm2()
check_det()
check_inverse()
check_trmm()
check_trsm()
check_batch_inverse()
check_batch_trmm()
check_batch_trsm()


def test_basic():
Expand Down

0 comments on commit 9ded67b

Please sign in to comment.