Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Add ability to fix/unfix parameters #186

Merged
merged 3 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions mxfusion/inference/inference_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,35 @@ def fix_all(self):
for p in self.param_dict.values():
p.grad_req = 'null'

def fix_variable(self, variable, value=None):
"""
Fixes a variable so that it isn't changed by the inference algorithm. Optionally a value can be specified to
which to variable's value will be fixed. If a value is not specified, the variable will be fixed to it's current
value

:param variable: The variable to be fixed
:type variable: MXFusion Variable
:param value: (Optional) Value to which the variable will be fixed
:type value: MXNet NDArray
"""

if value is not None:
self[variable] = value

parameter = self.param_dict[variable.uuid]
parameter.grad_req = 'null'

def unfix_variable(self, variable):
"""
Allows a variable to be changed by the inference algorithm if it has been previously fixed

:param variable: The variable to be unfixed
:type variable: MXFusion Variable
"""

parameter = self.param_dict[variable.uuid]
parameter.grad_req = 'write'

@property
def param_dict(self):
return self._params
Expand Down
62 changes: 62 additions & 0 deletions testing/inference/test_inference_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import unittest

import mxnet as mx
import numpy as np

from mxfusion.common.config import get_default_dtype
from mxfusion.components.variables import PositiveTransformation
from mxfusion.inference import BatchInferenceLoop, GradBasedInference, MAP
import mxfusion as mf


class InferenceParameterTests(unittest.TestCase):
def setUp(self):
dtype = get_default_dtype()
m = mf.models.Model()
m.mean = mf.components.Variable()
m.var = mf.components.Variable(transformation=PositiveTransformation())
m.N = mf.components.Variable()
m.x = mf.components.distributions.Normal.define_variable(mean=m.mean, variance=m.var, shape=(m.N,))
m.y = mf.components.distributions.Normal.define_variable(mean=m.x, variance=mx.nd.array([1], dtype=dtype), shape=(m.N,))
self.m = m

def test_variable_fixing(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add onto this test something like this that runs inference without fixing first, to verify that the value does change.

       alg = MAP(model=self.m, observed=observed)
        infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
        infr.initialize(y=mx.nd.array(np.random.rand(N)))
        infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
        assert infr.params[self.m.x.factor.mean] != mx.nd.ones(1)

       # Create m2?
        alg = MAP(model=self.m2, observed=observed)
        infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
        infr.initialize(y=mx.nd.array(np.random.rand(N)))
        infr.params.fix_variable(self.m.x.factor.mean, mx.nd.ones(1))
        infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
        assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1)
``

N = 10
dtype = get_default_dtype()
observed = [self.m.y]

# First check the parameter varies if it isn't fixed
alg = MAP(model=self.m, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
infr.initialize(y=mx.nd.array(np.random.rand(N)))
infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
assert infr.params[self.m.x.factor.mean] != mx.nd.ones(1)

# Now fix parameter and check it does not change
alg = MAP(model=self.m, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
infr.initialize(y=mx.nd.array(np.random.rand(N)))
infr.params.fix_variable(self.m.x.factor.mean, mx.nd.ones(1))
infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1)

def test_variable_unfixing(self):
N = 10
y = np.random.rand(N)
dtype = get_default_dtype()
observed = [self.m.y]

# First fix variable and run inference
alg = MAP(model=self.m, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
infr.initialize(y=mx.nd.array(np.random.rand(N)))
infr.params.fix_variable(self.m.x.factor.mean, mx.nd.ones(1))
infr.run(y=mx.nd.array(y, dtype=dtype), max_iter=10)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1) here?

assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1)

# Now unfix and run inference again
infr.params.unfix_variable(self.m.x.factor.mean)
infr.run(y=mx.nd.array(y, dtype=dtype), max_iter=10)

assert infr.params[self.m.x.factor.mean] != mx.nd.ones(1)