Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.6] Fix the monitor_callback invalid issue during calibration with…
Browse files Browse the repository at this point in the history
… variable input shapes (#18632)

* Fix the monitor_callback invalid issue during calibration with variable input shapes

* retrigger CI

* Add UT for monitor check and disable codecov
  • Loading branch information
ciyongch authored Jul 2, 2020
1 parent fb3fea4 commit e503704
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ codecov:
require_ci_to_pass: yes

coverage:
status:
project: off
patch: off
precision: 2
round: down
range: "70...100"
Expand Down
9 changes: 9 additions & 0 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx):
self._aux_dict = None
self._output_dict = None
self._monitor_callback = None
self._monitor_all = None
self._ctx = copy.deepcopy(ctx)
self._grad_req = copy.deepcopy(grad_req)
self._group2ctx = copy.deepcopy(group2ctx)
Expand Down Expand Up @@ -253,6 +254,7 @@ def set_monitor_callback(self, callback, monitor_all=False):
"""
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p)
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
self._monitor_all = monitor_all
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
self.handle,
self._monitor_callback,
Expand Down Expand Up @@ -469,6 +471,13 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs):
executor.arg_arrays = arg_arrays
executor.grad_arrays = grad_arrays
executor.aux_arrays = aux_arrays
if (self._monitor_callback is not None) and (self._monitor_all is not None):
# rebind callback to the new executor if the callback is valid
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
handle,
self._monitor_callback,
None,
ctypes.c_int(self._monitor_all)))
return executor

def debug_str(self):
Expand Down
55 changes: 54 additions & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8269,6 +8269,59 @@ def get_output_names_callback(name, arr):
check_name(us_sym, ['data', 'pooling_data', 'pooling_output'])
del os.environ['MXNET_SUBGRAPH_BACKEND']

@with_seed()
def test_monitor_with_variable_input_shape():
output = {}

def get_output_min_callback(name, arr):
name = py_str(name)
handle = ctypes.cast(arr, NDArrayHandle)
arr = NDArray(handle, writable=False)
min_val = mx.ndarray.min(arr).asscalar()
if name in output:
output[name] = min(output[name], min_val)
else:
output[name] = min_val

def check_result(output, names):
assert len(output) > 0
for k, v in output.items():
assert k in names
assert v is not None

is_windows = sys.platform.startswith('win')
if (is_windows):
# Windows doesn't support set environment variable on the fly, so disable it for now
pass
else:
# Disable subgraph in case subgraph will replace symbol
os.environ['MXNET_SUBGRAPH_BACKEND'] = "NONE"

batch_size = 1
op_name = 'conv'
dshape = (batch_size, 3, 10, 10)
data = mx.sym.Variable('data', shape=dshape)
sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=1, name=op_name)

mod = mx.module.Module(symbol=sym, label_names=None)
mod.bind(for_training=False, data_shapes=[('data', dshape)])
mod.init_params()
mod._exec_group.execs[0].set_monitor_callback(get_output_min_callback, monitor_all=True)

new_dshape = dshape[:-1] + (dshape[-1] + 4,)
new_data = mx.nd.random.uniform(shape=new_dshape)
new_data = mx.io.NDArrayIter(data=new_data, batch_size=batch_size)
new_data = DummyIter(new_data)

for batch in new_data:
mod.forward(data_batch=batch, is_train=False)
mx.nd.waitall()
break

name_list = ['data', 'conv_data', 'conv_weight', 'conv_bias', 'conv_output']
check_result(output, name_list)
del os.environ['MXNET_SUBGRAPH_BACKEND']

@with_seed()
@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/13915")
def test_activation():
Expand Down Expand Up @@ -9558,7 +9611,7 @@ def convert_bias(F, k_bias, v_bias, num_heads):
q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False,
num_hidden=qkv_units, no_bias=False)
att_score = mx.sym.contrib.interleaved_matmul_encdec_qk(
q_proj, kv_proj, heads=num_heads)
q_proj, kv_proj, heads=num_heads)
att_score = att_score + sonde
weighted_value = mx.sym.contrib.interleaved_matmul_encdec_valatt(
kv_proj, att_score, heads=num_heads)
Expand Down

0 comments on commit e503704

Please sign in to comment.