Skip to content

A simple solution to work on the copy of the graph #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions pymc4/inference/sampling/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@ def sample(model, # pylint: disable-msg=too-many-arguments
step_size=.4,
num_leapfrog_steps=3,
numpy=True):
initial_state = []
for name, (_, shape, _) in model.unobserved.items():
initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name)))
with model.temp_graph.as_default():
initial_state = []
for name, (_, shape, _) in model.unobserved.items():
initial_state.append(.5 * tf.ones(shape, name="init_{}".format(name)))

states, kernel_results = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=model.target_log_prob_fn(),
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps))
states, kernel_results = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=model.target_log_prob_fn(),
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps))

if numpy:
with tf.Session() as sess:
with tf.Session(graph=model.temp_graph) as sess:
states, is_accepted_ = sess.run([states, kernel_results.is_accepted])
accepted = np.sum(is_accepted_)
print("Acceptance rate: {}".format(accepted / num_results))
Expand Down
20 changes: 13 additions & 7 deletions pymc4/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __getattr__(self, item):
raise error from e


class Model(object):
class Model():
def __init__(self, name=None, graph=None, session=None, **config):
self._cfg = Config(**config)
self.name = name
Expand All @@ -38,6 +38,7 @@ def __init__(self, name=None, graph=None, session=None, **config):
session = tf.Session(graph=graph)
self.session = session
self.observe(**config)
self.temp_graph = tf.Graph()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use this pattern for creating and destroying the temporary graph. I'm using a pymc3 model here because I think the pattern is similar, but let me know if it isn't! Note that you will only need to implement the equivalent of the temp_model method below.

import pymc3 as pm


def check_in_model_context():
    try:
        m = pm.modelcontext(None)
        return f'In model context! {len(m.vars)} vars!'
    except TypeError:
        return 'Not in model context!'

class Model(object):
    def __init__(self):
        self._temp_model = None
        
    @contextmanager
    def temp_model(self):
        self._temp_model = pm.Model()
        try:
            with self._temp_model:
                yield
        finally:
            self._temp_model = None
    
    def calculation(self):
        print(check_in_model_context())
        with self.temp_model():
            pm.Normal('x', 0, 1)
            print(check_in_model_context())
            
        print(check_in_model_context())
        with self.temp_model():
            print(check_in_model_context())
        print(check_in_model_context())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decorator for inclass function is causing problems.. I will look into


def define(self, f):
self._f = f
Expand All @@ -54,6 +55,7 @@ def _init_variables(self):
info_collector = interceptors.CollectVariablesInfo()
with self.graph.as_default(), ed.interception(info_collector):
self._f(self.cfg)
tf.contrib.graph_editor.copy(self.graph, self.temp_graph)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So 1) You create a temp graph for inner manipulations

self._variables = info_collector.result

def test_point(self, sample=True):
Expand All @@ -66,12 +68,15 @@ def not_observed(var, *args, **kwargs): # pylint: disable=unused-argument
def get_mode(state, rv, *args, **kwargs): # pylint: disable=unused-argument
return rv.distribution.mode()
chain.insert(0, interceptors.Generic(after=get_mode))

with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)):
tf.contrib.graph_editor.copy(self.graph, self.temp_graph)
# pylint: disable=not-context-manager
with self.temp_graph.as_default(), ed.interception(interceptors.Chain(*chain)):
Copy link
Member

@ferrine ferrine Aug 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You reuse tmp graph created on step 1) suppose cfg has any tensor from original graph and you call your model with temp_graph.as_default(). That creates all variables for temp_graph. This looks fine unless the case I'm talking about. If you attempt to do an operation on two variables from different graphs, you will get an exception
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know a solution for this case yet. I've seen some kind of working and relevant example in tensorflow codebase
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/function.py#L347

self._f(self.cfg)
with self.session.as_default():
returns = self.session.run(list(values_collector.result.values()))
return dict(zip(values_collector.result.keys(), returns))
# pylint: enable=not-context-manager
with tf.Session(graph=self.temp_graph) as sess:
returns = sess.run(list(values_collector.result.values()))
keys = values_collector.result.keys()
return dict(zip(keys, returns))

def target_log_prob_fn(self, *args, **kwargs): # pylint: disable=unused-argument
"""
Expand All @@ -83,7 +88,8 @@ def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument
states = dict(zip(self.unobserved.keys(), args))
states.update(self.observed)
interceptor = interceptors.CollectLogProb(states)
with ed.interception(interceptor):
tf.contrib.graph_editor.copy(self.graph, self.temp_graph)
with self.temp_graph.as_default(), ed.interception(interceptor): # pylint: disable=not-context-manager
self._f(self._cfg)

log_prob = sum(interceptor.log_probs)
Expand Down
2 changes: 1 addition & 1 deletion pymc4/util/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape,rv')


class Interceptor(object):
class Interceptor():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good practice to inherit from object, even if you don't need to!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I thought in py3 we don't need that anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am facing a linting error when I import from object

def name_scope(self):
return tf.name_scope(self.__class__.__name__.lower())

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ xarray==0.10.4
numpy==1.14.3
tqdm==4.23.3
biwrap
tf-nightly==1.8.0.dev20180331
tfp-nightly==0.0.1.dev20180515
tb-nightly==1.8.0a20180424
tf-nightly==1.9.0.dev20180607
tfp-nightly==0.3.0.dev20180725
tb-nightly==1.9.0a20180613
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,5 @@ def simple(cfg):

log_prob_fn = model.target_log_prob_fn()

with tf.Session():
with tf.Session(graph=model.temp_graph):
assert -0.91893853 == pytest.approx(log_prob_fn(0).eval(), 0.00001)