Skip to content

Commit

Permalink
add ctx for rand_ndarray and rand_sparse_ndarray (apache#14966)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and haohuw committed Jun 23, 2019
1 parent d92e9bd commit 1eeb013
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def assign_each2(input1, input2, function):

def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None,
data_init=None, rsp_indices=None, modifier_func=None,
shuffle_csr_indices=False):
shuffle_csr_indices=False, ctx=None):
"""Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np)
Parameters
Expand Down Expand Up @@ -301,6 +301,7 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
>>> assert(row4nnz == 2*row3nnz)
"""
ctx = ctx if ctx else default_context()
density = rnd.rand() if density is None else density
dtype = default_dtype() if dtype is None else dtype
distribution = "uniform" if distribution is None else distribution
Expand All @@ -315,7 +316,7 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
idx_sample = rnd.rand(shape[0])
indices = np.argwhere(idx_sample < density).flatten()
if indices.shape[0] == 0:
result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype)
result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype, ctx=ctx)
return result, (np.array([], dtype=dtype), np.array([]))
# generate random values
val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype)
Expand All @@ -326,17 +327,17 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
if modifier_func is not None:
val = assign_each(val, modifier_func)

arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, dtype=dtype)
arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, dtype=dtype, ctx=ctx)
return arr, (val, indices)
elif stype == 'csr':
assert len(shape) == 2
if distribution == "uniform":
csr = _get_uniform_dataset_csr(shape[0], shape[1], density,
data_init=data_init,
shuffle_csr_indices=shuffle_csr_indices, dtype=dtype)
shuffle_csr_indices=shuffle_csr_indices, dtype=dtype).as_in_context(ctx)
return csr, (csr.indptr, csr.indices, csr.data)
elif distribution == "powerlaw":
csr = _get_powerlaw_dataset_csr(shape[0], shape[1], density=density, dtype=dtype)
csr = _get_powerlaw_dataset_csr(shape[0], shape[1], density=density, dtype=dtype).as_in_context(ctx)
return csr, (csr.indptr, csr.indices, csr.data)
else:
assert(False), "Distribution not supported: %s" % (distribution)
Expand All @@ -345,15 +346,17 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
assert(False), "unknown storage type"
return False

def rand_ndarray(shape, stype='default', density=None, dtype=None,
modifier_func=None, shuffle_csr_indices=False, distribution=None):
def rand_ndarray(shape, stype='default', density=None, dtype=None, modifier_func=None,
shuffle_csr_indices=False, distribution=None, ctx=None):
"""Generate a random sparse ndarray. Returns the generated ndarray."""
ctx = ctx if ctx else default_context()
if stype == 'default':
arr = mx.nd.array(random_arrays(shape), dtype=dtype)
arr = mx.nd.array(random_arrays(shape), dtype=dtype, ctx=ctx)
else:
arr, _ = rand_sparse_ndarray(shape, stype, density=density,
modifier_func=modifier_func, dtype=dtype,
shuffle_csr_indices=shuffle_csr_indices,
distribution=distribution)
distribution=distribution, ctx=ctx)
return arr


Expand Down

0 comments on commit 1eeb013

Please sign in to comment.