diff --git a/pymc4/inference/sampling/sample.py b/pymc4/inference/sampling/sample.py index 588aa369..9c96c3d3 100644 --- a/pymc4/inference/sampling/sample.py +++ b/pymc4/inference/sampling/sample.py @@ -9,21 +9,22 @@ def sample(model, # pylint: disable-msg=too-many-arguments step_size=.4, num_leapfrog_steps=3, numpy=True): - initial_state = [] - for name, (_, shape, _) in model.unobserved.items(): - initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name))) + with model.temp_graph.as_default(): + initial_state = [] + for name, (_, shape, _) in model.unobserved.items(): + initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name))) - states, kernel_results = tfp.mcmc.sample_chain( - num_results=num_results, - num_burnin_steps=num_burnin_steps, - current_state=initial_state, - kernel=tfp.mcmc.HamiltonianMonteCarlo( - target_log_prob_fn=model.target_log_prob_fn(), - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps)) + states, kernel_results = tfp.mcmc.sample_chain( + num_results=num_results, + num_burnin_steps=num_burnin_steps, + current_state=initial_state, + kernel=tfp.mcmc.HamiltonianMonteCarlo( + target_log_prob_fn=model.target_log_prob_fn(), + step_size=step_size, + num_leapfrog_steps=num_leapfrog_steps)) if numpy: - with tf.Session() as sess: + with tf.Session(graph=model.temp_graph) as sess: states, is_accepted_ = sess.run([states, kernel_results.is_accepted]) accepted = np.sum(is_accepted_) print("Acceptance rate: {}".format(accepted / num_results)) diff --git a/pymc4/model/base.py b/pymc4/model/base.py index 4c8d1613..40793b63 100644 --- a/pymc4/model/base.py +++ b/pymc4/model/base.py @@ -27,7 +27,7 @@ def __getattr__(self, item): raise error from e -class Model(object): +class Model(): def __init__(self, name=None, graph=None, session=None, **config): self._cfg = Config(**config) self.name = name @@ -38,6 +38,7 @@ def __init__(self, name=None, graph=None, session=None, **config): session = tf.Session(graph=graph) self.session = session self.observe(**config) + self.temp_graph = tf.Graph() def define(self, f): self._f = f @@ -54,6 +55,7 @@ def _init_variables(self): info_collector = interceptors.CollectVariablesInfo() with self.graph.as_default(), ed.interception(info_collector): self._f(self.cfg) + tf.contrib.graph_editor.copy(self.graph, self.temp_graph) self._variables = info_collector.result def test_point(self, sample=True): @@ -66,12 +68,15 @@ def not_observed(var, *args, **kwargs): # pylint: disable=unused-argument def get_mode(state, rv, *args, **kwargs): # pylint: disable=unused-argument return rv.distribution.mode() chain.insert(0, interceptors.Generic(after=get_mode)) - - with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)): + tf.contrib.graph_editor.copy(self.graph, self.temp_graph) + # pylint: disable=not-context-manager + with self.temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): self._f(self.cfg) - with self.session.as_default(): - returns = self.session.run(list(values_collector.result.values())) - return dict(zip(values_collector.result.keys(), returns)) + # pylint: enable=not-context-manager + with tf.Session(graph=self.temp_graph) as sess: + returns = sess.run(list(values_collector.result.values())) + keys = values_collector.result.keys() + return dict(zip(keys, returns)) def target_log_prob_fn(self, *args, **kwargs): # pylint: disable=unused-argument """ @@ -83,7 +88,8 @@ def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument states = dict(zip(self.unobserved.keys(), args)) states.update(self.observed) interceptor = interceptors.CollectLogProb(states) - with ed.interception(interceptor): + tf.contrib.graph_editor.copy(self.graph, self.temp_graph) + with self.temp_graph.as_default(), ed.interception(interceptor): # pylint: disable=not-context-manager self._f(self._cfg) log_prob = sum(interceptor.log_probs) diff --git a/pymc4/util/interceptors.py b/pymc4/util/interceptors.py index 964d98fe..3172da38 100644 --- a/pymc4/util/interceptors.py +++ b/pymc4/util/interceptors.py @@ -12,7 +12,7 @@ VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv') -class Interceptor(object): +class Interceptor(): def name_scope(self): return tf.name_scope(self.__class__.__name__.lower()) diff --git a/requirements.txt b/requirements.txt index 2d1cd48c..7b5aaf30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,6 @@ xarray==0.10.4 numpy==1.14.3 tqdm==4.23.3 biwrap -tf-nightly==1.8.0.dev20180331 -tfp-nightly==0.0.1.dev20180515 -tb-nightly==1.8.0a20180424 +tf-nightly==1.9.0.dev20180607 +tfp-nightly==0.3.0.dev20180725 +tb-nightly==1.9.0a20180613 diff --git a/tests/test_model.py b/tests/test_model.py index 85c96475..595ad18c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -153,5 +153,5 @@ def simple(cfg): log_prob_fn = model.target_log_prob_fn() - with tf.Session(): + with tf.Session(graph=model.temp_graph): assert -0.91893853 == pytest.approx(log_prob_fn(0).eval(), 0.00001)