Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
Merge pull request #186 from marpulli/fix_variables
Browse files Browse the repository at this point in the history
Add ability to fix/unfix parameters
  • Loading branch information
zhenwendai authored Jul 1, 2019
2 parents 28c57ff + 380522d commit 587026a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
29 changes: 29 additions & 0 deletions mxfusion/inference/inference_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,35 @@ def fix_all(self):
for p in self.param_dict.values():
p.grad_req = 'null'

def fix_variable(self, variable, value=None):
"""
Fixes a variable so that it isn't changed by the inference algorithm. Optionally a value can be specified to
which to variable's value will be fixed. If a value is not specified, the variable will be fixed to it's current
value
:param variable: The variable to be fixed
:type variable: MXFusion Variable
:param value: (Optional) Value to which the variable will be fixed
:type value: MXNet NDArray
"""

if value is not None:
self[variable] = value

parameter = self.param_dict[variable.uuid]
parameter.grad_req = 'null'

def unfix_variable(self, variable):
"""
Allows a variable to be changed by the inference algorithm if it has been previously fixed
:param variable: The variable to be unfixed
:type variable: MXFusion Variable
"""

parameter = self.param_dict[variable.uuid]
parameter.grad_req = 'write'

@property
def param_dict(self):
return self._params
Expand Down
62 changes: 62 additions & 0 deletions testing/inference/test_inference_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import unittest

import mxnet as mx
import numpy as np

from mxfusion.common.config import get_default_dtype
from mxfusion.components.variables import PositiveTransformation
from mxfusion.inference import BatchInferenceLoop, GradBasedInference, MAP
import mxfusion as mf


class InferenceParameterTests(unittest.TestCase):
def setUp(self):
dtype = get_default_dtype()
m = mf.models.Model()
m.mean = mf.components.Variable()
m.var = mf.components.Variable(transformation=PositiveTransformation())
m.N = mf.components.Variable()
m.x = mf.components.distributions.Normal.define_variable(mean=m.mean, variance=m.var, shape=(m.N,))
m.y = mf.components.distributions.Normal.define_variable(mean=m.x, variance=mx.nd.array([1], dtype=dtype), shape=(m.N,))
self.m = m

def test_variable_fixing(self):
N = 10
dtype = get_default_dtype()
observed = [self.m.y]

# First check the parameter varies if it isn't fixed
alg = MAP(model=self.m, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
infr.initialize(y=mx.nd.array(np.random.rand(N)))
infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
assert infr.params[self.m.x.factor.mean] != mx.nd.ones(1)

# Now fix parameter and check it does not change
alg = MAP(model=self.m, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
infr.initialize(y=mx.nd.array(np.random.rand(N)))
infr.params.fix_variable(self.m.x.factor.mean, mx.nd.ones(1))
infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1)

def test_variable_unfixing(self):
N = 10
y = np.random.rand(N)
dtype = get_default_dtype()
observed = [self.m.y]

# First fix variable and run inference
alg = MAP(model=self.m, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
infr.initialize(y=mx.nd.array(np.random.rand(N)))
infr.params.fix_variable(self.m.x.factor.mean, mx.nd.ones(1))
infr.run(y=mx.nd.array(y, dtype=dtype), max_iter=10)

assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1)

# Now unfix and run inference again
infr.params.unfix_variable(self.m.x.factor.mean)
infr.run(y=mx.nd.array(y, dtype=dtype), max_iter=10)

assert infr.params[self.m.x.factor.mean] != mx.nd.ones(1)

0 comments on commit 587026a

Please sign in to comment.