Skip to content

Commit

Permalink
fix block.export (apache#17970)
Browse files Browse the repository at this point in the history
* fix block.export

```net.hybridize``` may optimize out some ops. These ops are alive in nn.Block(also nn.HybridBlock), but its names are not contained in symbol's ```arg_names``` list. So ignore these ops except that their name are end with 'running_mean' or 'running_var'.

* Update block.py

let user can save their extra param.

* add allow_extra

add allow_extra to let user decide whether to save extra parameters or not.

* Update block.py

add moving_mean and moving_var when export model with SymbolBlock

* Update python/mxnet/gluon/block.py

typo

Co-authored-by: Sheng Zha <szha@users.noreply.github.com>

* Update block.py

* Update block.py

* Update python/mxnet/gluon/block.py

Co-authored-by: Leonard Lausen <leonard@lausen.nl>

Co-authored-by: Sheng Zha <szha@users.noreply.github.com>
Co-authored-by: Leonard Lausen <leonard@lausen.nl>
  • Loading branch information
3 people committed Nov 17, 2020
1 parent 6c5730a commit 6130612
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,12 +1278,16 @@ def export(self, path, epoch=0, remove_amp_cast=True):
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
arg_dict = {}
for name, param in self.collect_params().items():
if name in arg_names:
arg_dict['arg:%s'%name] = param._reduce()
else:
assert name in aux_names
arg_dict['aux:%s'%name] = param._reduce()
for is_arg, name, param in self._cached_op_args:
if not is_arg:
if name in arg_names:
arg_dict['arg:{}'.format(name)] = param._reduce()
else:
if name not in aux_names:
warnings.warn('Parameter "{name}" is not found in the graph. '
.format(name=name), stacklevel=3)
else:
arg_dict['aux:%s'%name] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
Expand Down

0 comments on commit 6130612

Please sign in to comment.