-
Notifications
You must be signed in to change notification settings - Fork 31
Conversation
Codecov Report
@@ Coverage Diff @@
## develop #186 +/- ##
===========================================
- Coverage 88.15% 88.15% -0.01%
===========================================
Files 80 80
Lines 4077 4085 +8
Branches 692 693 +1
===========================================
+ Hits 3594 3601 +7
Misses 310 310
- Partials 173 174 +1
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks for the addition Mark.
m.y = mf.components.distributions.Normal.define_variable(mean=m.x, variance=mx.nd.array([1], dtype=dtype), shape=(m.N,)) | ||
self.m = m | ||
|
||
def test_variable_fixing(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add onto this test something like this that runs inference without fixing first, to verify that the value does change.
alg = MAP(model=self.m, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
infr.initialize(y=mx.nd.array(np.random.rand(N)))
infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
assert infr.params[self.m.x.factor.mean] != mx.nd.ones(1)
# Create m2?
alg = MAP(model=self.m2, observed=observed)
infr = GradBasedInference(inference_algorithm=alg, grad_loop=BatchInferenceLoop())
infr.initialize(y=mx.nd.array(np.random.rand(N)))
infr.params.fix_variable(self.m.x.factor.mean, mx.nd.ones(1))
infr.run(y=mx.nd.array(np.random.rand(N), dtype=dtype), max_iter=10)
assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1)
``
infr.initialize(y=mx.nd.array(np.random.rand(N))) | ||
infr.params.fix_variable(self.m.x.factor.mean, mx.nd.ones(1)) | ||
infr.run(y=mx.nd.array(y, dtype=dtype), max_iter=10) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add an assert infr.params[self.m.x.factor.mean] == mx.nd.ones(1)
here?
Issue #, if available: #94
Description of changes: Added convenience functions to fix/unfix parameters on the
InferenceParameters
class. I didn't think it was necessary to put them in theInference
class tooBy submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.