Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Add ability to fix/unfix parameters #186

Merged
merged 3 commits into from
Jul 1, 2019
Merged

Conversation

marpulli
Copy link
Contributor

Issue #, if available: #94

Description of changes: Added convenience functions to fix/unfix parameters on the InferenceParameters class. I didn't think it was necessary to put them in the Inference class too

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.

@codecov-io
Copy link

codecov-io commented Jun 27, 2019

Codecov Report

Merging #186 into develop will decrease coverage by <.01%.
The diff coverage is 87.5%.

Impacted file tree graph

@@             Coverage Diff             @@
##           develop     #186      +/-   ##
===========================================
- Coverage    88.15%   88.15%   -0.01%     
===========================================
  Files           80       80              
  Lines         4077     4085       +8     
  Branches       692      693       +1     
===========================================
+ Hits          3594     3601       +7     
  Misses         310      310              
- Partials       173      174       +1
Impacted Files Coverage Δ
mxfusion/inference/inference_parameters.py 84.07% <87.5%> (+0.26%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 28c57ff...380522d. Read the comment docs.

Copy link
Contributor

@meissnereric meissnereric left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for the addition Mark.

m.y = mf.components.distributions.Normal.define_variable(mean=m.x, variance=mx.nd.array([1], dtype=dtype), shape=(m.N,))
self.m = m

def test_variable_fixing(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add onto this test something like this that runs inference without fixing first, to verify that the value does change.

       alg = MAP(model=self.m, observed=observed)
        infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
        infr.initialize(y=mx.nd.array(np.random.rand(N)))
        infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
        assert infr.params[self.m.x.factor.mean] != mx.nd.ones(1)

       # Create m2?
        alg = MAP(model=self.m2, observed=observed)
        infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
        infr.initialize(y=mx.nd.array(np.random.rand(N)))
        infr.params.fix_variable(self.m.x.factor.mean, mx.nd.ones(1))
        infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
        assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1)
``

infr.initialize(y=mx.nd.array(np.random.rand(N)))
infr.params.fix_variable(self.m.x.factor.mean, mx.nd.ones(1))
infr.run(y=mx.nd.array(y, dtype=dtype), max_iter=10)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1) here?

@meissnereric meissnereric added the enhancement New feature or request label Jun 28, 2019
@meissnereric meissnereric added this to the MXFusion v0.4.0 milestone Jun 28, 2019
@zhenwendai zhenwendai merged commit 587026a into amzn:develop Jul 1, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants