Skip to content

Commit

Permalink
Fix the monitor_callback invalid issue during calibration with variab…
Browse files Browse the repository at this point in the history
…le input shapes (apache#18705)
  • Loading branch information
ciyongch authored and ChaiBapchya committed Aug 15, 2020
1 parent 818549f commit fc2e3eb
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx):
self._aux_dict = None
self._output_dict = None
self._monitor_callback = None
self._monitor_all = None
self._ctx = copy.deepcopy(ctx)
self._grad_req = copy.deepcopy(grad_req)
self._group2ctx = copy.deepcopy(group2ctx)
Expand Down Expand Up @@ -253,6 +254,7 @@ def set_monitor_callback(self, callback, monitor_all=False):
"""
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p)
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
self._monitor_all = monitor_all
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
self.handle,
self._monitor_callback,
Expand Down Expand Up @@ -477,6 +479,13 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs):
executor.arg_arrays = arg_arrays
executor.grad_arrays = grad_arrays
executor.aux_arrays = aux_arrays
if (self._monitor_callback is not None) and (self._monitor_all is not None):
# rebind callback to the new executor if the callback is valid
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
handle,
self._monitor_callback,
None,
ctypes.c_int(self._monitor_all)))
return executor

def debug_str(self):
Expand Down
53 changes: 53 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8364,6 +8364,59 @@ def get_output_names_callback(name, arr):
check_name(us_sym, ['data', 'pooling_data', 'pooling_output'])
del os.environ['MXNET_SUBGRAPH_BACKEND']

@with_seed()
def test_monitor_with_variable_input_shape():
output = {}

def get_output_min_callback(name, arr):
name = py_str(name)
handle = ctypes.cast(arr, NDArrayHandle)
arr = NDArray(handle, writable=False)
min_val = mx.ndarray.min(arr).asscalar()
if name in output:
output[name] = min(output[name], min_val)
else:
output[name] = min_val

def check_result(output, names):
assert len(output) > 0
for k, v in output.items():
assert k in names
assert v is not None

is_windows = sys.platform.startswith('win')
if (is_windows):
# Windows doesn't support set environment variable on the fly, so disable it for now
pass
else:
# Disable subgraph in case subgraph will replace symbol
os.environ['MXNET_SUBGRAPH_BACKEND'] = "NONE"

batch_size = 1
op_name = 'conv'
dshape = (batch_size, 3, 10, 10)
data = mx.sym.Variable('data', shape=dshape)
sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=1, name=op_name)

mod = mx.module.Module(symbol=sym, label_names=None)
mod.bind(for_training=False, data_shapes=[('data', dshape)])
mod.init_params()
mod._exec_group.execs[0].set_monitor_callback(get_output_min_callback, monitor_all=True)

new_dshape = dshape[:-1] + (dshape[-1] + 4,)
new_data = mx.nd.random.uniform(shape=new_dshape)
new_data = mx.io.NDArrayIter(data=new_data, batch_size=batch_size)
new_data = DummyIter(new_data)

for batch in new_data:
mod.forward(data_batch=batch, is_train=False)
mx.nd.waitall()
break

name_list = ['data', 'conv_data', 'conv_weight', 'conv_bias', 'conv_output']
check_result(output, name_list)
del os.environ['MXNET_SUBGRAPH_BACKEND']

@with_seed()
@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/13915")
def test_activation():
Expand Down

0 comments on commit fc2e3eb

Please sign in to comment.