Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix the monitor_callback invalid issue during calibration with variab…
Browse files Browse the repository at this point in the history
…le input shapes
  • Loading branch information
ciyongch committed Jul 1, 2020
1 parent fb3fea4 commit 72ba804
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx):
self._aux_dict = None
self._output_dict = None
self._monitor_callback = None
self._monitor_all = None
self._ctx = copy.deepcopy(ctx)
self._grad_req = copy.deepcopy(grad_req)
self._group2ctx = copy.deepcopy(group2ctx)
Expand Down Expand Up @@ -253,6 +254,7 @@ def set_monitor_callback(self, callback, monitor_all=False):
"""
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p)
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
self._monitor_all = monitor_all
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
self.handle,
self._monitor_callback,
Expand Down Expand Up @@ -469,6 +471,13 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs):
executor.arg_arrays = arg_arrays
executor.grad_arrays = grad_arrays
executor.aux_arrays = aux_arrays
if (self._monitor_callback is not None) and (self._monitor_all is not None):
# rebind callback to the new executor if the callback is valid
check_call(_LIB.MXExecutorSetMonitorCallbackEX(
handle,
self._monitor_callback,
None,
ctypes.c_int(self._monitor_all)))
return executor

def debug_str(self):
Expand Down

0 comments on commit 72ba804

Please sign in to comment.