Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get mcmc sampling to work #9

Merged
merged 17 commits into from
Jul 6, 2018
Merged

Conversation

sharanry
Copy link
Contributor

@sharanry sharanry commented Jun 13, 2018

  • Unobserved variables accessible
  • Sampling works

TODO:
Write Tests

@sharanry sharanry requested review from ferrine and removed request for ferrine June 13, 2018 18:37
@sharanry sharanry changed the title Get mcmc sampling to work [WIP] Get mcmc sampling to work Jun 13, 2018
@sharanry
Copy link
Contributor Author

sharanry commented Jun 13, 2018

@ferrine
What do you suggest model.target_log_prob_fn() give? The logp function of only the result of model.f or all the intermediate RVs too?

If we give it only the final logp function then we wont be able to sample traces of intermediate RVs.
One more problem is tfp.mcmc.sample_chain() might work for only single RV(in current implementation, final) logp function .

And should the model.unobserved contain only final RV, i.e, the result of f()?

@sharanry sharanry self-assigned this Jun 13, 2018
Copy link
Member

@ferrine ferrine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review

Graph namespace

TF uses graph.as_default() to create graph. So if you repeatedly call model.unobserved it will spoil the namespace totally. The below snippet can replicate the problem

import tensorflow as tf
graph = tf.get_default_graph()
sess = tf.InteractiveSession(graph=graph)
def model():
    return tf.ones([1])
model()
model()
model()
graph.as_graph_def()

The output contains a lot of versions of tf.ones(). One way to solve the problem it to put all internal things into an auxiliary namespace.

for name, shape in model.unobserved.iteritems():
initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name)))
for name in model.unobserved:
initial_state.append(.5 * tf.ones(model.unobserved[name].shape, name="init_{}".format(name)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will create a lot of problems with namespace

Copy link
Member

@ferrine ferrine Jun 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know the way to go but to avoid repeated calls of self._f

Copy link
Contributor Author

@sharanry sharanry Jun 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unobserved = {}
for i in self.variables:
    if self.variables[i] not in self.observed.values():
        unobserved[i] = self.variables[i]
        
unobserved = collections.OrderedDict(unobserved)
return unobserved

I could do this to avoid unobserved calling f() multiple times?

Copy link
Contributor Author

@sharanry sharanry Jun 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [2]: graph = tf.Graph()
   ...:
   ...: with graph.as_default():
   ...:     ed.Normal(0., 1.)
   ...:     print(graph.as_graph_def())
   ...:

Outputs:

node {
  name: "Normal/loc/input"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 0.0
      }
    }
  }
}
node {
  name: "Normal/loc"
  op: "Identity"
  input: "Normal/loc/input"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Normal/scale/input"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 1.0
      }
    }
  }
}
node {
  name: "Normal/scale"
  op: "Identity"
  input: "Normal/scale/input"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Normal_1/sample/sample_shape"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
          }
        }
      }
    }
  }
}
node {
  name: "Normal_1/sample/Normal/batch_shape_tensor/batch_shape"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
          }
        }
      }
    }
  }
}
node {
  name: "Normal_1/sample/concat/values_0"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 1
          }
        }
        int_val: 1
      }
    }
  }
}
node {
  name: "Normal_1/sample/concat/axis"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Normal_1/sample/concat"
  op: "ConcatV2"
  input: "Normal_1/sample/concat/values_0"
  input: "Normal_1/sample/Normal/batch_shape_tensor/batch_shape"
  input: "Normal_1/sample/concat/axis"
  attr {
    key: "N"
    value {
      i: 2
    }
  }
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "Tidx"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Normal_1/sample/random_normal/mean"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 0.0
      }
    }
  }
}
node {
  name: "Normal_1/sample/random_normal/stddev"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 1.0
      }
    }
  }
}
node {
  name: "Normal_1/sample/random_normal/RandomStandardNormal"
  op: "RandomStandardNormal"
  input: "Normal_1/sample/concat"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "seed"
    value {
      i: 0
    }
  }
  attr {
    key: "seed2"
    value {
      i: 0
    }
  }
}
node {
  name: "Normal_1/sample/random_normal/mul"
  op: "Mul"
  input: "Normal_1/sample/random_normal/RandomStandardNormal"
  input: "Normal_1/sample/random_normal/stddev"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Normal_1/sample/random_normal"
  op: "Add"
  input: "Normal_1/sample/random_normal/mul"
  input: "Normal_1/sample/random_normal/mean"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Normal_1/sample/mul"
  op: "Mul"
  input: "Normal_1/sample/random_normal"
  input: "Normal/scale"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Normal_1/sample/add"
  op: "Add"
  input: "Normal_1/sample/mul"
  input: "Normal/loc"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Normal_1/sample/Shape"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 1
          }
        }
        int_val: 1
      }
    }
  }
}
node {
  name: "Normal_1/sample/strided_slice/stack"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 1
          }
        }
        int_val: 1
      }
    }
  }
}
node {
  name: "Normal_1/sample/strided_slice/stack_1"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 1
          }
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Normal_1/sample/strided_slice/stack_2"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 1
          }
        }
        int_val: 1
      }
    }
  }
}
node {
  name: "Normal_1/sample/strided_slice"
  op: "StridedSlice"
  input: "Normal_1/sample/Shape"
  input: "Normal_1/sample/strided_slice/stack"
  input: "Normal_1/sample/strided_slice/stack_1"
  input: "Normal_1/sample/strided_slice/stack_2"
  attr {
    key: "Index"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "begin_mask"
    value {
      i: 0
    }
  }
  attr {
    key: "ellipsis_mask"
    value {
      i: 0
    }
  }
  attr {
    key: "end_mask"
    value {
      i: 1
    }
  }
  attr {
    key: "new_axis_mask"
    value {
      i: 0
    }
  }
  attr {
    key: "shrink_axis_mask"
    value {
      i: 0
    }
  }
}
node {
  name: "Normal_1/sample/concat_1/axis"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Normal_1/sample/concat_1"
  op: "ConcatV2"
  input: "Normal_1/sample/sample_shape"
  input: "Normal_1/sample/strided_slice"
  input: "Normal_1/sample/concat_1/axis"
  attr {
    key: "N"
    value {
      i: 2
    }
  }
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "Tidx"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Normal_1/sample/Reshape"
  op: "Reshape"
  input: "Normal_1/sample/add"
  input: "Normal_1/sample/concat_1"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "Tshape"
    value {
      type: DT_INT32
    }
  }
}
versions {
  producer: 26
}
For just one edward random variable there is so much change in the graph.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That may be okay, if graph modifications are not frequently called (or it is hard to do) by user.

Copy link
Member

@ferrine ferrine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a special interceptor for this purpose

@@ -68,14 +68,41 @@ def get_mode(state, rv, *args, **kwargs):
returns = self.session.run(list(values_collector.result.values()))
return dict(zip(values_collector.result.keys(), returns))

