diff --git a/docs/api/python/gluon/gluon.md b/docs/api/python/gluon/gluon.md index 2ae766fdcba3..f523e649a458 100644 --- a/docs/api/python/gluon/gluon.md +++ b/docs/api/python/gluon/gluon.md @@ -34,6 +34,7 @@ in Python and then deploy with symbolic graph in C++ and Scala. :nosignatures: Parameter + Constant ParameterDict ``` diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index fa3828591140..7dc724339265 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -18,8 +18,9 @@ # coding: utf-8 # pylint: disable= """Neural network parameter.""" -__all__ = ['DeferredInitializationError', 'Parameter', 'ParameterDict', - 'tensor_types'] +__all__ = ['DeferredInitializationError', 'Parameter', 'Constant', + 'ParameterDict', 'tensor_types'] + from collections import OrderedDict import warnings @@ -459,6 +460,46 @@ def cast(self, dtype): autograd.mark_variables(self._data, self._grad, self.grad_req) +class Constant(Parameter): + """A constant parameter for holding immutable tensors. + `Constant`s are ignored by `autograd` and `Trainer`, thus their values + will not change during training. But you can still update their values + manually with the `set_data` method. + + `Constant`s can be created with either:: + + const = mx.gluon.Constant('const', [[1,2],[3,4]]) + + or:: + + class Block(gluon.Block): + def __init__(self, **kwargs): + super(Block, self).__init__(**kwargs) + self.const = self.params.get_constant('const', [[1,2],[3,4]]) + + Parameter + --------- + name : str + Name of the parameter. + value : array-like + Initial value for the constant. + """ + def __init__(self, name, value): + if not isinstance(value, ndarray.NDArray): + value = ndarray.array(value) + self.value = value + + class Init(initializer.Initializer): + def _init_weight(self, _, arr): + value.copyto(arr) + init_name = 'Constant_{}_{}'.format(name, id(self)) + initializer.alias(init_name)(Init) + + super(Constant, self).__init__( + name, grad_req='null', shape=value.shape, dtype=value.dtype, + init=init_name) + + class ParameterDict(object): """A dictionary managing a set of parameters. @@ -548,6 +589,45 @@ def get(self, name, **kwargs): setattr(param, k, v) return param + def get_constant(self, name, value=None): + """Retrieves a :py:class:`Constant` with name ``self.prefix+name``. If not found, + :py:func:`get` will first try to retrieve it from "shared" dict. If still not + found, :py:func:`get` will create a new :py:class:`Constant` with key-word + arguments and insert it to self. + + Constants + ---------- + name : str + Name of the desired Constant. It will be prepended with this dictionary's + prefix. + value : array-like + Initial value of constant. + + Returns + ------- + Constant + The created or retrieved :py:class:`Constant`. + """ + name = self.prefix + name + param = self._get_impl(name) + if param is None: + if value is None: + raise KeyError("No constant named {}. Please specify value " \ + "if you want to create a new constant.".format( + name)) + param = Constant(name, value) + self._params[name] = param + elif value is not None: + assert isinstance(param, Constant), \ + "Parameter {} already exists but it is not a constant.".format( + name) + if isinstance(value, nd.NDArray): + value = value.asnumpy() + assert param.shape == value.shape and \ + (param.value.asnumpy() == value).all(), \ + "Constant {} already exists but it's value doesn't match new value" + return param + def update(self, other): """Copies all Parameters in ``other`` to self.""" for k, v in other.items(): diff --git a/python/mxnet/registry.py b/python/mxnet/registry.py index 4c131a1b755a..0e4ac1c0b8c1 100644 --- a/python/mxnet/registry.py +++ b/python/mxnet/registry.py @@ -69,7 +69,8 @@ def register(klass, name=None): assert issubclass(klass, base_class), \ "Can only register subclass of %s"%base_class.__name__ if name is None: - name = klass.__name__.lower() + name = klass.__name__ + name = name.lower() if name in registry: warnings.warn( "\033[91mNew %s %s.%s registered with name %s is" diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 50b60a2db3eb..89f521543706 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -51,6 +51,34 @@ def test_paramdict(): params.load('test.params', mx.cpu()) +@with_seed() +def test_constant(): + class Test(gluon.HybridBlock): + def __init__(self, **kwargs): + super(Test, self).__init__(**kwargs) + self.value = np.asarray([[1,2], [3,4]]) + self.const = self.params.get_constant('const', self.value) + + def hybrid_forward(self, F, x, const): + return x + const + + test = Test() + test.initialize() + trainer = gluon.Trainer(test.collect_params(), 'sgd', + {'learning_rate': 1.0, 'momentum': 0.5}) + + with mx.autograd.record(): + x = mx.nd.ones((2,2)) + x.attach_grad() + y = test(x) + y.backward() + + trainer.step(1) + + assert (test.const.data().asnumpy() == test.value).all() + assert (x.grad.asnumpy() == 1).all() + + @with_seed() def test_parameter_sharing(): class Net(gluon.Block):