-
-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Get mcmc sampling to work #9
Conversation
@ferrine If we give it only the final logp function then we wont be able to sample traces of intermediate RVs. And should the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review
Graph namespace
TF uses graph.as_default()
to create graph. So if you repeatedly call model.unobserved
it will spoil the namespace totally. The below snippet can replicate the problem
import tensorflow as tf
graph = tf.get_default_graph()
sess = tf.InteractiveSession(graph=graph)
def model():
return tf.ones([1])
model()
model()
model()
graph.as_graph_def()
The output contains a lot of versions of tf.ones(). One way to solve the problem it to put all internal things into an auxiliary namespace.
pymc4/inference/sampling/sample.py
Outdated
for name, shape in model.unobserved.iteritems(): | ||
initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name))) | ||
for name in model.unobserved: | ||
initial_state.append(.5 * tf.ones(model.unobserved[name].shape, name="init_{}".format(name))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will create a lot of problems with namespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not know the way to go but to avoid repeated calls of self._f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unobserved = {}
for i in self.variables:
if self.variables[i] not in self.observed.values():
unobserved[i] = self.variables[i]
unobserved = collections.OrderedDict(unobserved)
return unobserved
I could do this to avoid unobserved calling f()
multiple times?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [2]: graph = tf.Graph()
...:
...: with graph.as_default():
...: ed.Normal(0., 1.)
...: print(graph.as_graph_def())
...:
Outputs:
node {
name: "Normal/loc/input"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.0
}
}
}
}
node {
name: "Normal/loc"
op: "Identity"
input: "Normal/loc/input"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Normal/scale/input"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 1.0
}
}
}
}
node {
name: "Normal/scale"
op: "Identity"
input: "Normal/scale/input"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Normal_1/sample/sample_shape"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
}
}
}
}
}
}
node {
name: "Normal_1/sample/Normal/batch_shape_tensor/batch_shape"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
}
}
}
}
}
}
node {
name: "Normal_1/sample/concat/values_0"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 1
}
}
}
}
node {
name: "Normal_1/sample/concat/axis"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "Normal_1/sample/concat"
op: "ConcatV2"
input: "Normal_1/sample/concat/values_0"
input: "Normal_1/sample/Normal/batch_shape_tensor/batch_shape"
input: "Normal_1/sample/concat/axis"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "Tidx"
value {
type: DT_INT32
}
}
}
node {
name: "Normal_1/sample/random_normal/mean"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.0
}
}
}
}
node {
name: "Normal_1/sample/random_normal/stddev"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 1.0
}
}
}
}
node {
name: "Normal_1/sample/random_normal/RandomStandardNormal"
op: "RandomStandardNormal"
input: "Normal_1/sample/concat"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "seed"
value {
i: 0
}
}
attr {
key: "seed2"
value {
i: 0
}
}
}
node {
name: "Normal_1/sample/random_normal/mul"
op: "Mul"
input: "Normal_1/sample/random_normal/RandomStandardNormal"
input: "Normal_1/sample/random_normal/stddev"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Normal_1/sample/random_normal"
op: "Add"
input: "Normal_1/sample/random_normal/mul"
input: "Normal_1/sample/random_normal/mean"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Normal_1/sample/mul"
op: "Mul"
input: "Normal_1/sample/random_normal"
input: "Normal/scale"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Normal_1/sample/add"
op: "Add"
input: "Normal_1/sample/mul"
input: "Normal/loc"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Normal_1/sample/Shape"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 1
}
}
}
}
node {
name: "Normal_1/sample/strided_slice/stack"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 1
}
}
}
}
node {
name: "Normal_1/sample/strided_slice/stack_1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 0
}
}
}
}
node {
name: "Normal_1/sample/strided_slice/stack_2"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 1
}
}
int_val: 1
}
}
}
}
node {
name: "Normal_1/sample/strided_slice"
op: "StridedSlice"
input: "Normal_1/sample/Shape"
input: "Normal_1/sample/strided_slice/stack"
input: "Normal_1/sample/strided_slice/stack_1"
input: "Normal_1/sample/strided_slice/stack_2"
attr {
key: "Index"
value {
type: DT_INT32
}
}
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "begin_mask"
value {
i: 0
}
}
attr {
key: "ellipsis_mask"
value {
i: 0
}
}
attr {
key: "end_mask"
value {
i: 1
}
}
attr {
key: "new_axis_mask"
value {
i: 0
}
}
attr {
key: "shrink_axis_mask"
value {
i: 0
}
}
}
node {
name: "Normal_1/sample/concat_1/axis"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "Normal_1/sample/concat_1"
op: "ConcatV2"
input: "Normal_1/sample/sample_shape"
input: "Normal_1/sample/strided_slice"
input: "Normal_1/sample/concat_1/axis"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "Tidx"
value {
type: DT_INT32
}
}
}
node {
name: "Normal_1/sample/Reshape"
op: "Reshape"
input: "Normal_1/sample/add"
input: "Normal_1/sample/concat_1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "Tshape"
value {
type: DT_INT32
}
}
}
versions {
producer: 26
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That may be okay, if graph modifications are not frequently called (or it is hard to do) by user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a special interceptor for this purpose
pymc4/model/base.py
Outdated
@@ -68,14 +68,41 @@ def get_mode(state, rv, *args, **kwargs): | |||
returns = self.session.run(list(values_collector.result.values())) | |||
return dict(zip(values_collector.result.keys(), returns)) | |||
|
|||
def target_log_prob_fn(self, *args, **kwargs): | |||
def log_prob_fn(self, x, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is x for here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary, removing it.
pymc4/model/base.py
Outdated
def log_joint_fn(*args, **kwargs): | ||
states = dict(zip(self.unobserved.keys(), args)) | ||
states.update(self.observed) | ||
log_probs = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/pymc-devs/pymc4/blob/functional/pymc4/util/interceptors.py#L110
collect_log_prob = CollectLogProb(states)
with ed.interception(collect_log_prob):
self._f(self._cfg)
return collect_log_prob.result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changing it, was facing problems with states before. Now working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ferrine interceptors.CollectLogProb
only works with model like
@model.define
def process(cfg=None):
mu = ed.Normal(0., 1., name="mu")
obs = ed.Normal(0., 1., name="obs")
return obs
and not model like
@model.define
def process(cfg=None):
mu = ed.Normal(0., 1., name="mu")
obs = ed.Normal(mu, 1., name="obs")
return obs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, what's happening?
pymc4/model/base.py
Outdated
@@ -67,6 +68,41 @@ def get_mode(state, rv, *args, **kwargs): | |||
returns = self.session.run(list(values_collector.result.values())) | |||
return dict(zip(values_collector.result.keys(), returns)) | |||
|
|||
def log_prob_fn(self, x, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I'm not sure this will work. ancestors of RV depend on the RV, here you do not replace RV with value=kwargs.get(i)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few styling nitpicks around dictionary iteration!
pymc4/inference/sampling/sample.py
Outdated
@@ -10,7 +10,8 @@ def sample(model, | |||
num_leapfrog_steps=3, | |||
numpy=True): | |||
initial_state = [] | |||
for name, shape in model.unobserved.iteritems(): | |||
for name in model.unobserved.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use for name, (_, shape, _) in model.items():
to indicate that dist
and rv
are not used in the loop
pymc4/model/base.py
Outdated
@property | ||
def unobserved(self): | ||
unobserved = {} | ||
for i in self.variables: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for name, variable in self.variables.items():
if variable not in self.observed.values():
unobserved[name] = variable
pymc4/model/base.py
Outdated
@@ -83,6 +100,16 @@ def graph(self): | |||
def observed(self): | |||
return self._observed | |||
|
|||
@property | |||
def unobserved(self): | |||
unobserved = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be initialized as an OrderedDict
: currently in Python < 3.6, the return value will not be ordered, since you built an (unordered) dictionary, then turned it into an OrderedDict
. We're targeting 3.6 and higher though, in which case you do not need the OrderedDict
at all, since all dicts now maintain insertion order.
That's a long way to say: I would just make a plain dictionary, but if you use OrderedDict
, it needs to be initialized as such.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice point
@@ -9,7 +9,7 @@ | |||
'CollectLogProb' | |||
] | |||
|
|||
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape') | |||
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still worried about this solution. Variable description is supposed to be collected far before sampling (or what about changing this?). So when you first collect VariableInfo for the first time and get these RVs, you save temporary nodes. When you collect LogProb you again run the model and variables involves there are totally different from those that are stored in VariableDescription
. That's why I did not store them there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with @ferrine - The RVs are not initialized until we configure the model (following the idea in the API discussion doc, we create the RVs when we call model.configure(...)
or model.sample(...)
). This means that we record the Distribution and the relationship between RVs, but the actually RVs are only initialized when we actually using them (ie, in the evaluation of logp).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem i am facing if RVs are not stored in the VariableDescription is that it doesn't store the specifics of any distribution (like loc or scale) even if they are mentioned in the model definition. So we will have to collect all this already provided info somehow.
@model.define
def process():
mu = ed.Normal(loc=0., scale=10., name="mu")
# here we lose the info that it has loc 0 and scale 10 without RV.
We can try defining a new Interceptor which does this for us for each RV.
We can then overwrite(replace existing and add new) the collected data every-time we call configure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let's make it later
@ferrine Could I merge this PR? |
pymc4/model/base.py
Outdated
states.update(self.observed) | ||
log_probs = [] | ||
|
||
def interceptor(f, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we want use class based interceptor for consistency?
@@ -9,7 +9,7 @@ | |||
'CollectLogProb' | |||
] | |||
|
|||
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape') | |||
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let's make it later
assert len(model.observed) == 1 | ||
assert not model.unobserved | ||
|
||
model.reset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We decided to meke a copy of model each time state changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one is not critical though, refactoring interceptor usage is what is really needed to finish this PR (#9 (diff))
@ferrine I have changed it to a class based interceptor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't find sampling test, I think we need one.
And test point is better get via |
@ferrine Currently |
num_results=5000, | ||
num_burnin_steps=3000, | ||
step_size=.4, | ||
num_leapfrog_steps=3, | ||
numpy=True): | ||
initial_state = [] | ||
for name, shape in model.unobserved.iteritems(): | ||
for name, (_, shape, _) in model.unobserved.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for name, point in model.test_point(mode=mode):
initial_state.append(point)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be done in next PR
TODO:
Write Tests