def target_log_prob_fn(self, *args, **kwargs):
def log_prob_fn(self, x, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is x for here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary, removing it.

def log_joint_fn(*args, **kwargs):
states = dict(zip(self.unobserved.keys(), args))
states.update(self.observed)
log_probs = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pymc-devs/pymc4/blob/functional/pymc4/util/interceptors.py#L110

collect_log_prob = CollectLogProb(states)
with ed.interception(collect_log_prob):
    self._f(self._cfg)
return collect_log_prob.result

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing it, was facing problems with states before. Now working.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ferrine interceptors.CollectLogProb only works with model like

@model.define
def process(cfg=None):
    mu = ed.Normal(0., 1., name="mu")
    obs = ed.Normal(0., 1., name="obs")
    return obs

and not model like

@model.define
def process(cfg=None):
    mu = ed.Normal(0., 1., name="mu")
    obs = ed.Normal(mu, 1., name="obs")
    return obs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, what's happening?

@@ -67,6 +68,41 @@ def get_mode(state, rv, *args, **kwargs):
returns = self.session.run(list(values_collector.result.values()))
return dict(zip(values_collector.result.keys(), returns))

def log_prob_fn(self, x, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I'm not sure this will work. ancestors of RV depend on the RV, here you do not replace RV with value=kwargs.get(i)

Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few styling nitpicks around dictionary iteration!

@@ -10,7 +10,8 @@ def sample(model,
num_leapfrog_steps=3,
numpy=True):
initial_state = []
for name, shape in model.unobserved.iteritems():
for name in model.unobserved.keys():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use for name, (_, shape, _) in model.items(): to indicate that dist and rv are not used in the loop

@property
def unobserved(self):
unobserved = {}
for i in self.variables:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for name, variable in self.variables.items():
    if variable not in self.observed.values():
        unobserved[name] = variable

@@ -83,6 +100,16 @@ def graph(self):
def observed(self):
return self._observed

@property
def unobserved(self):
unobserved = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be initialized as an OrderedDict: currently in Python < 3.6, the return value will not be ordered, since you built an (unordered) dictionary, then turned it into an OrderedDict. We're targeting 3.6 and higher though, in which case you do not need the OrderedDict at all, since all dicts now maintain insertion order.

That's a long way to say: I would just make a plain dictionary, but if you use OrderedDict, it needs to be initialized as such.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice point

@@ -9,7 +9,7 @@
'CollectLogProb'
]

VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape')
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still worried about this solution. Variable description is supposed to be collected far before sampling (or what about changing this?). So when you first collect VariableInfo for the first time and get these RVs, you save temporary nodes. When you collect LogProb you again run the model and variables involves there are totally different from those that are stored in VariableDescription. That's why I did not store them there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @ferrine - The RVs are not initialized until we configure the model (following the idea in the API discussion doc, we create the RVs when we call model.configure(...) or model.sample(...)). This means that we record the Distribution and the relationship between RVs, but the actually RVs are only initialized when we actually using them (ie, in the evaluation of logp).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem i am facing if RVs are not stored in the VariableDescription is that it doesn't store the specifics of any distribution (like loc or scale) even if they are mentioned in the model definition. So we will have to collect all this already provided info somehow.

@model.define
def process():
    mu = ed.Normal(loc=0., scale=10., name="mu")
    # here we lose the info that it has loc 0 and scale 10 without RV.

We can try defining a new Interceptor which does this for us for each RV.
We can then overwrite(replace existing and add new) the collected data every-time we call configure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's make it later

@sharanry sharanry changed the title [WIP] Get mcmc sampling to work Get mcmc sampling to work Jul 1, 2018
@sharanry
Copy link
Contributor Author

sharanry commented Jul 3, 2018

@ferrine Could I merge this PR?

states.update(self.observed)
log_probs = []

def interceptor(f, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want use class based interceptor for consistency?

@@ -9,7 +9,7 @@
'CollectLogProb'
]

VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape')
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's make it later

assert len(model.observed) == 1
assert not model.unobserved

model.reset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided to meke a copy of model each time state changes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is not critical though, refactoring interceptor usage is what is really needed to finish this PR (#9 (diff))

@sharanry
Copy link
Contributor Author

sharanry commented Jul 6, 2018

@ferrine I have changed it to a class based interceptor

Copy link
Member

@ferrine ferrine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find sampling test, I think we need one.

@ferrine
Copy link
Member

ferrine commented Jul 6, 2018

And test point is better get via model.test_point()

@sharanry
Copy link
Contributor Author

sharanry commented Jul 6, 2018

@ferrine Currently model.test_point() is how you get the test point.
I am not sure I understand what you are saying.

num_results=5000,
num_burnin_steps=3000,
step_size=.4,
num_leapfrog_steps=3,
numpy=True):
initial_state = []
for name, shape in model.unobserved.iteritems():
for name, (_, shape, _) in model.unobserved.items():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for name, point in model.test_point(mode=mode):
    initial_state.append(point)

Copy link
Member

@ferrine ferrine Jul 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be done in next PR

@ColCarroll ColCarroll merged commit 4c8d0d5 into pymc-devs:functional Jul 6, 2018
@ColCarroll
Copy link
Member

Congrats @sharanry , and thanks for the thorough review @ferrine !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants