Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Refactor kvstore test
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Nov 8, 2018
1 parent 722ad7a commit e1d60bd
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 47 deletions.
21 changes: 21 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,3 +1998,24 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)


class EnvManager(object):
"""Class to set an environment variable with 'with' idiom:
with EnvManager(key,val):
...
"""
def __init__(self, key, val):
self._key = key
self._next_val = val
self._prev_val = None

def __enter__(self):
self._prev_val = os.environ.get(self._key)
os.environ[self._key] = self._next_val

def __exit__(self, ptype, value, trace):
if self._prev_val:
os.environ[self._key] = self._prev_val
else:
del os.environ[self._key]
44 changes: 14 additions & 30 deletions tests/python/gpu/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,23 @@
import numpy as np
import unittest
import os
import logging

from mxnet.test_utils import EnvManager

shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)]
keys = [1,2,3,4,5,6,7]
num_gpus = len(mx.test_utils.list_gpus())


if num_gpus > 8 :
print("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
print("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
logging.warn("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
logging.warn("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
num_gpus = 8;

gpus = range(1, 1+num_gpus)

class EnvManager:
def __init__(self, key, val):
self._key = key
self._next_val = val
self._prev_val = None

def __enter__(self):
try:
self._prev_val = os.environ[self._key]
except KeyError:
self._prev_val = ''
os.environ[self._key] = self._next_val

def __exit__(self, ptype, value, trace):
os.environ[self._key] = self._prev_val

@unittest.skipIf(mx.context.num_gpus() < 1, "test_device_pushpull needs at least 1 GPU")
def test_device_pushpull():
def check_dense_pushpull(kv_type):
for shape, key in zip(shapes, keys):
Expand All @@ -63,20 +51,16 @@ def check_dense_pushpull(kv_type):
for x in range(n_gpus):
assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0)

envs1 = '1'
key1 = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
envs2 = ['','1']
key2 = 'MXNET_KVSTORE_USETREE'
for i in range(2):
for val2 in envs2:
with EnvManager(key2, val2):
kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
kvstore_usetree_values = ['','1']
kvstore_usetree = 'MXNET_KVSTORE_USETREE'
for _ in range(2):
for x in kvstore_usetree_values:
with EnvManager(kvstore_usetree, x):
check_dense_pushpull('local')
check_dense_pushpull('device')

os.environ[key1] = envs1
os.environ[key1] = ''

print ("Passed")
os.environ[kvstore_tree_array_bound] = '1'
del os.environ[kvstore_tree_array_bound]

if __name__ == '__main__':
test_device_pushpull()
18 changes: 1 addition & 17 deletions tests/python/gpu/test_kvstore_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import mxnet as mx
import numpy as np
import unittest
from mxnet.test_utils import assert_almost_equal, default_context
from mxnet.test_utils import assert_almost_equal, default_context, EnvManager
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown
Expand All @@ -30,22 +30,6 @@
keys = [5, 7, 11]
str_keys = ['b', 'c', 'd']

class EnvManager:
def __init__(self, key, val):
self._key = key
self._next_val = val
self._prev_val = None

def __enter__(self):
try:
self._prev_val = os.environ[self._key]
except KeyError:
self._prev_val = ''
os.environ[self._key] = self._next_val

def __exit__(self, ptype, value, trace):
os.environ[self._key] = self._prev_val

def init_kv_with_str(stype='default', kv_type='local'):
"""init kv """
kv = mx.kv.create(kv_type)
Expand Down

0 comments on commit e1d60bd

Please sign in to comment.