From b218d43ca95b49b3e0c8675ec4d071f8f0ad4776 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 25 Jul 2018 11:38:58 +0530 Subject: [PATCH 1/7] keep a graph copy solution --- pymc4/model/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pymc4/model/base.py b/pymc4/model/base.py index 4c8d1613..8666789f 100644 --- a/pymc4/model/base.py +++ b/pymc4/model/base.py @@ -66,8 +66,9 @@ def not_observed(var, *args, **kwargs): # pylint: disable=unused-argument def get_mode(state, rv, *args, **kwargs): # pylint: disable=unused-argument return rv.distribution.mode() chain.insert(0, interceptors.Generic(after=get_mode)) - - with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)): + temp_graph = tf.Graph() + tf.contrib.graph_editor.copy(self.graph, temp_graph) + with temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): self._f(self.cfg) with self.session.as_default(): returns = self.session.run(list(values_collector.result.values())) @@ -83,7 +84,9 @@ def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument states = dict(zip(self.unobserved.keys(), args)) states.update(self.observed) interceptor = interceptors.CollectLogProb(states) - with ed.interception(interceptor): + temp_graph = tf.Graph() + tf.contrib.graph_editor.copy(self.graph, temp_graph) + with temp_graph.as_default(), ed.interception(interceptor): self._f(self._cfg) log_prob = sum(interceptor.log_probs) From 88fc3c6a24532fc381dc33e0ab0a56caeb7f5367 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 25 Jul 2018 11:49:03 +0530 Subject: [PATCH 2/7] clear temp_graph --- pymc4/model/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc4/model/base.py b/pymc4/model/base.py index 8666789f..c0d727f5 100644 --- a/pymc4/model/base.py +++ b/pymc4/model/base.py @@ -70,6 +70,7 @@ def get_mode(state, rv, *args, **kwargs): # pylint: disable=unused-argument tf.contrib.graph_editor.copy(self.graph, temp_graph) with temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): self._f(self.cfg) + del temp_graph with self.session.as_default(): returns = self.session.run(list(values_collector.result.values())) return dict(zip(values_collector.result.keys(), returns)) @@ -88,7 +89,7 @@ def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument tf.contrib.graph_editor.copy(self.graph, temp_graph) with temp_graph.as_default(), ed.interception(interceptor): self._f(self._cfg) - + del temp_graph log_prob = sum(interceptor.log_probs) return log_prob return log_joint_fn From 8136428246fe685109e25fd80887eee6dbab64f1 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 25 Jul 2018 14:31:02 +0530 Subject: [PATCH 3/7] fix graph --- pymc4/inference/sampling/sample.py | 25 +++++++++++++------------ pymc4/model/base.py | 16 +++++++++------- tests/test_model.py | 2 +- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/pymc4/inference/sampling/sample.py b/pymc4/inference/sampling/sample.py index 588aa369..9c96c3d3 100644 --- a/pymc4/inference/sampling/sample.py +++ b/pymc4/inference/sampling/sample.py @@ -9,21 +9,22 @@ def sample(model, # pylint: disable-msg=too-many-arguments step_size=.4, num_leapfrog_steps=3, numpy=True): - initial_state = [] - for name, (_, shape, _) in model.unobserved.items(): - initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name))) + with model.temp_graph.as_default(): + initial_state = [] + for name, (_, shape, _) in model.unobserved.items(): + initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name))) - states, kernel_results = tfp.mcmc.sample_chain( - num_results=num_results, - num_burnin_steps=num_burnin_steps, - current_state=initial_state, - kernel=tfp.mcmc.HamiltonianMonteCarlo( - target_log_prob_fn=model.target_log_prob_fn(), - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps)) + states, kernel_results = tfp.mcmc.sample_chain( + num_results=num_results, + num_burnin_steps=num_burnin_steps, + current_state=initial_state, + kernel=tfp.mcmc.HamiltonianMonteCarlo( + target_log_prob_fn=model.target_log_prob_fn(), + step_size=step_size, + num_leapfrog_steps=num_leapfrog_steps)) if numpy: - with tf.Session() as sess: + with tf.Session(graph=model.temp_graph) as sess: states, is_accepted_ = sess.run([states, kernel_results.is_accepted]) accepted = np.sum(is_accepted_) print("Acceptance rate: {}".format(accepted / num_results)) diff --git a/pymc4/model/base.py b/pymc4/model/base.py index c0d727f5..c72bf032 100644 --- a/pymc4/model/base.py +++ b/pymc4/model/base.py @@ -38,6 +38,7 @@ def __init__(self, name=None, graph=None, session=None, **config): session = tf.Session(graph=graph) self.session = session self.observe(**config) + self.temp_graph = tf.Graph() def define(self, f): self._f = f @@ -54,6 +55,7 @@ def _init_variables(self): info_collector = interceptors.CollectVariablesInfo() with self.graph.as_default(), ed.interception(info_collector): self._f(self.cfg) + tf.contrib.graph_editor.copy(self.graph, self.temp_graph) self._variables = info_collector.result def test_point(self, sample=True): @@ -70,26 +72,26 @@ def get_mode(state, rv, *args, **kwargs): # pylint: disable=unused-argument tf.contrib.graph_editor.copy(self.graph, temp_graph) with temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): self._f(self.cfg) + with tf.Session(graph=temp_graph) as sess: + returns = sess.run(list(values_collector.result.values())) + keys = values_collector.result.keys() del temp_graph - with self.session.as_default(): - returns = self.session.run(list(values_collector.result.values())) - return dict(zip(values_collector.result.keys(), returns)) + return dict(zip(keys, returns)) def target_log_prob_fn(self, *args, **kwargs): # pylint: disable=unused-argument """ Pass the states of the RVs as args in alphabetical order of the RVs. Compatible as `target_log_prob_fn` for tfp samplers. """ + # self.temp_graph = tf.Graph() def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument states = dict(zip(self.unobserved.keys(), args)) states.update(self.observed) interceptor = interceptors.CollectLogProb(states) - temp_graph = tf.Graph() - tf.contrib.graph_editor.copy(self.graph, temp_graph) - with temp_graph.as_default(), ed.interception(interceptor): + with self.temp_graph.as_default(), ed.interception(interceptor): self._f(self._cfg) - del temp_graph + # del temp_graph log_prob = sum(interceptor.log_probs) return log_prob return log_joint_fn diff --git a/tests/test_model.py b/tests/test_model.py index 85c96475..595ad18c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -153,5 +153,5 @@ def simple(cfg): log_prob_fn = model.target_log_prob_fn() - with tf.Session(): + with tf.Session(graph=model.temp_graph): assert -0.91893853 == pytest.approx(log_prob_fn(0).eval(), 0.00001) From 3fd3acc690192bf523433391c2b9a6e459951258 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 25 Jul 2018 14:44:59 +0530 Subject: [PATCH 4/7] update requirements.txt --- requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2d1cd48c..7b5aaf30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,6 @@ xarray==0.10.4 numpy==1.14.3 tqdm==4.23.3 biwrap -tf-nightly==1.8.0.dev20180331 -tfp-nightly==0.0.1.dev20180515 -tb-nightly==1.8.0a20180424 +tf-nightly==1.9.0.dev20180607 +tfp-nightly==0.3.0.dev20180725 +tb-nightly==1.9.0a20180613 From 003217337f596bcf4e8f4b26450125fcce16b045 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 25 Jul 2018 15:03:29 +0530 Subject: [PATCH 5/7] clean up --- pymc4/model/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc4/model/base.py b/pymc4/model/base.py index c72bf032..2115960c 100644 --- a/pymc4/model/base.py +++ b/pymc4/model/base.py @@ -27,7 +27,7 @@ def __getattr__(self, item): raise error from e -class Model(object): +class Model(): def __init__(self, name=None, graph=None, session=None, **config): self._cfg = Config(**config) self.name = name @@ -69,7 +69,7 @@ def get_mode(state, rv, *args, **kwargs): # pylint: disable=unused-argument return rv.distribution.mode() chain.insert(0, interceptors.Generic(after=get_mode)) temp_graph = tf.Graph() - tf.contrib.graph_editor.copy(self.graph, temp_graph) + tf.contrib.graph_editor.copy(self.graph, self.temp_graph) with temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): self._f(self.cfg) with tf.Session(graph=temp_graph) as sess: @@ -83,15 +83,15 @@ def target_log_prob_fn(self, *args, **kwargs): # pylint: disable=unused-argumen Pass the states of the RVs as args in alphabetical order of the RVs. Compatible as `target_log_prob_fn` for tfp samplers. """ - # self.temp_graph = tf.Graph() def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument states = dict(zip(self.unobserved.keys(), args)) states.update(self.observed) interceptor = interceptors.CollectLogProb(states) + tf.contrib.graph_editor.copy(self.graph, self.temp_graph) with self.temp_graph.as_default(), ed.interception(interceptor): self._f(self._cfg) - # del temp_graph + log_prob = sum(interceptor.log_probs) return log_prob return log_joint_fn From f3b6715e1e9f6a4672bc91853ec764b77b11485f Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 25 Jul 2018 15:21:57 +0530 Subject: [PATCH 6/7] fix pylint errors --- pymc4/model/base.py | 8 +++----- pymc4/util/interceptors.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pymc4/model/base.py b/pymc4/model/base.py index 2115960c..41ae3138 100644 --- a/pymc4/model/base.py +++ b/pymc4/model/base.py @@ -68,14 +68,12 @@ def not_observed(var, *args, **kwargs): # pylint: disable=unused-argument def get_mode(state, rv, *args, **kwargs): # pylint: disable=unused-argument return rv.distribution.mode() chain.insert(0, interceptors.Generic(after=get_mode)) - temp_graph = tf.Graph() tf.contrib.graph_editor.copy(self.graph, self.temp_graph) - with temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): + with self.temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): # pylint: disable=not-context-manager self._f(self.cfg) - with tf.Session(graph=temp_graph) as sess: + with tf.Session(graph=self.temp_graph) as sess: returns = sess.run(list(values_collector.result.values())) keys = values_collector.result.keys() - del temp_graph return dict(zip(keys, returns)) def target_log_prob_fn(self, *args, **kwargs): # pylint: disable=unused-argument @@ -89,7 +87,7 @@ def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument states.update(self.observed) interceptor = interceptors.CollectLogProb(states) tf.contrib.graph_editor.copy(self.graph, self.temp_graph) - with self.temp_graph.as_default(), ed.interception(interceptor): + with self.temp_graph.as_default(), ed.interception(interceptor): # pylint: disable=not-context-manager self._f(self._cfg) log_prob = sum(interceptor.log_probs) diff --git a/pymc4/util/interceptors.py b/pymc4/util/interceptors.py index 964d98fe..3172da38 100644 --- a/pymc4/util/interceptors.py +++ b/pymc4/util/interceptors.py @@ -12,7 +12,7 @@ VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv') -class Interceptor(object): +class Interceptor(): def name_scope(self): return tf.name_scope(self.__class__.__name__.lower()) From 72cb7acbe9010a4bef5cc8100d70de78a3f0b511 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 25 Jul 2018 17:12:35 +0530 Subject: [PATCH 7/7] fix pycodestyle errors --- pymc4/model/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc4/model/base.py b/pymc4/model/base.py index 41ae3138..40793b63 100644 --- a/pymc4/model/base.py +++ b/pymc4/model/base.py @@ -69,8 +69,10 @@ def get_mode(state, rv, *args, **kwargs): # pylint: disable=unused-argument return rv.distribution.mode() chain.insert(0, interceptors.Generic(after=get_mode)) tf.contrib.graph_editor.copy(self.graph, self.temp_graph) - with self.temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): # pylint: disable=not-context-manager + # pylint: disable=not-context-manager + with self.temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)): self._f(self.cfg) + # pylint: enable=not-context-manager with tf.Session(graph=self.temp_graph) as sess: returns = sess.run(list(values_collector.result.values())) keys = values_collector.result.keys()